Engel Nyst 1 год назад
Родитель
Сommit
9d0e6a24bc

+ 2 - 2
.github/workflows/py-unit-tests.yml

@@ -93,7 +93,7 @@ jobs:
         id: buildx
         uses: docker/setup-buildx-action@v3
       - name: Run Tests
-        run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml ./tests/unit
+        run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml ./tests/unit --ignore=tests/unit/test_memory.py
       - name: Upload coverage to Codecov
         uses: codecov/codecov-action@v4
         env:
@@ -125,7 +125,7 @@ jobs:
       - name: Build Environment
         run: make build
       - name: Run Tests
-        run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml -svv ./tests/unit
+        run: poetry run pytest --forked --cov=agenthub --cov=openhands --cov-report=xml -svv ./tests/unit --ignore=tests/unit/test_memory.py
       - name: Upload coverage to Codecov
         uses: codecov/codecov-action@v4
         env:

+ 1 - 1
config.template.toml

@@ -185,7 +185,7 @@ model = "gpt-4o-mini"
 #memory_enabled = false
 
 # Memory maximum threads
-#memory_max_threads = 2
+#memory_max_threads = 3
 
 # LLM config group to use
 #llm_config = 'your-llm-config-group'

+ 1 - 1
openhands/core/config/agent_config.py

@@ -16,7 +16,7 @@ class AgentConfig:
 
     micro_agent_name: str | None = None
     memory_enabled: bool = False
-    memory_max_threads: int = 2
+    memory_max_threads: int = 3
     llm_config: str | None = None
 
     def defaults_to_dict(self) -> dict:

+ 1 - 1
openhands/events/serialization/event.py

@@ -96,7 +96,7 @@ def event_to_memory(event: 'Event', max_message_chars: int) -> dict:
 
 def truncate_content(content: str, max_chars: int) -> str:
     """Truncate the middle of the observation content if it is too long."""
-    if len(content) <= max_chars:
+    if len(content) <= max_chars or max_chars == -1:
         return content
 
     # truncate the middle and include a message to the LLM about it

+ 143 - 145
openhands/memory/memory.py

@@ -1,189 +1,187 @@
-import threading
-
-from openai._exceptions import APIConnectionError, InternalServerError, RateLimitError
-from tenacity import (
-    retry,
-    retry_if_exception_type,
-    stop_after_attempt,
-    wait_random_exponential,
-)
+import json
 
-from openhands.core.config import LLMConfig
+from openhands.core.config import AgentConfig, LLMConfig
 from openhands.core.logger import openhands_logger as logger
-from openhands.core.utils import json
-from openhands.utils.tenacity_stop import stop_if_should_exit
-
-try:
-    import chromadb
-    import llama_index.embeddings.openai.base as llama_openai
-    from llama_index.core import Document, VectorStoreIndex
-    from llama_index.core.retrievers import VectorIndexRetriever
-    from llama_index.vector_stores.chroma import ChromaVectorStore
-
-    LLAMA_INDEX_AVAILABLE = True
-except ImportError:
-    LLAMA_INDEX_AVAILABLE = False
+from openhands.events.event import Event
+from openhands.events.serialization.event import event_to_memory
+from openhands.events.stream import EventStream
+from openhands.utils.embeddings import (
+    LLAMA_INDEX_AVAILABLE,
+    EmbeddingsLoader,
+    check_llama_index,
+)
 
+# Conditional imports based on llama_index availability
 if LLAMA_INDEX_AVAILABLE:
