memory.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import json
  2. from openhands.core.config import AgentConfig, LLMConfig
  3. from openhands.core.logger import openhands_logger as logger
  4. from openhands.events.event import Event
  5. from openhands.events.serialization.event import event_to_memory
  6. from openhands.events.stream import EventStream
  7. from openhands.utils.embeddings import (
  8. LLAMA_INDEX_AVAILABLE,
  9. EmbeddingsLoader,
  10. check_llama_index,
  11. )
  12. # Conditional imports based on llama_index availability
  13. if LLAMA_INDEX_AVAILABLE:
  14. import chromadb
  15. from llama_index.core import Document
  16. from llama_index.core.indices.vector_store.base import VectorStoreIndex
  17. from llama_index.core.indices.vector_store.retrievers.retriever import (
  18. VectorIndexRetriever,
  19. )
  20. from llama_index.core.schema import TextNode
  21. from llama_index.vector_stores.chroma import ChromaVectorStore
  22. class LongTermMemory:
  23. """Handles storing information for the agent to access later, using chromadb."""
  24. event_stream: EventStream
  25. def __init__(
  26. self,
  27. llm_config: LLMConfig,
  28. agent_config: AgentConfig,
  29. event_stream: EventStream,
  30. ):
  31. """Initialize the chromadb and set up ChromaVectorStore for later use."""
  32. check_llama_index()
  33. # initialize the chromadb client
  34. db = chromadb.PersistentClient(
  35. path=f'./cache/sessions/{event_stream.sid}/memory',
  36. # FIXME anonymized_telemetry=False,
  37. )
  38. self.collection = db.get_or_create_collection(name='memories')
  39. vector_store = ChromaVectorStore(chroma_collection=self.collection)
  40. # embedding model
  41. embedding_strategy = llm_config.embedding_model
  42. self.embed_model = EmbeddingsLoader.get_embedding_model(
  43. embedding_strategy, llm_config
  44. )
  45. logger.debug(f'Using embedding model: {self.embed_model}')
  46. # instantiate the index
  47. self.index = VectorStoreIndex.from_vector_store(vector_store, self.embed_model)
  48. self.thought_idx = 0
  49. # initialize the event stream
  50. self.event_stream = event_stream
  51. # max of threads to run the pipeline
  52. self.memory_max_threads = agent_config.memory_max_threads
  53. def add_event(self, event: Event):
  54. """Adds a new event to the long term memory with a unique id.
  55. Parameters:
  56. - event: The new event to be added to memory
  57. """
  58. try:
  59. # convert the event to a memory-friendly format, and don't truncate
  60. event_data = event_to_memory(event, -1)
  61. except (json.JSONDecodeError, KeyError, ValueError) as e:
  62. logger.warning(f'Failed to process event: {e}')
  63. return
  64. # determine the event type and ID
  65. event_type = ''
  66. event_id = ''
  67. if 'action' in event_data:
  68. event_type = 'action'
  69. event_id = event_data['action']
  70. elif 'observation' in event_data:
  71. event_type = 'observation'
  72. event_id = event_data['observation']
  73. # create a Document instance for the event
  74. doc = Document(
  75. text=json.dumps(event_data),
  76. doc_id=str(self.thought_idx),
  77. extra_info={
  78. 'type': event_type,
  79. 'id': event_id,
  80. 'idx': self.thought_idx,
  81. },
  82. )
  83. self.thought_idx += 1
  84. logger.debug('Adding %s event to memory: %d', event_type, self.thought_idx)
  85. self._add_document(document=doc)
  86. def _add_document(self, document: 'Document'):
  87. """Inserts a single document into the index."""
  88. self.index.insert_nodes([self._create_node(document)])
  89. def _create_node(self, document: 'Document') -> 'TextNode':
  90. """Create a TextNode from a Document instance."""
  91. return TextNode(
  92. text=document.text,
  93. doc_id=document.doc_id,
  94. extra_info=document.extra_info,
  95. )
  96. def search(self, query: str, k: int = 10) -> list[str]:
  97. """Searches through the current memory using VectorIndexRetriever.
  98. Parameters:
  99. - query (str): A query to match search results to
  100. - k (int): Number of top results to return
  101. Returns:
  102. - list[str]: List of top k results found in current memory
  103. """
  104. retriever = VectorIndexRetriever(
  105. index=self.index,
  106. similarity_top_k=k,
  107. )
  108. results = retriever.retrieve(query)
  109. for result in results:
  110. logger.debug(
  111. f'Doc ID: {result.doc_id}:\n Text: {result.get_text()}\n Score: {result.score}'
  112. )
  113. return [r.get_text() for r in results]
  114. def _events_to_docs(self) -> list['Document']:
  115. """Convert all events from the EventStream to documents for batch insert into the index."""
  116. try:
  117. events = self.event_stream.get_events()
  118. except Exception as e:
  119. logger.debug(f'No events found for session {self.event_stream.sid}: {e}')
  120. return []
  121. documents: list[Document] = []
  122. for event in events:
  123. try:
  124. # convert the event to a memory-friendly format, and don't truncate
  125. event_data = event_to_memory(event, -1)
  126. # determine the event type and ID
  127. event_type = ''
  128. event_id = ''
  129. if 'action' in event_data:
  130. event_type = 'action'
  131. event_id = event_data['action']
  132. elif 'observation' in event_data:
  133. event_type = 'observation'
  134. event_id = event_data['observation']
  135. # create a Document instance for the event
  136. doc = Document(
  137. text=json.dumps(event_data),
  138. doc_id=str(self.thought_idx),
  139. extra_info={
  140. 'type': event_type,
  141. 'id': event_id,
  142. 'idx': self.thought_idx,
  143. },
  144. )
  145. documents.append(doc)
  146. self.thought_idx += 1
  147. except (json.JSONDecodeError, KeyError, ValueError) as e:
  148. logger.warning(f'Failed to process event: {e}')
  149. continue
  150. if documents:
  151. logger.debug(f'Batch inserting {len(documents)} documents into the index.')
  152. else:
  153. logger.debug('No valid documents found to insert into the index.')
  154. return documents
  155. def create_nodes(self, documents: list['Document']) -> list['TextNode']:
  156. """Create nodes from a list of documents."""
  157. return [self._create_node(doc) for doc in documents]