|
@@ -1,26 +1,54 @@
|
|
|
import os
|
|
import os
|
|
|
import uuid
|
|
import uuid
|
|
|
|
|
|
|
|
-from litellm import completion as litellm_completion
|
|
|
|
|
|
|
+from litellm.router import Router
|
|
|
from functools import partial
|
|
from functools import partial
|
|
|
|
|
|
|
|
from opendevin import config
|
|
from opendevin import config
|
|
|
|
|
|
|
|
-DEFAULT_MODEL = config.get_or_default("LLM_MODEL", "gpt-4-0125-preview")
|
|
|
|
|
|
|
+DEFAULT_MODEL_NAME = config.get_or_default("LLM_MODEL", "gpt-4-0125-preview")
|
|
|
DEFAULT_API_KEY = config.get_or_none("LLM_API_KEY")
|
|
DEFAULT_API_KEY = config.get_or_none("LLM_API_KEY")
|
|
|
DEFAULT_BASE_URL = config.get_or_none("LLM_BASE_URL")
|
|
DEFAULT_BASE_URL = config.get_or_none("LLM_BASE_URL")
|
|
|
|
|
+DEFAULT_LLM_NUM_RETRIES = config.get_or_default("LLM_NUM_RETRIES", 6)
|
|
|
|
|
+DEFAULT_LLM_COOLDOWN_TIME = config.get_or_default("LLM_COOLDOWN_TIME", 1)
|
|
|
PROMPT_DEBUG_DIR = config.get_or_default("PROMPT_DEBUG_DIR", "")
|
|
PROMPT_DEBUG_DIR = config.get_or_default("PROMPT_DEBUG_DIR", "")
|
|
|
|
|
|
|
|
class LLM:
|
|
class LLM:
|
|
|
- def __init__(self, model=DEFAULT_MODEL, api_key=DEFAULT_API_KEY, base_url=DEFAULT_BASE_URL, debug_dir=PROMPT_DEBUG_DIR):
|
|
|
|
|
- self.model = model if model else DEFAULT_MODEL
|
|
|
|
|
|
|
+ def __init__(self,
|
|
|
|
|
+ model=DEFAULT_MODEL_NAME,
|
|
|
|
|
+ api_key=DEFAULT_API_KEY,
|
|
|
|
|
+ base_url=DEFAULT_BASE_URL,
|
|
|
|
|
+ num_retries=DEFAULT_LLM_NUM_RETRIES,
|
|
|
|
|
+ cooldown_time=DEFAULT_LLM_COOLDOWN_TIME,
|
|
|
|
|
+ debug_dir=PROMPT_DEBUG_DIR
|
|
|
|
|
+ ):
|
|
|
|
|
+ self.model_name = model if model else DEFAULT_MODEL_NAME
|
|
|
self.api_key = api_key if api_key else DEFAULT_API_KEY
|
|
self.api_key = api_key if api_key else DEFAULT_API_KEY
|
|
|
self.base_url = base_url if base_url else DEFAULT_BASE_URL
|
|
self.base_url = base_url if base_url else DEFAULT_BASE_URL
|
|
|
|
|
+ self.num_retries = num_retries if num_retries else DEFAULT_LLM_NUM_RETRIES
|
|
|
|
|
+ self.cooldown_time = cooldown_time if cooldown_time else DEFAULT_LLM_COOLDOWN_TIME
|
|
|
self._debug_dir = debug_dir if debug_dir else PROMPT_DEBUG_DIR
|
|
self._debug_dir = debug_dir if debug_dir else PROMPT_DEBUG_DIR
|
|
|
self._debug_idx = 0
|
|
self._debug_idx = 0
|
|
|
self._debug_id = uuid.uuid4().hex
|
|
self._debug_id = uuid.uuid4().hex
|
|
|
|
|
|
|
|
- self._completion = partial(litellm_completion, model=self.model, api_key=self.api_key, base_url=self.base_url)
|
|
|
|
|
|
|
+ # We use litellm's Router in order to support retries (especially rate limit backoff retries).
|
|
|
|
|
+ # Typically you would use a whole model list, but it's unnecessary with our implementation's structure
|
|
|
|
|
+ self._router = Router(
|
|
|
|
|
+ model_list=[{
|
|
|
|
|
+ "model_name": self.model_name,
|
|
|
|
|
+ "litellm_params": {
|
|
|
|
|
+ "model": self.model_name,
|
|
|
|
|
+ "api_key": self.api_key,
|
|
|
|
|
+ "api_base": self.base_url
|
|
|
|
|
+ }
|
|
|
|
|
+ }],
|
|
|
|
|
+ num_retries=self.num_retries,
|
|
|
|
|
+ allowed_fails=self.num_retries, # We allow all retries to fail, so they can retry instead of going into "cooldown"
|
|
|
|
|
+ cooldown_time=self.cooldown_time,
|
|
|
|
|
+ # set_verbose=True,
|
|
|
|
|
+ # debug_level="DEBUG"
|
|
|
|
|
+ )
|
|
|
|
|
+ self._completion = partial(self._router.completion, model=self.model_name)
|
|
|
|
|
|
|
|
if self._debug_dir:
|
|
if self._debug_dir:
|
|
|
print(f"Logging prompts to {self._debug_dir}/{self._debug_id}")
|
|
print(f"Logging prompts to {self._debug_dir}/{self._debug_id}")
|