-    # TODO: this could be made configurable
-    num_retries: int = 10
-    retry_min_wait: int = 3
-    retry_max_wait: int = 300
-
-    # llama-index includes a retry decorator around openai.get_embeddings() function
-    # it is initialized with hard-coded values and errors
-    # this non-customizable behavior is creating issues when it's retrying faster than providers' rate limits
-    # this block attempts to banish it and replace it with our decorator, to allow users to set their own limits
-
-    if hasattr(llama_openai.get_embeddings, '__wrapped__'):
-        original_get_embeddings = llama_openai.get_embeddings.__wrapped__
-    else:
-        logger.warning('Cannot set custom retry limits.')
-        num_retries = 1
-        original_get_embeddings = llama_openai.get_embeddings
-
-    def attempt_on_error(retry_state):
-        logger.error(
-            f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
-            exc_info=False,
-        )
-        return None
-
-    @retry(
-        reraise=True,
-        stop=stop_after_attempt(num_retries) | stop_if_should_exit(),
-        wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait),
-        retry=retry_if_exception_type(
-            (RateLimitError, APIConnectionError, InternalServerError)
-        ),
-        after=attempt_on_error,
+    import chromadb
+    from llama_index.core import Document
+    from llama_index.core.indices.vector_store.base import VectorStoreIndex
+    from llama_index.core.indices.vector_store.retrievers.retriever import (
+        VectorIndexRetriever,
     )
-    def wrapper_get_embeddings(*args, **kwargs):
-        return original_get_embeddings(*args, **kwargs)
-
-    llama_openai.get_embeddings = wrapper_get_embeddings
-
-    class EmbeddingsLoader:
-        """Loader for embedding model initialization."""
-
-        @staticmethod
-        def get_embedding_model(strategy: str, llm_config: LLMConfig):
-            supported_ollama_embed_models = [
-                'llama2',
-                'mxbai-embed-large',
-                'nomic-embed-text',
-                'all-minilm',
-                'stable-code',
-                'bge-m3',
-                'bge-large',
-                'paraphrase-multilingual',
-                'snowflake-arctic-embed',
-            ]
-            if strategy in supported_ollama_embed_models:
-                from llama_index.embeddings.ollama import OllamaEmbedding
-
-                return OllamaEmbedding(
-                    model_name=strategy,
-                    base_url=llm_config.embedding_base_url,
-                    ollama_additional_kwargs={'mirostat': 0},
-                )
-            elif strategy == 'openai':
-                from llama_index.embeddings.openai import OpenAIEmbedding
-
-                return OpenAIEmbedding(
-                    model='text-embedding-ada-002',
-                    api_key=llm_config.api_key,
-                )
-            elif strategy == 'azureopenai':
-                from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
-
-                return AzureOpenAIEmbedding(
-                    model='text-embedding-ada-002',
-                    deployment_name=llm_config.embedding_deployment_name,
-                    api_key=llm_config.api_key,
-                    azure_endpoint=llm_config.base_url,
-                    api_version=llm_config.api_version,
-                )
-            elif (strategy is not None) and (strategy.lower() == 'none'):
-                # TODO: this works but is not elegant enough. The incentive is when
-                # an agent using embeddings is not used, there is no reason we need to
-                # initialize an embedding model
-                return None
-            else:
-                from llama_index.embeddings.huggingface import HuggingFaceEmbedding
-
-                return HuggingFaceEmbedding(model_name='BAAI/bge-small-en-v1.5')
+    from llama_index.core.schema import TextNode
+    from llama_index.vector_stores.chroma import ChromaVectorStore
 
 
 class LongTermMemory:
     """Handles storing information for the agent to access later, using chromadb."""
 
-    def __init__(self, llm_config: LLMConfig, memory_max_threads: int = 1):
+    event_stream: EventStream
+
+    def __init__(
+        self,
+        llm_config: LLMConfig,
+        agent_config: AgentConfig,
+        event_stream: EventStream,
+    ):
         """Initialize the chromadb and set up ChromaVectorStore for later use."""
-        if not LLAMA_INDEX_AVAILABLE:
-            raise ImportError(
-                'llama_index and its dependencies are not installed. '
-                'To use LongTermMemory, please run: poetry install --with llama-index'
-            )
 
-        db = chromadb.Client(chromadb.Settings(anonymized_telemetry=False))
+        check_llama_index()
+
+        # initialize the chromadb client
+        db = chromadb.PersistentClient(
+            path=f'./cache/sessions/{event_stream.sid}/memory',
+            # FIXME anonymized_telemetry=False,
+        )
         self.collection = db.get_or_create_collection(name='memories')
         vector_store = ChromaVectorStore(chroma_collection=self.collection)
