embeddings.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import importlib.util
  2. import os
  3. from joblib import Parallel, delayed
  4. from openhands.core.config import LLMConfig
  5. from openhands.core.logger import openhands_logger as logger
  6. try:
  7. # check if those we need later are available using importlib
  8. if importlib.util.find_spec('chromadb') is None:
  9. raise ImportError(
  10. 'chromadb is not available. Please install it using poetry install --with llama-index'
  11. )
  12. if (
  13. importlib.util.find_spec(
  14. 'llama_index.core.indices.vector_store.retrievers.retriever'
  15. )
  16. is None
  17. or importlib.util.find_spec('llama_index.core.indices.vector_store.base')
  18. is None
  19. ):
  20. raise ImportError(
  21. 'llama_index is not available. Please install it using poetry install --with llama-index'
  22. )
  23. from llama_index.core import Document, VectorStoreIndex
  24. from llama_index.core.base.embeddings.base import BaseEmbedding
  25. from llama_index.core.ingestion import IngestionPipeline
  26. from llama_index.core.schema import TextNode
  27. LLAMA_INDEX_AVAILABLE = True
  28. except ImportError:
  29. LLAMA_INDEX_AVAILABLE = False
  30. # Define supported embedding models
  31. SUPPORTED_OLLAMA_EMBED_MODELS = [
  32. 'llama2',
  33. 'mxbai-embed-large',
  34. 'nomic-embed-text',
  35. 'all-minilm',
  36. 'stable-code',
  37. 'bge-m3',
  38. 'bge-large',
  39. 'paraphrase-multilingual',
  40. 'snowflake-arctic-embed',
  41. ]
  42. def check_llama_index():
  43. """Utility function to check the availability of llama_index.
  44. Raises:
  45. ImportError: If llama_index is not available.
  46. """
  47. if not LLAMA_INDEX_AVAILABLE:
  48. raise ImportError(
  49. 'llama_index and its dependencies are not installed. '
  50. 'To use memory features, please run: poetry install --with llama-index.'
  51. )
  52. class EmbeddingsLoader:
  53. """Loader for embedding model initialization."""
  54. @staticmethod
  55. def get_embedding_model(strategy: str, llm_config: LLMConfig) -> 'BaseEmbedding':
  56. """Initialize and return the appropriate embedding model based on the strategy.
  57. Parameters:
  58. - strategy: The embedding strategy to use.
  59. - llm_config: Configuration for the LLM.
  60. Returns:
  61. - An instance of the selected embedding model or None.
  62. """
  63. if strategy in SUPPORTED_OLLAMA_EMBED_MODELS:
  64. from llama_index.embeddings.ollama import OllamaEmbedding
  65. return OllamaEmbedding(
  66. model_name=strategy,
  67. base_url=llm_config.embedding_base_url,
  68. ollama_additional_kwargs={'mirostat': 0},
  69. )
  70. elif strategy == 'openai':
  71. from llama_index.embeddings.openai import OpenAIEmbedding
  72. return OpenAIEmbedding(
  73. model='text-embedding-ada-002',
  74. api_key=llm_config.api_key,
  75. )
  76. elif strategy == 'azureopenai':
  77. from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
  78. return AzureOpenAIEmbedding(
  79. model='text-embedding-ada-002',
  80. deployment_name=llm_config.embedding_deployment_name,
  81. api_key=llm_config.api_key,
  82. azure_endpoint=llm_config.base_url,
  83. api_version=llm_config.api_version,
  84. )
  85. elif strategy == 'voyage':
  86. from llama_index.embeddings.voyageai import VoyageEmbedding
  87. return VoyageEmbedding(
  88. model_name='voyage-code-3',
  89. )
  90. elif (strategy is not None) and (strategy.lower() == 'none'):
  91. # TODO: this works but is not elegant enough. The incentive is when
  92. # an agent using embeddings is not used, there is no reason we need to
  93. # initialize an embedding model
  94. return None
  95. else:
  96. from llama_index.embeddings.huggingface import HuggingFaceEmbedding
  97. # initialize the local embedding model
  98. local_embed_model = HuggingFaceEmbedding(
  99. model_name='BAAI/bge-small-en-v1.5'
  100. )
  101. # for local embeddings, we need torch
  102. import torch
  103. # choose the best device
  104. # first determine what is available: CUDA, MPS, or CPU
  105. if torch.cuda.is_available():
  106. device = 'cuda'
  107. elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
  108. device = 'mps'
  109. else:
  110. device = 'cpu'
  111. os.environ['CUDA_VISIBLE_DEVICES'] = ''
  112. os.environ['PYTORCH_FORCE_CPU'] = (
  113. '1' # try to force CPU to avoid errors
  114. )
  115. # override CUDA availability
  116. torch.cuda.is_available = lambda: False
  117. # disable MPS to avoid errors
  118. if device != 'mps' and hasattr(torch.backends, 'mps'):
  119. torch.backends.mps.is_available = lambda: False
  120. torch.backends.mps.is_built = False
  121. # the device being used
  122. logger.debug(f'Using device for embeddings: {device}')
  123. return local_embed_model
  124. # --------------------------------------------------------------------------
  125. # Utility functions to run pipelines, split out for profiling
  126. # --------------------------------------------------------------------------
  127. def run_pipeline(
  128. embed_model: 'BaseEmbedding', documents: list['Document'], num_workers: int
  129. ) -> list['TextNode']:
  130. """Run a pipeline embedding documents."""
  131. # set up a pipeline with the transformations to make
  132. pipeline = IngestionPipeline(
  133. transformations=[
  134. embed_model,
  135. ],
  136. )
  137. # run the pipeline with num_workers
  138. nodes = pipeline.run(
  139. documents=documents, show_progress=True, num_workers=num_workers
  140. )
  141. return nodes
  142. def insert_batch_docs(
  143. index: 'VectorStoreIndex', documents: list['Document'], num_workers: int
  144. ) -> list['TextNode']:
  145. """Run the document indexing in parallel."""
  146. results = Parallel(n_jobs=num_workers, backend='threading')(
  147. delayed(index.insert)(doc) for doc in documents
  148. )
  149. return results