memory.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import os
  2. import chromadb
  3. from llama_index.core import Document
  4. from llama_index.core.retrievers import VectorIndexRetriever
  5. from llama_index.core import VectorStoreIndex
  6. from llama_index.vector_stores.chroma import ChromaVectorStore
  7. from . import json
  8. embedding_strategy = os.getenv("LLM_EMBEDDING_MODEL", "local")
  9. # TODO: More embeddings: https://docs.llamaindex.ai/en/stable/examples/embeddings/OpenAI/
  10. # There's probably a more programmatic way to do this.
  11. if embedding_strategy == "llama2":
  12. from llama_index.embeddings.ollama import OllamaEmbedding
  13. embed_model = OllamaEmbedding(
  14. model_name="llama2",
  15. base_url=os.getenv("LLM_BASE_URL", "http://localhost:8000"),
  16. ollama_additional_kwargs={"mirostat": 0},
  17. )
  18. elif embedding_strategy == "openai":
  19. from llama_index.embeddings.openai import OpenAIEmbedding
  20. embed_model = OpenAIEmbedding(
  21. base_url=os.getenv("LLM_BASE_URL"),
  22. )
  23. elif embedding_strategy == "azureopenai":
  24. from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding # Need to instruct to set these env variables in documentation
  25. embed_model = AzureOpenAIEmbedding(
  26. model="text-embedding-ada-002",
  27. deployment_name=os.getenv("LLM_DEPLOYMENT_NAME"),
  28. api_key=os.getenv("LLM_API_KEY"),
  29. azure_endpoint=os.getenv("LLM_BASE_URL"),
  30. api_version=os.getenv("LLM_API_VERSION"),
  31. )
  32. else:
  33. from llama_index.embeddings.huggingface import HuggingFaceEmbedding
  34. embed_model = HuggingFaceEmbedding(
  35. model_name="BAAI/bge-small-en-v1.5"
  36. )
  37. class LongTermMemory:
  38. def __init__(self):
  39. db = chromadb.Client()
  40. self.collection = db.get_or_create_collection(name="memories")
  41. vector_store = ChromaVectorStore(chroma_collection=self.collection)
  42. self.index = VectorStoreIndex.from_vector_store(vector_store, embed_model=embed_model)
  43. self.thought_idx = 0
  44. def add_event(self, event):
  45. doc = Document(
  46. text=json.dumps(event),
  47. doc_id=str(self.thought_idx),
  48. extra_info={
  49. "type": event["action"],
  50. "idx": self.thought_idx,
  51. },
  52. )
  53. self.thought_idx += 1
  54. self.index.insert(doc)
  55. def search(self, query, k=10):
  56. retriever = VectorIndexRetriever(
  57. index=self.index,
  58. similarity_top_k=k,
  59. )
  60. results = retriever.retrieve(query)
  61. return [r.get_text() for r in results]