+
+        # embedding model
         embedding_strategy = llm_config.embedding_model
-        embed_model = EmbeddingsLoader.get_embedding_model(
+        self.embed_model = EmbeddingsLoader.get_embedding_model(
             embedding_strategy, llm_config
         )
-        self.index = VectorStoreIndex.from_vector_store(vector_store, embed_model)
-        self.sema = threading.Semaphore(value=memory_max_threads)
+
+        # instantiate the index
+        self.index = VectorStoreIndex.from_vector_store(vector_store, self.embed_model)
         self.thought_idx = 0
-        self._add_threads: list[threading.Thread] = []
 
-    def add_event(self, event: dict):
+        # initialize the event stream
+        self.event_stream = event_stream
+
+        # max of threads to run the pipeline
+        self.memory_max_threads = agent_config.memory_max_threads
+
+    def add_event(self, event: Event):
         """Adds a new event to the long term memory with a unique id.
 
         Parameters:
-        - event (dict): The new event to be added to memory
+        - event: The new event to be added to memory
         """
-        id = ''
-        t = ''
-        if 'action' in event:
-            t = 'action'
-            id = event['action']
-        elif 'observation' in event:
-            t = 'observation'
-            id = event['observation']
+        try:
+            # convert the event to a memory-friendly format, and don't truncate
+            event_data = event_to_memory(event, -1)
+        except (json.JSONDecodeError, KeyError, ValueError) as e:
+            logger.warning(f'Failed to process event: {e}')
+            return
+
+        # determine the event type and ID
+        event_type = ''
+        event_id = ''
+        if 'action' in event_data:
+            event_type = 'action'
+            event_id = event_data['action']
+        elif 'observation' in event_data:
+            event_type = 'observation'
+            event_id = event_data['observation']
+
+        # create a Document instance for the event
         doc = Document(
-            text=json.dumps(event),
+            text=json.dumps(event_data),
             doc_id=str(self.thought_idx),
             extra_info={
-                'type': t,
-                'id': id,
+                'type': event_type,
+                'id': event_id,
                 'idx': self.thought_idx,
             },
         )
         self.thought_idx += 1
-        logger.debug('Adding %s event to memory: %d', t, self.thought_idx)
-        thread = threading.Thread(target=self._add_doc, args=(doc,))
-        self._add_threads.append(thread)
-        thread.start()  # We add the doc concurrently so we don't have to wait ~500ms for the insert
-
-    def _add_doc(self, doc):
-        with self.sema:
-            self.index.insert(doc)
+        logger.debug('Adding %s event to memory: %d', event_type, self.thought_idx)
+        self._add_document(document=doc)
+
+    def _add_document(self, document: 'Document'):
+        """Inserts a single document into the index."""
+        self.index.insert_nodes([self._create_node(document)])
+
+    def _create_node(self, document: 'Document') -> 'TextNode':
+        """Create a TextNode from a Document instance."""
+        return TextNode(
+            text=document.text,
+            doc_id=document.doc_id,
+            extra_info=document.extra_info,
+        )
 
-    def search(self, query: str, k: int = 10):
-        """Searches through the current memory using VectorIndexRetriever
+    def search(self, query: str, k: int = 10) -> list[str]:
+        """Searches through the current memory using VectorIndexRetriever.
 
         Parameters:
         - query (str): A query to match search results to
         - k (int): Number of top results to return
 
         Returns:
-        - list[str]: list of top k results found in current memory
+        - list[str]: List of top k results found in current memory
         """
         retriever = VectorIndexRetriever(
             index=self.index,
             similarity_top_k=k,
         )
         results = retriever.retrieve(query)
+
+        for result in results:
+            logger.debug(
+                f'Doc ID: {result.doc_id}:\n Text: {result.get_text()}\n Score: {result.score}'
+            )
+
         return [r.get_text() for r in results]
+
+    def _events_to_docs(self) -> list['Document']:
+        """Convert all events from the EventStream to documents for batch insert into the index."""
+        try:
+            events = self.event_stream.get_events()
+        except Exception as e:
+            logger.debug(f'No events found for session {self.event_stream.sid}: {e}')
+            return []
+
+        documents: list[Document] = []
+
+        for event in events:
+            try:
+                # convert the event to a memory-friendly format, and don't truncate
+                event_data = event_to_memory(event, -1)
+
+                # determine the event type and ID
+                event_type = ''
+                event_id = ''
+                if 'action' in event_data:
+                    event_type = 'action'
+                    event_id = event_data['action']
+                elif 'observation' in event_data:
+                    event_type = 'observation'
+                    event_id = event_data['observation']
+
+                # create a Document instance for the event
+                doc = Document(
+                    text=json.dumps(event_data),
+                    doc_id=str(self.thought_idx),
+                    extra_info={
+                        'type': event_type,
+                        'id': event_id,
+                        'idx': self.thought_idx,
+                    },
+                )
+                documents.append(doc)
+                self.thought_idx += 1
+            except (json.JSONDecodeError, KeyError, ValueError) as e:
+                logger.warning(f'Failed to process event: {e}')
+                continue
+
+        if documents:
+            logger.debug(f'Batch inserting {len(documents)} documents into the index.')
+        else:
+            logger.debug('No valid documents found to insert into the index.')
+
+        return documents
+
+    def create_nodes(self, documents: list['Document']) -> list['TextNode']:
+        """Create nodes from a list of documents."""
+        return [self._create_node(doc) for doc in documents]

+ 176 - 0
openhands/utils/embeddings.py

@@ -0,0 +1,176 @@
+import importlib.util
+import os
+
+from joblib import Parallel, delayed
+
+from openhands.core.config import LLMConfig
+
+try:
+    # check if those we need later are available using importlib
+    if importlib.util.find_spec('chromadb') is None:
+        raise ImportError(
+            'chromadb is not available. Please install it using poetry install --with llama-index'
+        )
+
+    if (
+        importlib.util.find_spec(
+            'llama_index.core.indices.vector_store.retrievers.retriever'
+        )
+        is None
+        or importlib.util.find_spec('llama_index.core.indices.vector_store.base')
+        is None
+    ):
+        raise ImportError(
+            'llama_index is not available. Please install it using poetry install --with llama-index'
+        )
+
+    from llama_index.core import Document, VectorStoreIndex
+    from llama_index.core.base.embeddings.base import BaseEmbedding
+    from llama_index.core.ingestion import IngestionPipeline
+    from llama_index.core.schema import TextNode
+
+    LLAMA_INDEX_AVAILABLE = True
+
+except ImportError:
+    LLAMA_INDEX_AVAILABLE = False
+
+# Define supported embedding models
+SUPPORTED_OLLAMA_EMBED_MODELS = [
+    'llama2',
+    'mxbai-embed-large',
+    'nomic-embed-text',
+    'all-minilm',
+    'stable-code',
+    'bge-m3',
+    'bge-large',
+    'paraphrase-multilingual',
+    'snowflake-arctic-embed',
+]
+
+
+def check_llama_index():
+    """Utility function to check the availability of llama_index.
+
+    Raises:
+        ImportError: If llama_index is not available.
+    """
+    if not LLAMA_INDEX_AVAILABLE:
+        raise ImportError(
+            'llama_index and its dependencies are not installed. '
+            'To use memory features, please run: poetry install --with llama-index.'
+        )
+
+
+class EmbeddingsLoader:
+    """Loader for embedding model initialization."""
+
+    @staticmethod
+    def get_embedding_model(strategy: str, llm_config: LLMConfig) -> 'BaseEmbedding':
+        """Initialize and return the appropriate embedding model based on the strategy.
+
+        Parameters:
+        - strategy: The embedding strategy to use.
+        - llm_config: Configuration for the LLM.
+
+        Returns:
+        - An instance of the selected embedding model or None.
+        """
+
+        if strategy in SUPPORTED_OLLAMA_EMBED_MODELS:
+            from llama_index.embeddings.ollama import OllamaEmbedding
+
+            return OllamaEmbedding(
+                model_name=strategy,
+                base_url=llm_config.embedding_base_url,
+                ollama_additional_kwargs={'mirostat': 0},
+            )
+        elif strategy == 'openai':
+            from llama_index.embeddings.openai import OpenAIEmbedding
+
+            return OpenAIEmbedding(
+                model='text-embedding-ada-002',
+                api_key=llm_config.api_key,
+            )
+        elif strategy == 'azureopenai':
+            from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
+
+            return AzureOpenAIEmbedding(
+                model='text-embedding-ada-002',
+                deployment_name=llm_config.embedding_deployment_name,
+                api_key=llm_config.api_key,
+                azure_endpoint=llm_config.base_url,
+                api_version=llm_config.api_version,
+            )
+        elif (strategy is not None) and (strategy.lower() == 'none'):
+            # TODO: this works but is not elegant enough. The incentive is when
+            # an agent using embeddings is not used, there is no reason we need to
+            # initialize an embedding model
+            return None
+        else:
+            from llama_index.embeddings.huggingface import HuggingFaceEmbedding
+
+            # initialize the local embedding model
+            local_embed_model = HuggingFaceEmbedding(
+                model_name='BAAI/bge-small-en-v1.5'
+            )
+
+            # for local embeddings, we need torch
+            import torch
+
+            # choose the best device
+            # first determine what is available: CUDA, MPS, or CPU
+            if torch.cuda.is_available():
+                device = 'cuda'
+            elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
+                device = 'mps'
+            else:
+                device = 'cpu'
+                os.environ['CUDA_VISIBLE_DEVICES'] = ''
+                os.environ['PYTORCH_FORCE_CPU'] = (
+                    '1'  # try to force CPU to avoid errors
+                )
+
+                # override CUDA availability
+                torch.cuda.is_available = lambda: False
+
+            # disable MPS to avoid errors
+            if device != 'mps' and hasattr(torch.backends, 'mps'):
+                torch.backends.mps.is_available = lambda: False
+                torch.backends.mps.is_built = False
+
+            # the device being used
+            print(f'Using device for embeddings: {device}')
+
+            return local_embed_model
+
+
+# --------------------------------------------------------------------------
+# Utility functions to run pipelines, split out for profiling
+# --------------------------------------------------------------------------
+def run_pipeline(
+    embed_model: 'BaseEmbedding', documents: list['Document'], num_workers: int
+) -> list['TextNode']:
+    """Run a pipeline embedding documents."""
+
+    # set up a pipeline with the transformations to make
+    pipeline = IngestionPipeline(
+        transformations=[
+            embed_model,
+        ],
+    )
+
+    # run the pipeline with num_workers
+    nodes = pipeline.run(
+        documents=documents, show_progress=True, num_workers=num_workers
+    )
+    return nodes
+
+
+def insert_batch_docs(
+    index: 'VectorStoreIndex', documents: list['Document'], num_workers: int
+) -> list['TextNode']:
+    """Run the document indexing in parallel."""
+    results = Parallel(n_jobs=num_workers, backend='threading')(
+        delayed(index.insert)(doc) for doc in documents
+    )
+    return results

+ 246 - 0
tests/unit/test_memory.py

@@ -0,0 +1,246 @@
+import json
+from datetime import datetime, timezone
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from openhands.core.config import AgentConfig, LLMConfig
+from openhands.events.event import Event, EventSource
+from openhands.events.stream import EventStream
+from openhands.memory.memory import LongTermMemory
+from openhands.storage.files import FileStore
+
+
+@pytest.fixture
+def mock_llm_config() -> LLMConfig:
+    config = MagicMock(spec=LLMConfig)
+    config.embedding_model = 'test_embedding_model'
+    config.api_key = 'test_api_key'
+    config.api_version = 'v1'
+    return config
+
+
+@pytest.fixture
+def mock_agent_config() -> AgentConfig:
+    config = AgentConfig(
+        micro_agent_name='test_micro_agent',
+        memory_enabled=True,
+        memory_max_threads=4,
+        llm_config='test_llm_config',
+    )
+    return config
+
+
+@pytest.fixture
+def mock_file_store() -> FileStore:
+    store = MagicMock(spec=FileStore)
+    store.sid = 'test_session'
+    return store
+
+
+@pytest.fixture
+def mock_event_stream(mock_file_store: FileStore) -> EventStream:
+    with patch('openhands.events.stream.EventStream') as MockEventStream:
+        instance = MockEventStream.return_value
+        instance.sid = 'test_session'
+        instance.get_events = MagicMock()
+        return instance
+
+
+@pytest.fixture
+def long_term_memory(
+    mock_llm_config: LLMConfig,
+    mock_agent_config: AgentConfig,
+    mock_event_stream: EventStream,
+) -> LongTermMemory:
+    with patch(
+        'openhands.memory.memory.chromadb.PersistentClient'
+    ) as mock_chroma_client:
+        mock_collection = MagicMock()
+        mock_chroma_client.return_value.get_or_create_collection.return_value = (
+            mock_collection
+        )
+        memory = LongTermMemory(
+            llm_config=mock_llm_config,
+            agent_config=mock_agent_config,
+            event_stream=mock_event_stream,
+        )
+        memory.collection = mock_collection
+        return memory
+
+
+def _create_action_event(action: str) -> Event:
+    """Helper function to create an action event."""
+    event = Event()
+    event._id = -1
+    event._timestamp = datetime.now(timezone.utc).isoformat()
+    event._source = EventSource.AGENT
+    event.action = action
+    return event
+
+
+def _create_observation_event(observation: str) -> Event:
+    """Helper function to create an observation event."""
+    event = Event()
+    event._id = -1
+    event._timestamp = datetime.now(timezone.utc).isoformat()
+    event._source = EventSource.USER
+    event.observation = observation
+    return event
+
+
+def test_add_event_with_action(long_term_memory: LongTermMemory):
+    event = _create_action_event('test_action')
+    long_term_memory._add_document = MagicMock()
+    long_term_memory.add_event(event)
+    assert long_term_memory.thought_idx == 1
+    long_term_memory._add_document.assert_called_once()
+    _, kwargs = long_term_memory._add_document.call_args
+    assert kwargs['document'].extra_info['type'] == 'action'
+    assert kwargs['document'].extra_info['id'] == 'test_action'
+
+
+def test_add_event_with_observation(long_term_memory: LongTermMemory):
+    event = _create_observation_event('test_observation')
+    long_term_memory._add_document = MagicMock()
+    long_term_memory.add_event(event)
+    assert long_term_memory.thought_idx == 1
+    long_term_memory._add_document.assert_called_once()
+    _, kwargs = long_term_memory._add_document.call_args
+    assert kwargs['document'].extra_info['type'] == 'observation'
+    assert kwargs['document'].extra_info['id'] == 'test_observation'
+
+
+def test_add_event_with_missing_keys(long_term_memory: LongTermMemory):
+    # Creating an event with additional unexpected attributes
+    event = Event()
+    event._id = -1
+    event._timestamp = datetime.now(timezone.utc).isoformat()
+    event._source = EventSource.AGENT
+    event.action = 'test_action'
+    event.unexpected_key = 'value'
+
+    long_term_memory._add_document = MagicMock()
+    long_term_memory.add_event(event)
+    assert long_term_memory.thought_idx == 1
+    long_term_memory._add_document.assert_called_once()
+    _, kwargs = long_term_memory._add_document.call_args
+    assert kwargs['document'].extra_info['type'] == 'action'
+    assert kwargs['document'].extra_info['id'] == 'test_action'
+
+
+def test_events_to_docs_no_events(
+    long_term_memory: LongTermMemory, mock_event_stream: EventStream
+):
+    mock_event_stream.get_events.side_effect = FileNotFoundError
+
+    # convert events to documents
+    documents = long_term_memory._events_to_docs()
+
+    # since get_events raises, documents should be empty
+    assert len(documents) == 0
+
+    # thought_idx remains unchanged
+    assert long_term_memory.thought_idx == 0
+
+
+def test_load_events_into_index_with_invalid_json(
+    long_term_memory: LongTermMemory, mock_event_stream: EventStream
+):
+    """Test loading events with malformed event data."""
+    # Simulate an event that causes event_to_memory to raise a JSONDecodeError
+    with patch(
+        'openhands.memory.memory.event_to_memory',
+        side_effect=json.JSONDecodeError('Expecting value', '', 0),
+    ):
+        event = _create_action_event('invalid_action')
+        mock_event_stream.get_events.return_value = [event]
+
+        # convert events to documents
+        documents = long_term_memory._events_to_docs()
+
+        # since event_to_memory raises, documents should be empty
+        assert len(documents) == 0
+
+    # thought_idx remains unchanged
+    assert long_term_memory.thought_idx == 0
+
+
+def test_embeddings_inserted_into_chroma(long_term_memory: LongTermMemory):
+    event = _create_action_event('test_action')
+    long_term_memory._add_document = MagicMock()
+    long_term_memory.add_event(event)
+    long_term_memory._add_document.assert_called()
+    _, kwargs = long_term_memory._add_document.call_args
+    assert 'document' in kwargs
+    assert (
+        kwargs['document'].text
+        == '{"source": "agent", "action": "test_action", "args": {}}'
+    )
+
+
+def test_search_returns_correct_results(long_term_memory: LongTermMemory):
+    mock_retriever = MagicMock()
+    mock_retriever.retrieve.return_value = [
+        MagicMock(get_text=MagicMock(return_value='result1')),
+        MagicMock(get_text=MagicMock(return_value='result2')),
+    ]
+    with patch(
+        'openhands.memory.memory.VectorIndexRetriever', return_value=mock_retriever
+    ):
+        results = long_term_memory.search(query='test query', k=2)
+        assert results == ['result1', 'result2']
+        mock_retriever.retrieve.assert_called_once_with('test query')
+
+
+def test_search_with_no_results(long_term_memory: LongTermMemory):
+    mock_retriever = MagicMock()
+    mock_retriever.retrieve.return_value = []
+    with patch(
+        'openhands.memory.memory.VectorIndexRetriever', return_value=mock_retriever
+    ):
+        results = long_term_memory.search(query='no results', k=5)
+        assert results == []
+        mock_retriever.retrieve.assert_called_once_with('no results')
+
+
+def test_add_event_increment_thought_idx(long_term_memory: LongTermMemory):
+    event1 = _create_action_event('action1')
+    event2 = _create_observation_event('observation1')
+    long_term_memory.add_event(event1)
+    long_term_memory.add_event(event2)
+    assert long_term_memory.thought_idx == 2
+
+
+def test_load_events_batch_insert(
+    long_term_memory: LongTermMemory, mock_event_stream: EventStream
+):
+    event1 = _create_action_event('action1')
+    event2 = _create_observation_event('observation1')
+    event3 = _create_action_event('action2')
+    mock_event_stream.get_events.return_value = [event1, event2, event3]
+
+    # Mock insert_batch_docs
+    with patch('openhands.utils.embeddings.insert_batch_docs') as mock_run_docs:
+        # convert events to documents
+        documents = long_term_memory._events_to_docs()
+
+        # Mock the insert_batch_docs to simulate document insertion
+        mock_run_docs.return_value = []
+
+        # Call insert_batch_docs with the documents
+        mock_run_docs(
+            index=long_term_memory.index,
+            documents=documents,
+            num_workers=long_term_memory.memory_max_threads,
+        )
+
+        # Assert that insert_batch_docs was called with the correct arguments
+        mock_run_docs.assert_called_once_with(
+            index=long_term_memory.index,
+            documents=documents,
+            num_workers=long_term_memory.memory_max_threads,
+        )
+
+    # Check if thought_idx was incremented correctly
+    assert long_term_memory.thought_idx == 3