|
|
@@ -1,5 +1,4 @@
|
|
|
import copy
|
|
|
-import os
|
|
|
import time
|
|
|
import warnings
|
|
|
from functools import partial
|
|
|
@@ -10,16 +9,16 @@ from openhands.core.config import LLMConfig
|
|
|
with warnings.catch_warnings():
|
|
|
warnings.simplefilter('ignore')
|
|
|
import litellm
|
|
|
+from litellm import ModelInfo
|
|
|
from litellm import completion as litellm_completion
|
|
|
from litellm import completion_cost as litellm_completion_cost
|
|
|
from litellm.exceptions import (
|
|
|
APIConnectionError,
|
|
|
- ContentPolicyViolationError,
|
|
|
InternalServerError,
|
|
|
- OpenAIError,
|
|
|
RateLimitError,
|
|
|
+ ServiceUnavailableError,
|
|
|
)
|
|
|
-from litellm.types.utils import CostPerToken
|
|
|
+from litellm.types.utils import CostPerToken, ModelResponse, Usage
|
|
|
|
|
|
from openhands.core.logger import openhands_logger as logger
|
|
|
from openhands.core.message import Message
|
|
|
@@ -29,9 +28,23 @@ from openhands.llm.retry_mixin import RetryMixin
|
|
|
|
|
|
__all__ = ['LLM']
|
|
|
|
|
|
-cache_prompting_supported_models = [
|
|
|
+# tuple of exceptions to retry on
|
|
|
+LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = (
|
|
|
+ APIConnectionError,
|
|
|
+ InternalServerError,
|
|
|
+ RateLimitError,
|
|
|
+ ServiceUnavailableError,
|
|
|
+)
|
|
|
+
|
|
|
+# cache prompt supporting models
|
|
|
+# remove this when we gemini and deepseek are supported
|
|
|
+CACHE_PROMPT_SUPPORTED_MODELS = [
|
|
|
'claude-3-5-sonnet-20240620',
|
|
|
'claude-3-haiku-20240307',
|
|
|
+ 'claude-3-opus-20240229',
|
|
|
+ 'anthropic/claude-3-opus-20240229',
|
|
|
+ 'anthropic/claude-3-haiku-20240307',
|
|
|
+ 'anthropic/claude-3-5-sonnet-20240620',
|
|
|
]
|
|
|
|
|
|
|
|
|
@@ -55,23 +68,17 @@ class LLM(RetryMixin, DebugMixin):
|
|
|
config: The LLM configuration.
|
|
|
metrics: The metrics to use.
|
|
|
"""
|
|
|
- self.metrics = metrics if metrics is not None else Metrics()
|
|
|
- self.cost_metric_supported = True
|
|
|
- self.config = copy.deepcopy(config)
|
|
|
-
|
|
|
- os.environ['OR_SITE_URL'] = self.config.openrouter_site_url
|
|
|
- os.environ['OR_APP_NAME'] = self.config.openrouter_app_name
|
|
|
+ self.metrics: Metrics = metrics if metrics is not None else Metrics()
|
|
|
+ self.cost_metric_supported: bool = True
|
|
|
+ self.config: LLMConfig = copy.deepcopy(config)
|
|
|
|
|
|
# list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
|
|
|
# - 'messages': list of messages
|
|
|
# - 'response': response from the LLM
|
|
|
self.llm_completions: list[dict[str, Any]] = []
|
|
|
|
|
|
- # Set up config attributes with default values to prevent AttributeError
|
|
|
- LLMConfig.set_missing_attributes(self.config)
|
|
|
-
|
|
|
# litellm actually uses base Exception here for unknown model
|
|
|
- self.model_info = None
|
|
|
+ self.model_info: ModelInfo | None = None
|
|
|
try:
|
|
|
if self.config.model.startswith('openrouter'):
|
|
|
self.model_info = litellm.get_model_info(self.config.model)
|
|
|
@@ -83,15 +90,6 @@ class LLM(RetryMixin, DebugMixin):
|
|
|
except Exception as e:
|
|
|
logger.warning(f'Could not get model info for {config.model}:\n{e}')
|
|
|
|
|
|
- # Tuple of exceptions to retry on
|
|
|
- self.retry_exceptions = (
|
|
|
- APIConnectionError,
|
|
|
- ContentPolicyViolationError,
|
|
|
- InternalServerError,
|
|
|
- OpenAIError,
|
|
|
- RateLimitError,
|
|
|
- )
|
|
|
-
|
|
|
# Set the max tokens in an LM-specific way if not set
|
|
|
if self.config.max_input_tokens is None:
|
|
|
if (
|
|
|
@@ -135,23 +133,39 @@ class LLM(RetryMixin, DebugMixin):
|
|
|
|
|
|
if self.vision_is_active():
|
|
|
logger.debug('LLM: model has vision enabled')
|
|
|
+ if self.is_caching_prompt_active():
|
|
|
+ logger.debug('LLM: caching prompt enabled')
|
|
|
|
|
|
completion_unwrapped = self._completion
|
|
|
|
|
|
@self.retry_decorator(
|
|
|
num_retries=self.config.num_retries,
|
|
|
- retry_exceptions=self.retry_exceptions,
|
|
|
+ retry_exceptions=LLM_RETRY_EXCEPTIONS,
|
|
|
retry_min_wait=self.config.retry_min_wait,
|
|
|
retry_max_wait=self.config.retry_max_wait,
|
|
|
retry_multiplier=self.config.retry_multiplier,
|
|
|
)
|
|
|
def wrapper(*args, **kwargs):
|
|
|
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
|
|
- # some callers might just send the messages directly
|
|
|
- if 'messages' in kwargs:
|
|
|
+ messages: list[dict[str, Any]] | dict[str, Any] = []
|
|
|
+
|
|
|
+ # some callers might send the model and messages directly
|
|
|
+ # litellm allows positional args, like completion(model, messages, **kwargs)
|
|
|
+ if len(args) > 1:
|
|
|
+ # ignore the first argument if it's provided (it would be the model)
|
|
|
+ # design wise: we don't allow overriding the configured values
|
|
|
+ # implementation wise: the partial function set the model as a kwarg already
|
|
|
+ # as well as other kwargs
|
|
|
+ messages = args[1] if len(args) > 1 else args[0]
|
|
|
+ kwargs['messages'] = messages
|
|
|
+
|
|
|
+ # remove the first args, they're sent in kwargs
|
|
|
+ args = args[2:]
|
|
|
+ elif 'messages' in kwargs:
|
|
|
messages = kwargs['messages']
|
|
|
- else:
|
|
|
- messages = args[1] if len(args) > 1 else []
|
|
|
+
|
|
|
+ # ensure we work with a list of messages
|
|
|
+ messages = messages if isinstance(messages, list) else [messages]
|
|
|
|
|
|
# if we have no messages, something went very wrong
|
|
|
if not messages:
|
|
|
@@ -169,7 +183,8 @@ class LLM(RetryMixin, DebugMixin):
|
|
|
'anthropic-beta': 'prompt-caching-2024-07-31',
|
|
|
}
|
|
|
|
|
|
- resp = completion_unwrapped(*args, **kwargs)
|
|
|
+ # we don't support streaming here, thus we get a ModelResponse
|
|
|
+ resp: ModelResponse = completion_unwrapped(*args, **kwargs)
|
|
|
|
|
|
# log for evals or other scripts that need the raw completion
|
|
|
if self.config.log_completions:
|
|
|
@@ -182,7 +197,7 @@ class LLM(RetryMixin, DebugMixin):
|
|
|
}
|
|
|
)
|
|
|
|
|
|
- message_back = resp['choices'][0]['message']['content']
|
|
|
+ message_back: str = resp['choices'][0]['message']['content']
|
|
|
|
|
|
# log the LLM response
|
|
|
self.log_response(message_back)
|
|
|
@@ -211,22 +226,29 @@ class LLM(RetryMixin, DebugMixin):
|
|
|
Returns:
|
|
|
bool: True if model is vision capable. If model is not supported by litellm, it will return False.
|
|
|
"""
|
|
|
- try:
|
|
|
- return litellm.supports_vision(self.config.model)
|
|
|
- except Exception:
|
|
|
- return False
|
|
|
+ # litellm.supports_vision currently returns False for 'openai/gpt-...' or 'anthropic/claude-...' (with prefixes)
|
|
|
+ # but model_info will have the correct value for some reason.
|
|
|
+ # we can go with it, but we will need to keep an eye if model_info is correct for Vertex or other providers
|
|
|
+ # remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
|
|
|
+ return litellm.supports_vision(self.config.model) or (
|
|
|
+ self.model_info is not None
|
|
|
+ and self.model_info.get('supports_vision', False)
|
|
|
+ )
|
|
|
|
|
|
def is_caching_prompt_active(self) -> bool:
|
|
|
- """Check if prompt caching is enabled and supported for current model.
|
|
|
+ """Check if prompt caching is supported and enabled for current model.
|
|
|
|
|
|
Returns:
|
|
|
- boolean: True if prompt caching is active for the given model.
|
|
|
+ boolean: True if prompt caching is supported and enabled for the given model.
|
|
|
"""
|
|
|
- return self.config.caching_prompt is True and any(
|
|
|
- model in self.config.model for model in cache_prompting_supported_models
|
|
|
+ return (
|
|
|
+ self.config.caching_prompt is True
|
|
|
+ and self.model_info is not None
|
|
|
+ and self.model_info.get('supports_prompt_caching', False)
|
|
|
+ and self.config.model in CACHE_PROMPT_SUPPORTED_MODELS
|
|
|
)
|
|
|
|
|
|
- def _post_completion(self, response) -> None:
|
|
|
+ def _post_completion(self, response: ModelResponse) -> None:
|
|
|
"""Post-process the completion response.
|
|
|
|
|
|
Logs the cost and usage stats of the completion call.
|
|
|
@@ -244,7 +266,7 @@ class LLM(RetryMixin, DebugMixin):
|
|
|
self.metrics.accumulated_cost,
|
|
|
)
|
|
|
|
|
|
- usage = response.get('usage')
|
|
|
+ usage: Usage | None = response.get('usage')
|
|
|
|
|
|
if usage:
|
|
|
# keep track of the input and output tokens
|
|
|
@@ -366,5 +388,12 @@ class LLM(RetryMixin, DebugMixin):
|
|
|
|
|
|
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
|
|
|
if isinstance(messages, Message):
|
|
|
- return [messages.model_dump()]
|
|
|
+ messages = [messages]
|
|
|
+
|
|
|
+ # set flags to know how to serialize the messages
|
|
|
+ for message in messages:
|
|
|
+ message.cache_enabled = self.is_caching_prompt_active()
|
|
|
+ message.vision_enabled = self.vision_is_active()
|
|
|
+
|
|
|
+ # let pydantic handle the serialization
|
|
|
return [message.model_dump() for message in messages]
|