|
|
@@ -1,30 +1,55 @@
|
|
|
|
|
|
from litellm import completion as litellm_completion
|
|
|
-from tenacity import retry, retry_if_exception_type, stop_after_attempt
|
|
|
-from litellm.exceptions import APIConnectionError, RateLimitError
|
|
|
+from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
|
|
|
+from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError
|
|
|
from functools import partial
|
|
|
|
|
|
from opendevin import config
|
|
|
-from opendevin.logger import llm_prompt_logger, llm_response_logger, opendevin_logger
|
|
|
+from opendevin.logger import llm_prompt_logger, llm_response_logger
|
|
|
+from opendevin.logger import opendevin_logger as logger
|
|
|
+
|
|
|
|
|
|
DEFAULT_API_KEY = config.get('LLM_API_KEY')
|
|
|
DEFAULT_BASE_URL = config.get('LLM_BASE_URL')
|
|
|
DEFAULT_MODEL_NAME = config.get('LLM_MODEL')
|
|
|
-DEFAULT_LLM_NUM_RETRIES = config.get('LLM_NUM_RETRIES')
|
|
|
-DEFAULT_LLM_COOLDOWN_TIME = config.get('LLM_COOLDOWN_TIME')
|
|
|
DEFAULT_API_VERSION = config.get('LLM_API_VERSION')
|
|
|
+LLM_NUM_RETRIES = config.get('LLM_NUM_RETRIES')
|
|
|
+LLM_RETRY_MIN_WAIT = config.get('LLM_RETRY_MIN_WAIT')
|
|
|
+LLM_RETRY_MAX_WAIT = config.get('LLM_RETRY_MAX_WAIT')
|
|
|
|
|
|
|
|
|
class LLM:
|
|
|
+ """
|
|
|
+ The LLM class represents a Language Model instance.
|
|
|
+ """
|
|
|
+
|
|
|
def __init__(self,
|
|
|
model=DEFAULT_MODEL_NAME,
|
|
|
api_key=DEFAULT_API_KEY,
|
|
|
base_url=DEFAULT_BASE_URL,
|
|
|
- num_retries=DEFAULT_LLM_NUM_RETRIES,
|
|
|
- cooldown_time=DEFAULT_LLM_COOLDOWN_TIME,
|
|
|
api_version=DEFAULT_API_VERSION,
|
|
|
+ num_retries=LLM_NUM_RETRIES,
|
|
|
+ retry_min_wait=LLM_RETRY_MIN_WAIT,
|
|
|
+ retry_max_wait=LLM_RETRY_MAX_WAIT,
|
|
|
):
|
|
|
- opendevin_logger.info(f'Initializing LLM with model: {model}')
|
|
|
+ """
|
|
|
+ Args:
|
|
|
+ model (str, optional): The name of the language model. Defaults to LLM_MODEL.
|
|
|
+ api_key (str, optional): The API key for accessing the language model. Defaults to LLM_API_KEY.
|
|
|
+ base_url (str, optional): The base URL for the language model API. Defaults to LLM_BASE_URL. Not necessary for OpenAI.
|
|
|
+ api_version (str, optional): The version of the API to use. Defaults to LLM_API_VERSION. Not necessary for OpenAI.
|
|
|
+ num_retries (int, optional): The number of retries for API calls. Defaults to LLM_NUM_RETRIES.
|
|
|
+ retry_min_wait (int, optional): The minimum time to wait between retries in seconds. Defaults to LLM_RETRY_MIN_TIME.
|
|
|
+ retry_max_wait (int, optional): The maximum time to wait between retries in seconds. Defaults to LLM_RETRY_MAX_TIME.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ model_name (str): The name of the language model.
|
|
|
+ api_key (str): The API key for accessing the language model.
|
|
|
+ base_url (str): The base URL for the language model API.
|
|
|
+ api_version (str): The version of the API to use.
|
|
|
+ completion (function): A decorator for the litellm completion function.
|
|
|
+ """
|
|
|
+ logger.info(f'Initializing LLM with model: {model}')
|
|
|
self.model_name = model
|
|
|
self.api_key = api_key
|
|
|
self.base_url = base_url
|
|
|
@@ -35,15 +60,13 @@ class LLM:
|
|
|
|
|
|
completion_unwrapped = self._completion
|
|
|
|
|
|
- def my_wait(retry_state):
|
|
|
- seconds = (retry_state.attempt_number) * cooldown_time
|
|
|
- opendevin_logger.warning(f'LLM error: {retry_state.outcome.exception()}')
|
|
|
- opendevin_logger.info(f'Attempt #{retry_state.attempt_number} | Sleeping for {seconds}s')
|
|
|
- return seconds
|
|
|
+ def attempt_on_error(retry_state):
|
|
|
+ logger.error(f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize these settings in the configuration.', exc_info=False)
|
|
|
+ return True
|
|
|
|
|
|
@retry(reraise=True,
|
|
|
stop=stop_after_attempt(num_retries),
|
|
|
- wait=my_wait, retry=retry_if_exception_type((APIConnectionError, RateLimitError)))
|
|
|
+ wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait), retry=retry_if_exception_type((RateLimitError, APIConnectionError, ServiceUnavailableError)), after=attempt_on_error)
|
|
|
def wrapper(*args, **kwargs):
|
|
|
if 'messages' in kwargs:
|
|
|
messages = kwargs['messages']
|