|
|
@@ -7,36 +7,37 @@ from llama_index.vector_stores.chroma import ChromaVectorStore
|
|
|
from opendevin import config
|
|
|
from . import json
|
|
|
|
|
|
-embedding_strategy = config.get("LLM_EMBEDDING_MODEL")
|
|
|
+embedding_strategy = config.get('LLM_EMBEDDING_MODEL')
|
|
|
|
|
|
# TODO: More embeddings: https://docs.llamaindex.ai/en/stable/examples/embeddings/OpenAI/
|
|
|
# There's probably a more programmatic way to do this.
|
|
|
-if embedding_strategy == "llama2":
|
|
|
+if embedding_strategy == 'llama2':
|
|
|
from llama_index.embeddings.ollama import OllamaEmbedding
|
|
|
embed_model = OllamaEmbedding(
|
|
|
- model_name="llama2",
|
|
|
- base_url=config.get_or_error("LLM_BASE_URL"),
|
|
|
- ollama_additional_kwargs={"mirostat": 0},
|
|
|
+ model_name='llama2',
|
|
|
+ base_url=config.get('LLM_BASE_URL', required=True),
|
|
|
+ ollama_additional_kwargs={'mirostat': 0},
|
|
|
)
|
|
|
-elif embedding_strategy == "openai":
|
|
|
+elif embedding_strategy == 'openai':
|
|
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
|
|
embed_model = OpenAIEmbedding(
|
|
|
- model="text-embedding-ada-002",
|
|
|
- api_key=config.get_or_error("LLM_API_KEY")
|
|
|
+ model='text-embedding-ada-002',
|
|
|
+ api_key=config.get('LLM_API_KEY', required=True)
|
|
|
)
|
|
|
-elif embedding_strategy == "azureopenai":
|
|
|
- from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding # Need to instruct to set these env variables in documentation
|
|
|
+elif embedding_strategy == 'azureopenai':
|
|
|
+ # Need to instruct to set these env variables in documentation
|
|
|
+ from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
|
|
embed_model = AzureOpenAIEmbedding(
|
|
|
- model="text-embedding-ada-002",
|
|
|
- deployment_name=config.get_or_error("LLM_DEPLOYMENT_NAME"),
|
|
|
- api_key=config.get_or_error("LLM_API_KEY"),
|
|
|
- azure_endpoint=config.get_or_error("LLM_BASE_URL"),
|
|
|
- api_version=config.get_or_error("LLM_API_VERSION"),
|
|
|
+ model='text-embedding-ada-002',
|
|
|
+ deployment_name=config.get('LLM_DEPLOYMENT_NAME', required=True),
|
|
|
+ api_key=config.get('LLM_API_KEY', required=True),
|
|
|
+ azure_endpoint=config.get('LLM_BASE_URL', required=True),
|
|
|
+ api_version=config.get('LLM_API_VERSION', required=True),
|
|
|
)
|
|
|
else:
|
|
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
|
|
embed_model = HuggingFaceEmbedding(
|
|
|
- model_name="BAAI/bge-small-en-v1.5"
|
|
|
+ model_name='BAAI/bge-small-en-v1.5'
|
|
|
)
|
|
|
|
|
|
|
|
|
@@ -51,9 +52,10 @@ class LongTermMemory:
|
|
|
Initialize the chromadb and set up ChromaVectorStore for later use.
|
|
|
"""
|
|
|
db = chromadb.Client()
|
|
|
- self.collection = db.get_or_create_collection(name="memories")
|
|
|
+ self.collection = db.get_or_create_collection(name='memories')
|
|
|
vector_store = ChromaVectorStore(chroma_collection=self.collection)
|
|
|
- self.index = VectorStoreIndex.from_vector_store(vector_store, embed_model=embed_model)
|
|
|
+ self.index = VectorStoreIndex.from_vector_store(
|
|
|
+ vector_store, embed_model=embed_model)
|
|
|
self.thought_idx = 0
|
|
|
|
|
|
def add_event(self, event: dict):
|
|
|
@@ -63,27 +65,27 @@ class LongTermMemory:
|
|
|
Parameters:
|
|
|
- event (dict): The new event to be added to memory
|
|
|
"""
|
|
|
- id = ""
|
|
|
- t = ""
|
|
|
- if "action" in event:
|
|
|
- t = "action"
|
|
|
- id = event["action"]
|
|
|
- elif "observation" in event:
|
|
|
- t = "observation"
|
|
|
- id = event["observation"]
|
|
|
+ id = ''
|
|
|
+ t = ''
|
|
|
+ if 'action' in event:
|
|
|
+ t = 'action'
|
|
|
+ id = event['action']
|
|
|
+ elif 'observation' in event:
|
|
|
+ t = 'observation'
|
|
|
+ id = event['observation']
|
|
|
doc = Document(
|
|
|
text=json.dumps(event),
|
|
|
doc_id=str(self.thought_idx),
|
|
|
extra_info={
|
|
|
- "type": t,
|
|
|
- "id": id,
|
|
|
- "idx": self.thought_idx,
|
|
|
+ 'type': t,
|
|
|
+ 'id': id,
|
|
|
+ 'idx': self.thought_idx,
|
|
|
},
|
|
|
)
|
|
|
self.thought_idx += 1
|
|
|
self.index.insert(doc)
|
|
|
|
|
|
- def search(self, query: str, k: int=10):
|
|
|
+ def search(self, query: str, k: int = 10):
|
|
|
"""
|
|
|
Searches through the current memory using VectorIndexRetriever
|
|
|
|
|
|
@@ -100,5 +102,3 @@ class LongTermMemory:
|
|
|
)
|
|
|
results = retriever.retrieve(query)
|
|
|
return [r.get_text() for r in results]
|
|
|
-
|
|
|
-
|