memory.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from threading import Thread
  2. import chromadb
  3. from llama_index.core import Document
  4. from llama_index.core.retrievers import VectorIndexRetriever
  5. from llama_index.core import VectorStoreIndex
  6. from llama_index.vector_stores.chroma import ChromaVectorStore
  7. from opendevin import config
  8. from opendevin.logger import opendevin_logger as logger
  9. from . import json
  10. embedding_strategy = config.get('LLM_EMBEDDING_MODEL')
  11. # TODO: More embeddings: https://docs.llamaindex.ai/en/stable/examples/embeddings/OpenAI/
  12. # There's probably a more programmatic way to do this.
  13. if embedding_strategy == 'llama2':
  14. from llama_index.embeddings.ollama import OllamaEmbedding
  15. embed_model = OllamaEmbedding(
  16. model_name='llama2',
  17. base_url=config.get('LLM_BASE_URL', required=True),
  18. ollama_additional_kwargs={'mirostat': 0},
  19. )
  20. elif embedding_strategy == 'openai':
  21. from llama_index.embeddings.openai import OpenAIEmbedding
  22. embed_model = OpenAIEmbedding(
  23. model='text-embedding-ada-002',
  24. api_key=config.get('LLM_API_KEY', required=True)
  25. )
  26. elif embedding_strategy == 'azureopenai':
  27. # Need to instruct to set these env variables in documentation
  28. from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
  29. embed_model = AzureOpenAIEmbedding(
  30. model='text-embedding-ada-002',
  31. deployment_name=config.get('LLM_DEPLOYMENT_NAME', required=True),
  32. api_key=config.get('LLM_API_KEY', required=True),
  33. azure_endpoint=config.get('LLM_BASE_URL', required=True),
  34. api_version=config.get('LLM_API_VERSION', required=True),
  35. )
  36. else:
  37. from llama_index.embeddings.huggingface import HuggingFaceEmbedding
  38. embed_model = HuggingFaceEmbedding(
  39. model_name='BAAI/bge-small-en-v1.5'
  40. )
  41. class LongTermMemory:
  42. """
  43. Responsible for storing information that the agent can call on later for better insights and context.
  44. Uses chromadb to store and search through memories.
  45. """
  46. def __init__(self):
  47. """
  48. Initialize the chromadb and set up ChromaVectorStore for later use.
  49. """
  50. db = chromadb.Client()
  51. self.collection = db.get_or_create_collection(name='memories')
  52. vector_store = ChromaVectorStore(chroma_collection=self.collection)
  53. self.index = VectorStoreIndex.from_vector_store(
  54. vector_store, embed_model=embed_model)
  55. self.thought_idx = 0
  56. def add_event(self, event: dict):
  57. """
  58. Adds a new event to the long term memory with a unique id.
  59. Parameters:
  60. - event (dict): The new event to be added to memory
  61. """
  62. id = ''
  63. t = ''
  64. if 'action' in event:
  65. t = 'action'
  66. id = event['action']
  67. elif 'observation' in event:
  68. t = 'observation'
  69. id = event['observation']
  70. doc = Document(
  71. text=json.dumps(event),
  72. doc_id=str(self.thought_idx),
  73. extra_info={
  74. 'type': t,
  75. 'id': id,
  76. 'idx': self.thought_idx,
  77. },
  78. )
  79. self.thought_idx += 1
  80. logger.debug('Adding %s event to memory: %d', t, self.thought_idx)
  81. thread = Thread(target=self._add_doc, args=(doc,))
  82. thread.start() # We add the doc concurrently so we don't have to wait ~500ms for the insert
  83. def _add_doc(self, doc):
  84. self.index.insert(doc)
  85. def search(self, query: str, k: int = 10):
  86. """
  87. Searches through the current memory using VectorIndexRetriever
  88. Parameters:
  89. - query (str): A query to match search results to
  90. - k (int): Number of top results to return
  91. Returns:
  92. - List[str]: List of top k results found in current memory
  93. """
  94. retriever = VectorIndexRetriever(
  95. index=self.index,
  96. similarity_top_k=k,
  97. )
  98. results = retriever.retrieve(query)
  99. return [r.get_text() for r in results]