ソースを参照

fix: corrected bedrock model list (#1513)

- auto set environment variable
- add criteria for querying the AWS bedrock model
zhaoninge 1 年間 前
コミット
6150ab6a3e

+ 3 - 0
opendevin/core/config.py

@@ -23,6 +23,9 @@ DEFAULT_CONFIG: dict = {
     ConfigType.LLM_API_KEY: None,
     ConfigType.LLM_API_KEY: None,
     ConfigType.LLM_BASE_URL: None,
     ConfigType.LLM_BASE_URL: None,
     ConfigType.LLM_CUSTOM_LLM_PROVIDER: None,
     ConfigType.LLM_CUSTOM_LLM_PROVIDER: None,
+    ConfigType.AWS_ACCESS_KEY_ID: None,
+    ConfigType.AWS_SECRET_ACCESS_KEY: None,
+    ConfigType.AWS_REGION_NAME: None,
     ConfigType.WORKSPACE_BASE: os.getcwd(),
     ConfigType.WORKSPACE_BASE: os.getcwd(),
     ConfigType.WORKSPACE_MOUNT_PATH: None,
     ConfigType.WORKSPACE_MOUNT_PATH: None,
     ConfigType.WORKSPACE_MOUNT_PATH_IN_SANDBOX: '/workspace',
     ConfigType.WORKSPACE_MOUNT_PATH_IN_SANDBOX: '/workspace',

+ 3 - 0
opendevin/core/schema/config.py

@@ -10,6 +10,9 @@ class ConfigType(str, Enum):
     LLM_TIMEOUT = 'LLM_TIMEOUT'
     LLM_TIMEOUT = 'LLM_TIMEOUT'
     LLM_API_KEY = 'LLM_API_KEY'
     LLM_API_KEY = 'LLM_API_KEY'
     LLM_BASE_URL = 'LLM_BASE_URL'
     LLM_BASE_URL = 'LLM_BASE_URL'
+    AWS_ACCESS_KEY_ID = 'AWS_ACCESS_KEY_ID'
+    AWS_SECRET_ACCESS_KEY = 'AWS_SECRET_ACCESS_KEY'
+    AWS_REGION_NAME = 'AWS_REGION_NAME'
     WORKSPACE_BASE = 'WORKSPACE_BASE'
     WORKSPACE_BASE = 'WORKSPACE_BASE'
     WORKSPACE_MOUNT_PATH = 'WORKSPACE_MOUNT_PATH'
     WORKSPACE_MOUNT_PATH = 'WORKSPACE_MOUNT_PATH'
     WORKSPACE_MOUNT_REWRITE = 'WORKSPACE_MOUNT_REWRITE'
     WORKSPACE_MOUNT_REWRITE = 'WORKSPACE_MOUNT_REWRITE'

+ 53 - 0
opendevin/llm/bedrock.py

@@ -0,0 +1,53 @@
+import os
+
+import boto3
+
+from opendevin.core import config
+from opendevin.core.logger import opendevin_logger as logger
+from opendevin.core.schema import ConfigType
+
+AWS_ACCESS_KEY_ID = config.get(ConfigType.AWS_ACCESS_KEY_ID)
+AWS_SECRET_ACCESS_KEY = config.get(ConfigType.AWS_SECRET_ACCESS_KEY)
+AWS_REGION_NAME = config.get(ConfigType.AWS_REGION_NAME)
+
+# It needs to be set as an environment variable, if the variable is configured in the Config file.
+if AWS_ACCESS_KEY_ID is not None:
+    os.environ[ConfigType.AWS_ACCESS_KEY_ID] = AWS_ACCESS_KEY_ID
+if AWS_SECRET_ACCESS_KEY is not None:
+    os.environ[ConfigType.AWS_SECRET_ACCESS_KEY] = AWS_SECRET_ACCESS_KEY
+if AWS_REGION_NAME is not None:
+    os.environ[ConfigType.AWS_REGION_NAME] = AWS_REGION_NAME
+
+
+def list_foundation_models():
+    try:
+        # The AWS bedrock model id is not queried, if no AWS parameters are configured.
+        if (
+            AWS_REGION_NAME is None
+            or AWS_ACCESS_KEY_ID is None
+            or AWS_SECRET_ACCESS_KEY is None
+        ):
+            return []
+
+        client = boto3.client(
+            service_name='bedrock',
+            region_name=AWS_REGION_NAME,
+            aws_access_key_id=AWS_ACCESS_KEY_ID,
+            aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
+        )
+        foundation_models_list = client.list_foundation_models(
+            byOutputModality='TEXT', byInferenceType='ON_DEMAND'
+        )
+        model_summaries = foundation_models_list['modelSummaries']
+        return ['bedrock/' + model['modelId'] for model in model_summaries]
+    except Exception as err:
+        logger.warning(
+            '%s. Please config AWS_REGION_NAME AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY'
+            ' if you want use bedrock model.',
+            err,
+        )
+        return []
+
+
+def remove_error_modelId(model_list):
+    return list(filter(lambda m: not m.startswith('bedrock'), model_list))

+ 9 - 1
opendevin/server/listen.py

@@ -15,6 +15,7 @@ from opendevin.controller.agent import Agent
 from opendevin.core import config
 from opendevin.core import config
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.schema.config import ConfigType
 from opendevin.core.schema.config import ConfigType
+from opendevin.llm import bedrock
 from opendevin.runtime import files
 from opendevin.runtime import files
 from opendevin.server.agent import agent_manager
 from opendevin.server.agent import agent_manager
 from opendevin.server.auth import get_sid_from_token, sign_token
 from opendevin.server.auth import get_sid_from_token, sign_token
@@ -50,7 +51,14 @@ async def get_litellm_models():
     """
     """
     Get all models supported by LiteLLM.
     Get all models supported by LiteLLM.
     """
     """
-    return list(set(litellm.model_list + list(litellm.model_cost.keys())))
+    litellm_model_list = litellm.model_list + list(litellm.model_cost.keys())
+    litellm_model_list_without_bedrock = bedrock.remove_error_modelId(
+        litellm_model_list
+    )
+    bedrock_model_list = bedrock.list_foundation_models()
+    model_list = litellm_model_list_without_bedrock + bedrock_model_list
+
+    return list(set(model_list))
 
 
 
 
 @app.get('/api/agents')
 @app.get('/api/agents')