memory.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import json
  2. from openhands.core.config import AgentConfig, LLMConfig
  3. from openhands.core.logger import openhands_logger as logger
  4. from openhands.events.event import Event
  5. from openhands.events.serialization.event import event_to_memory
  6. from openhands.events.stream import EventStream
  7. from openhands.utils.embeddings import (
  8. LLAMA_INDEX_AVAILABLE,
  9. EmbeddingsLoader,
  10. check_llama_index,
  11. )
  12. # Conditional imports based on llama_index availability
  13. if LLAMA_INDEX_AVAILABLE:
  14. import chromadb
  15. from llama_index.core import Document
  16. from llama_index.core.indices.vector_store.base import VectorStoreIndex
  17. from llama_index.core.indices.vector_store.retrievers.retriever import (
  18. VectorIndexRetriever,
  19. )
  20. from llama_index.core.schema import TextNode
  21. from llama_index.vector_stores.chroma import ChromaVectorStore
  22. class LongTermMemory:
  23. """Handles storing information for the agent to access later, using chromadb."""
  24. event_stream: EventStream
  25. def __init__(
  26. self,
  27. llm_config: LLMConfig,
  28. agent_config: AgentConfig,
  29. event_stream: EventStream,
  30. ):
  31. """Initialize the chromadb and set up ChromaVectorStore for later use."""
  32. check_llama_index()
  33. # initialize the chromadb client
  34. db = chromadb.PersistentClient(
  35. path=f'./cache/sessions/{event_stream.sid}/memory',
  36. # FIXME anonymized_telemetry=False,
  37. )
  38. self.collection = db.get_or_create_collection(name='memories')
  39. vector_store = ChromaVectorStore(chroma_collection=self.collection)
  40. # embedding model
  41. embedding_strategy = llm_config.embedding_model
  42. self.embed_model = EmbeddingsLoader.get_embedding_model(
  43. embedding_strategy, llm_config
  44. )
  45. # instantiate the index
  46. self.index = VectorStoreIndex.from_vector_store(vector_store, self.embed_model)
  47. self.thought_idx = 0
  48. # initialize the event stream
  49. self.event_stream = event_stream
  50. # max of threads to run the pipeline
  51. self.memory_max_threads = agent_config.memory_max_threads
  52. def add_event(self, event: Event):
  53. """Adds a new event to the long term memory with a unique id.
  54. Parameters:
  55. - event: The new event to be added to memory
  56. """
  57. try:
  58. # convert the event to a memory-friendly format, and don't truncate
  59. event_data = event_to_memory(event, -1)
  60. except (json.JSONDecodeError, KeyError, ValueError) as e:
  61. logger.warning(f'Failed to process event: {e}')
  62. return
  63. # determine the event type and ID
  64. event_type = ''
  65. event_id = ''
  66. if 'action' in event_data:
  67. event_type = 'action'
  68. event_id = event_data['action']
  69. elif 'observation' in event_data:
  70. event_type = 'observation'
  71. event_id = event_data['observation']
  72. # create a Document instance for the event
  73. doc = Document(
  74. text=json.dumps(event_data),
  75. doc_id=str(self.thought_idx),
  76. extra_info={
  77. 'type': event_type,
  78. 'id': event_id,
  79. 'idx': self.thought_idx,
  80. },
  81. )
  82. self.thought_idx += 1
  83. logger.debug('Adding %s event to memory: %d', event_type, self.thought_idx)
  84. self._add_document(document=doc)
  85. def _add_document(self, document: 'Document'):
  86. """Inserts a single document into the index."""
  87. self.index.insert_nodes([self._create_node(document)])
  88. def _create_node(self, document: 'Document') -> 'TextNode':
  89. """Create a TextNode from a Document instance."""
  90. return TextNode(
  91. text=document.text,
  92. doc_id=document.doc_id,
  93. extra_info=document.extra_info,
  94. )
  95. def search(self, query: str, k: int = 10) -> list[str]:
  96. """Searches through the current memory using VectorIndexRetriever.
  97. Parameters:
  98. - query (str): A query to match search results to
  99. - k (int): Number of top results to return
  100. Returns:
  101. - list[str]: List of top k results found in current memory
  102. """
  103. retriever = VectorIndexRetriever(
  104. index=self.index,
  105. similarity_top_k=k,
  106. )
  107. results = retriever.retrieve(query)
  108. for result in results:
  109. logger.debug(
  110. f'Doc ID: {result.doc_id}:\n Text: {result.get_text()}\n Score: {result.score}'
  111. )
  112. return [r.get_text() for r in results]
  113. def _events_to_docs(self) -> list['Document']:
  114. """Convert all events from the EventStream to documents for batch insert into the index."""
  115. try:
  116. events = self.event_stream.get_events()
  117. except Exception as e:
  118. logger.debug(f'No events found for session {self.event_stream.sid}: {e}')
  119. return []
  120. documents: list[Document] = []
  121. for event in events:
  122. try:
  123. # convert the event to a memory-friendly format, and don't truncate
  124. event_data = event_to_memory(event, -1)
  125. # determine the event type and ID
  126. event_type = ''
  127. event_id = ''
  128. if 'action' in event_data:
  129. event_type = 'action'
  130. event_id = event_data['action']
  131. elif 'observation' in event_data:
  132. event_type = 'observation'
  133. event_id = event_data['observation']
  134. # create a Document instance for the event
  135. doc = Document(
  136. text=json.dumps(event_data),
  137. doc_id=str(self.thought_idx),
  138. extra_info={
  139. 'type': event_type,
  140. 'id': event_id,
  141. 'idx': self.thought_idx,
  142. },
  143. )
  144. documents.append(doc)
  145. self.thought_idx += 1
  146. except (json.JSONDecodeError, KeyError, ValueError) as e:
  147. logger.warning(f'Failed to process event: {e}')
  148. continue
  149. if documents:
  150. logger.debug(f'Batch inserting {len(documents)} documents into the index.')
  151. else:
  152. logger.debug('No valid documents found to insert into the index.')
  153. return documents
  154. def create_nodes(self, documents: list['Document']) -> list['TextNode']:
  155. """Create nodes from a list of documents."""
  156. return [self._create_node(doc) for doc in documents]