memory.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import threading
  2. from openai._exceptions import APIConnectionError, InternalServerError, RateLimitError
  3. from tenacity import (
  4. retry,
  5. retry_if_exception_type,
  6. stop_after_attempt,
  7. wait_random_exponential,
  8. )
  9. from openhands.core.config import LLMConfig
  10. from openhands.core.logger import openhands_logger as logger
  11. from openhands.core.utils import json
  12. from openhands.utils.tenacity_stop import stop_if_should_exit
  13. try:
  14. import chromadb
  15. import llama_index.embeddings.openai.base as llama_openai
  16. from llama_index.core import Document, VectorStoreIndex
  17. from llama_index.core.retrievers import VectorIndexRetriever
  18. from llama_index.vector_stores.chroma import ChromaVectorStore
  19. LLAMA_INDEX_AVAILABLE = True
  20. except ImportError:
  21. LLAMA_INDEX_AVAILABLE = False
  22. if LLAMA_INDEX_AVAILABLE:
  23. # TODO: this could be made configurable
  24. num_retries: int = 10
  25. retry_min_wait: int = 3
  26. retry_max_wait: int = 300
  27. # llama-index includes a retry decorator around openai.get_embeddings() function
  28. # it is initialized with hard-coded values and errors
  29. # this non-customizable behavior is creating issues when it's retrying faster than providers' rate limits
  30. # this block attempts to banish it and replace it with our decorator, to allow users to set their own limits
  31. if hasattr(llama_openai.get_embeddings, '__wrapped__'):
  32. original_get_embeddings = llama_openai.get_embeddings.__wrapped__
  33. else:
  34. logger.warning('Cannot set custom retry limits.')
  35. num_retries = 1
  36. original_get_embeddings = llama_openai.get_embeddings
  37. def attempt_on_error(retry_state):
  38. logger.error(
  39. f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
  40. exc_info=False,
  41. )
  42. return None
  43. @retry(
  44. reraise=True,
  45. stop=stop_after_attempt(num_retries) | stop_if_should_exit(),
  46. wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait),
  47. retry=retry_if_exception_type(
  48. (RateLimitError, APIConnectionError, InternalServerError)
  49. ),
  50. after=attempt_on_error,
  51. )
  52. def wrapper_get_embeddings(*args, **kwargs):
  53. return original_get_embeddings(*args, **kwargs)
  54. llama_openai.get_embeddings = wrapper_get_embeddings
  55. class EmbeddingsLoader:
  56. """Loader for embedding model initialization."""
  57. @staticmethod
  58. def get_embedding_model(strategy: str, llm_config: LLMConfig):
  59. supported_ollama_embed_models = [
  60. 'llama2',
  61. 'mxbai-embed-large',
  62. 'nomic-embed-text',
  63. 'all-minilm',
  64. 'stable-code',
  65. 'bge-m3',
  66. 'bge-large',
  67. 'paraphrase-multilingual',
  68. 'snowflake-arctic-embed',
  69. ]
  70. if strategy in supported_ollama_embed_models:
  71. from llama_index.embeddings.ollama import OllamaEmbedding
  72. return OllamaEmbedding(
  73. model_name=strategy,
  74. base_url=llm_config.embedding_base_url,
  75. ollama_additional_kwargs={'mirostat': 0},
  76. )
  77. elif strategy == 'openai':
  78. from llama_index.embeddings.openai import OpenAIEmbedding
  79. return OpenAIEmbedding(
  80. model='text-embedding-ada-002',
  81. api_key=llm_config.api_key,
  82. )
  83. elif strategy == 'azureopenai':
  84. from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
  85. return AzureOpenAIEmbedding(
  86. model='text-embedding-ada-002',
  87. deployment_name=llm_config.embedding_deployment_name,
  88. api_key=llm_config.api_key,
  89. azure_endpoint=llm_config.base_url,
  90. api_version=llm_config.api_version,
  91. )
  92. elif (strategy is not None) and (strategy.lower() == 'none'):
  93. # TODO: this works but is not elegant enough. The incentive is when
  94. # an agent using embeddings is not used, there is no reason we need to
  95. # initialize an embedding model
  96. return None
  97. else:
  98. from llama_index.embeddings.huggingface import HuggingFaceEmbedding
  99. return HuggingFaceEmbedding(model_name='BAAI/bge-small-en-v1.5')
  100. class LongTermMemory:
  101. """Handles storing information for the agent to access later, using chromadb."""
  102. def __init__(self, llm_config: LLMConfig, memory_max_threads: int = 1):
  103. """Initialize the chromadb and set up ChromaVectorStore for later use."""
  104. if not LLAMA_INDEX_AVAILABLE:
  105. raise ImportError(
  106. 'llama_index and its dependencies are not installed. '
  107. 'To use LongTermMemory, please run: poetry install --with llama-index'
  108. )
  109. db = chromadb.Client(chromadb.Settings(anonymized_telemetry=False))
  110. self.collection = db.get_or_create_collection(name='memories')
  111. vector_store = ChromaVectorStore(chroma_collection=self.collection)
  112. embedding_strategy = llm_config.embedding_model
  113. embed_model = EmbeddingsLoader.get_embedding_model(
  114. embedding_strategy, llm_config
  115. )
  116. self.index = VectorStoreIndex.from_vector_store(vector_store, embed_model)
  117. self.sema = threading.Semaphore(value=memory_max_threads)
  118. self.thought_idx = 0
  119. self._add_threads: list[threading.Thread] = []
  120. def add_event(self, event: dict):
  121. """Adds a new event to the long term memory with a unique id.
  122. Parameters:
  123. - event (dict): The new event to be added to memory
  124. """
  125. id = ''
  126. t = ''
  127. if 'action' in event:
  128. t = 'action'
  129. id = event['action']
  130. elif 'observation' in event:
  131. t = 'observation'
  132. id = event['observation']
  133. doc = Document(
  134. text=json.dumps(event),
  135. doc_id=str(self.thought_idx),
  136. extra_info={
  137. 'type': t,
  138. 'id': id,
  139. 'idx': self.thought_idx,
  140. },
  141. )
  142. self.thought_idx += 1
  143. logger.debug('Adding %s event to memory: %d', t, self.thought_idx)
  144. thread = threading.Thread(target=self._add_doc, args=(doc,))
  145. self._add_threads.append(thread)
  146. thread.start() # We add the doc concurrently so we don't have to wait ~500ms for the insert
  147. def _add_doc(self, doc):
  148. with self.sema:
  149. self.index.insert(doc)
  150. def search(self, query: str, k: int = 10):
  151. """Searches through the current memory using VectorIndexRetriever
  152. Parameters:
  153. - query (str): A query to match search results to
  154. - k (int): Number of top results to return
  155. Returns:
  156. - list[str]: list of top k results found in current memory
  157. """
  158. retriever = VectorIndexRetriever(
  159. index=self.index,
  160. similarity_top_k=k,
  161. )
  162. results = retriever.retrieve(query)
  163. return [r.get_text() for r in results]