Просмотр исходного кода

Remove global config from bedrock (#2954)

Graham Neubig 1 год назад
Родитель
Сommit
2c982582d7
2 измененных файлов с 23 добавлено и 32 удалено
  1. 6 27
      opendevin/llm/bedrock.py
  2. 17 5
      opendevin/server/listen.py

+ 6 - 27
opendevin/llm/bedrock.py

@@ -1,39 +1,18 @@
-import os
-
 import boto3
 
-from opendevin.core.config import config
 from opendevin.core.logger import opendevin_logger as logger
 
-# TODO: this assumes AWS-specific configs are under default 'llm' group
-AWS_ACCESS_KEY_ID = config.get_llm_config().aws_access_key_id
-AWS_SECRET_ACCESS_KEY = config.get_llm_config().aws_secret_access_key
-AWS_REGION_NAME = config.get_llm_config().aws_region_name
-
-# It needs to be set as an environment variable, if the variable is configured in the Config file.
-if AWS_ACCESS_KEY_ID is not None:
-    os.environ['AWS_ACCESS_KEY_ID'] = AWS_ACCESS_KEY_ID
-if AWS_SECRET_ACCESS_KEY is not None:
-    os.environ['AWS_SECRET_ACCESS_KEY'] = AWS_SECRET_ACCESS_KEY
-if AWS_REGION_NAME is not None:
-    os.environ['AWS_REGION_NAME'] = AWS_REGION_NAME
 
-
-def list_foundation_models():
+def list_foundation_models(
+    aws_region_name: str, aws_access_key_id: str, aws_secret_access_key: str
+) -> list[str]:
     try:
         # The AWS bedrock model id is not queried, if no AWS parameters are configured.
-        if (
-            AWS_REGION_NAME is None
-            or AWS_ACCESS_KEY_ID is None
-            or AWS_SECRET_ACCESS_KEY is None
-        ):
-            return []
-
         client = boto3.client(
             service_name='bedrock',
-            region_name=AWS_REGION_NAME,
-            aws_access_key_id=AWS_ACCESS_KEY_ID,
-            aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
+            region_name=aws_region_name,
+            aws_access_key_id=aws_access_key_id,
+            aws_secret_access_key=aws_secret_access_key,
         )
         foundation_models_list = client.list_foundation_models(
             byOutputModality='TEXT', byInferenceType='ON_DEMAND'

+ 17 - 5
opendevin/server/listen.py

@@ -29,7 +29,7 @@ from fastapi.staticfiles import StaticFiles
 
 import agenthub  # noqa F401 (we import this to get the agents registered)
 from opendevin.controller.agent import Agent
-from opendevin.core.config import config
+from opendevin.core.config import LLMConfig, config
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.schema import AgentState  # Add this import
 from opendevin.events.action import ChangeAgentStateAction, NullAction
@@ -280,8 +280,9 @@ async def websocket_endpoint(websocket: WebSocket):
 
 
 @app.get('/api/options/models')
-async def get_litellm_models():
-    """Get all models supported by LiteLLM.
+async def get_litellm_models() -> list[str]:
+    """
+    Get all models supported by LiteLLM.
 
     This function combines models from litellm and Bedrock, removing any
     error-prone Bedrock models.
@@ -298,8 +299,19 @@ async def get_litellm_models():
     litellm_model_list_without_bedrock = bedrock.remove_error_modelId(
         litellm_model_list
     )
-    bedrock_model_list = bedrock.list_foundation_models()
-    model_list = litellm_model_list_without_bedrock + bedrock_model_list
+    # TODO: for bedrock, this is using the default config
+    llm_config: LLMConfig = config.get_llm_config()
+    if (
+        llm_config.aws_region_name
+        and llm_config.aws_access_key_id
+        and llm_config.aws_secret_access_key
+    ):
+        bedrock_model_list = bedrock.list_foundation_models(
+            llm_config.aws_region_name,
+            llm_config.aws_access_key_id,
+            llm_config.aws_secret_access_key,
+        )
+        model_list = litellm_model_list_without_bedrock + bedrock_model_list
     for llm_config in config.llms.values():
         ollama_base_url = llm_config.ollama_base_url
         if llm_config.model.startswith('ollama'):