| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619 |
- import copy
- import os
- import time
- import warnings
- from functools import partial
- from typing import Any
- import requests
- from openhands.core.config import LLMConfig
- with warnings.catch_warnings():
- warnings.simplefilter('ignore')
- import litellm
- from litellm import Message as LiteLLMMessage
- from litellm import ModelInfo, PromptTokensDetails
- from litellm import completion as litellm_completion
- from litellm import completion_cost as litellm_completion_cost
- from litellm.exceptions import (
- APIConnectionError,
- APIError,
- InternalServerError,
- RateLimitError,
- ServiceUnavailableError,
- )
- from litellm.types.utils import CostPerToken, ModelResponse, Usage
- from litellm.utils import create_pretrained_tokenizer
- from openhands.core.exceptions import CloudFlareBlockageError
- from openhands.core.logger import openhands_logger as logger
- from openhands.core.message import Message
- from openhands.llm.debug_mixin import DebugMixin
- from openhands.llm.fn_call_converter import (
- STOP_WORDS,
- convert_fncall_messages_to_non_fncall_messages,
- convert_non_fncall_messages_to_fncall_messages,
- )
- from openhands.llm.metrics import Metrics
- from openhands.llm.retry_mixin import RetryMixin
- __all__ = ['LLM']
- # tuple of exceptions to retry on
- LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = (
- APIConnectionError,
- # FIXME: APIError is useful on 502 from a proxy for example,
- # but it also retries on other errors that are permanent
- APIError,
- InternalServerError,
- RateLimitError,
- ServiceUnavailableError,
- )
- # cache prompt supporting models
- # remove this when we gemini and deepseek are supported
- CACHE_PROMPT_SUPPORTED_MODELS = [
- 'claude-3-5-sonnet-20241022',
- 'claude-3-5-sonnet-20240620',
- 'claude-3-5-haiku-20241022',
- 'claude-3-haiku-20240307',
- 'claude-3-opus-20240229',
- ]
- # function calling supporting models
- FUNCTION_CALLING_SUPPORTED_MODELS = [
- 'claude-3-5-sonnet',
- 'claude-3-5-sonnet-20240620',
- 'claude-3-5-sonnet-20241022',
- 'claude-3.5-haiku',
- 'claude-3-5-haiku-20241022',
- 'gpt-4o-mini',
- 'gpt-4o',
- ]
- class LLM(RetryMixin, DebugMixin):
- """The LLM class represents a Language Model instance.
- Attributes:
- config: an LLMConfig object specifying the configuration of the LLM.
- """
- def __init__(
- self,
- config: LLMConfig,
- metrics: Metrics | None = None,
- ):
- """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
- Passing simple parameters always overrides config.
- Args:
- config: The LLM configuration.
- metrics: The metrics to use.
- """
- self._tried_model_info = False
- self.metrics: Metrics = (
- metrics if metrics is not None else Metrics(model_name=config.model)
- )
- self.cost_metric_supported: bool = True
- self.config: LLMConfig = copy.deepcopy(config)
- # litellm actually uses base Exception here for unknown model
- self.model_info: ModelInfo | None = None
- if self.config.log_completions:
- if self.config.log_completions_folder is None:
- raise RuntimeError(
- 'log_completions_folder is required when log_completions is enabled'
- )
- os.makedirs(self.config.log_completions_folder, exist_ok=True)
- # call init_model_info to initialize config.max_output_tokens
- # which is used in partial function
- with warnings.catch_warnings():
- warnings.simplefilter('ignore')
- self.init_model_info()
- if self.vision_is_active():
- logger.debug('LLM: model has vision enabled')
- if self.is_caching_prompt_active():
- logger.debug('LLM: caching prompt enabled')
- if self.is_function_calling_active():
- logger.debug('LLM: model supports function calling')
- # if using a custom tokenizer, make sure it's loaded and accessible in the format expected by litellm
- if self.config.custom_tokenizer is not None:
- self.tokenizer = create_pretrained_tokenizer(self.config.custom_tokenizer)
- else:
- self.tokenizer = None
- # set up the completion function
- self._completion = partial(
- litellm_completion,
- model=self.config.model,
- api_key=self.config.api_key,
- base_url=self.config.base_url,
- api_version=self.config.api_version,
- custom_llm_provider=self.config.custom_llm_provider,
- max_tokens=self.config.max_output_tokens,
- timeout=self.config.timeout,
- temperature=self.config.temperature,
- top_p=self.config.top_p,
- drop_params=self.config.drop_params,
- )
- self._completion_unwrapped = self._completion
- @self.retry_decorator(
- num_retries=self.config.num_retries,
- retry_exceptions=LLM_RETRY_EXCEPTIONS,
- retry_min_wait=self.config.retry_min_wait,
- retry_max_wait=self.config.retry_max_wait,
- retry_multiplier=self.config.retry_multiplier,
- )
- def wrapper(*args, **kwargs):
- """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
- from openhands.core.utils import json
- messages: list[dict[str, Any]] | dict[str, Any] = []
- mock_function_calling = kwargs.pop('mock_function_calling', False)
- # some callers might send the model and messages directly
- # litellm allows positional args, like completion(model, messages, **kwargs)
- if len(args) > 1:
- # ignore the first argument if it's provided (it would be the model)
- # design wise: we don't allow overriding the configured values
- # implementation wise: the partial function set the model as a kwarg already
- # as well as other kwargs
- messages = args[1] if len(args) > 1 else args[0]
- kwargs['messages'] = messages
- # remove the first args, they're sent in kwargs
- args = args[2:]
- elif 'messages' in kwargs:
- messages = kwargs['messages']
- # ensure we work with a list of messages
- messages = messages if isinstance(messages, list) else [messages]
- original_fncall_messages = copy.deepcopy(messages)
- mock_fncall_tools = None
- if mock_function_calling:
- assert (
- 'tools' in kwargs
- ), "'tools' must be in kwargs when mock_function_calling is True"
- messages = convert_fncall_messages_to_non_fncall_messages(
- messages, kwargs['tools']
- )
- kwargs['messages'] = messages
- kwargs['stop'] = STOP_WORDS
- mock_fncall_tools = kwargs.pop('tools')
- # if we have no messages, something went very wrong
- if not messages:
- raise ValueError(
- 'The messages list is empty. At least one message is required.'
- )
- # log the entire LLM prompt
- self.log_prompt(messages)
- if self.is_caching_prompt_active():
- # Anthropic-specific prompt caching
- if 'claude-3' in self.config.model:
- kwargs['extra_headers'] = {
- 'anthropic-beta': 'prompt-caching-2024-07-31',
- }
- try:
- # Record start time for latency measurement
- start_time = time.time()
- # we don't support streaming here, thus we get a ModelResponse
- resp: ModelResponse = self._completion_unwrapped(*args, **kwargs)
- # Calculate and record latency
- latency = time.time() - start_time
- response_id = resp.get('id', 'unknown')
- self.metrics.add_response_latency(latency, response_id)
- non_fncall_response = copy.deepcopy(resp)
- if mock_function_calling:
- assert len(resp.choices) == 1
- assert mock_fncall_tools is not None
- non_fncall_response_message = resp.choices[0].message
- fn_call_messages_with_response = (
- convert_non_fncall_messages_to_fncall_messages(
- messages + [non_fncall_response_message], mock_fncall_tools
- )
- )
- fn_call_response_message = fn_call_messages_with_response[-1]
- if not isinstance(fn_call_response_message, LiteLLMMessage):
- fn_call_response_message = LiteLLMMessage(
- **fn_call_response_message
- )
- resp.choices[0].message = fn_call_response_message
- message_back: str = resp['choices'][0]['message']['content'] or ''
- tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
- if tool_calls:
- for tool_call in tool_calls:
- fn_name = tool_call.function.name
- fn_args = tool_call.function.arguments
- message_back += f'\nFunction call: {fn_name}({fn_args})'
- # log the LLM response
- self.log_response(message_back)
- # post-process the response first to calculate cost
- cost = self._post_completion(resp)
- # log for evals or other scripts that need the raw completion
- if self.config.log_completions:
- assert self.config.log_completions_folder is not None
- log_file = os.path.join(
- self.config.log_completions_folder,
- # use the metric model name (for draft editor)
- f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
- )
- # set up the dict to be logged
- _d = {
- 'messages': messages,
- 'response': resp,
- 'args': args,
- 'kwargs': {k: v for k, v in kwargs.items() if k != 'messages'},
- 'timestamp': time.time(),
- 'cost': cost,
- }
- # if non-native function calling, save messages/response separately
- if mock_function_calling:
- # Overwrite response as non-fncall to be consistent with messages
- _d['response'] = non_fncall_response
- # Save fncall_messages/response separately
- _d['fncall_messages'] = original_fncall_messages
- _d['fncall_response'] = resp
- with open(log_file, 'w') as f:
- f.write(json.dumps(_d))
- return resp
- except APIError as e:
- if 'Attention Required! | Cloudflare' in str(e):
- raise CloudFlareBlockageError(
- 'Request blocked by CloudFlare'
- ) from e
- raise
- self._completion = wrapper
- @property
- def completion(self):
- """Decorator for the litellm completion function.
- Check the complete documentation at https://litellm.vercel.app/docs/completion
- """
- return self._completion
- def init_model_info(self):
- if self._tried_model_info:
- return
- self._tried_model_info = True
- try:
- if self.config.model.startswith('openrouter'):
- self.model_info = litellm.get_model_info(self.config.model)
- except Exception as e:
- logger.debug(f'Error getting model info: {e}')
- if self.config.model.startswith('litellm_proxy/'):
- # IF we are using LiteLLM proxy, get model info from LiteLLM proxy
- # GET {base_url}/v1/model/info with litellm_model_id as path param
- response = requests.get(
- f'{self.config.base_url}/v1/model/info',
- headers={'Authorization': f'Bearer {self.config.api_key}'},
- )
- resp_json = response.json()
- if 'data' not in resp_json:
- logger.error(
- f'Error getting model info from LiteLLM proxy: {resp_json}'
- )
- all_model_info = resp_json.get('data', [])
- current_model_info = next(
- (
- info
- for info in all_model_info
- if info['model_name']
- == self.config.model.removeprefix('litellm_proxy/')
- ),
- None,
- )
- if current_model_info:
- self.model_info = current_model_info['model_info']
- # Last two attempts to get model info from NAME
- if not self.model_info:
- try:
- self.model_info = litellm.get_model_info(
- self.config.model.split(':')[0]
- )
- # noinspection PyBroadException
- except Exception:
- pass
- if not self.model_info:
- try:
- self.model_info = litellm.get_model_info(
- self.config.model.split('/')[-1]
- )
- # noinspection PyBroadException
- except Exception:
- pass
- logger.debug(f'Model info: {self.model_info}')
- if self.config.model.startswith('huggingface'):
- # HF doesn't support the OpenAI default value for top_p (1)
- logger.debug(
- f'Setting top_p to 0.9 for Hugging Face model: {self.config.model}'
- )
- self.config.top_p = 0.9 if self.config.top_p == 1 else self.config.top_p
- # Set the max tokens in an LM-specific way if not set
- if self.config.max_input_tokens is None:
- if (
- self.model_info is not None
- and 'max_input_tokens' in self.model_info
- and isinstance(self.model_info['max_input_tokens'], int)
- ):
- self.config.max_input_tokens = self.model_info['max_input_tokens']
- else:
- # Safe fallback for any potentially viable model
- self.config.max_input_tokens = 4096
- if self.config.max_output_tokens is None:
- # Safe default for any potentially viable model
- self.config.max_output_tokens = 4096
- if self.model_info is not None:
- # max_output_tokens has precedence over max_tokens, if either exists.
- # litellm has models with both, one or none of these 2 parameters!
- if 'max_output_tokens' in self.model_info and isinstance(
- self.model_info['max_output_tokens'], int
- ):
- self.config.max_output_tokens = self.model_info['max_output_tokens']
- elif 'max_tokens' in self.model_info and isinstance(
- self.model_info['max_tokens'], int
- ):
- self.config.max_output_tokens = self.model_info['max_tokens']
- def vision_is_active(self) -> bool:
- with warnings.catch_warnings():
- warnings.simplefilter('ignore')
- return not self.config.disable_vision and self._supports_vision()
- def _supports_vision(self) -> bool:
- """Acquire from litellm if model is vision capable.
- Returns:
- bool: True if model is vision capable. Return False if model not supported by litellm.
- """
- # litellm.supports_vision currently returns False for 'openai/gpt-...' or 'anthropic/claude-...' (with prefixes)
- # but model_info will have the correct value for some reason.
- # we can go with it, but we will need to keep an eye if model_info is correct for Vertex or other providers
- # remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
- # Check both the full model name and the name after proxy prefix for vision support
- return (
- litellm.supports_vision(self.config.model)
- or litellm.supports_vision(self.config.model.split('/')[-1])
- or (
- self.model_info is not None
- and self.model_info.get('supports_vision', False)
- )
- )
- def is_caching_prompt_active(self) -> bool:
- """Check if prompt caching is supported and enabled for current model.
- Returns:
- boolean: True if prompt caching is supported and enabled for the given model.
- """
- return (
- self.config.caching_prompt is True
- and (
- self.config.model in CACHE_PROMPT_SUPPORTED_MODELS
- or self.config.model.split('/')[-1] in CACHE_PROMPT_SUPPORTED_MODELS
- )
- # We don't need to look-up model_info, because only Anthropic models needs the explicit caching breakpoint
- )
- def is_function_calling_active(self) -> bool:
- # Check if model name is in supported list before checking model_info
- model_name_supported = (
- self.config.model in FUNCTION_CALLING_SUPPORTED_MODELS
- or self.config.model.split('/')[-1] in FUNCTION_CALLING_SUPPORTED_MODELS
- or any(m in self.config.model for m in FUNCTION_CALLING_SUPPORTED_MODELS)
- )
- return model_name_supported
- def _post_completion(self, response: ModelResponse) -> float:
- """Post-process the completion response.
- Logs the cost and usage stats of the completion call.
- """
- try:
- cur_cost = self._completion_cost(response)
- except Exception:
- cur_cost = 0
- stats = ''
- if self.cost_metric_supported:
- # keep track of the cost
- stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
- cur_cost,
- self.metrics.accumulated_cost,
- )
- # Add latency to stats if available
- if self.metrics.response_latencies:
- latest_latency = self.metrics.response_latencies[-1]
- stats += 'Response Latency: %.3f seconds\n' % latest_latency.latency
- usage: Usage | None = response.get('usage')
- if usage:
- # keep track of the input and output tokens
- input_tokens = usage.get('prompt_tokens')
- output_tokens = usage.get('completion_tokens')
- if input_tokens:
- stats += 'Input tokens: ' + str(input_tokens)
- if output_tokens:
- stats += (
- (' | ' if input_tokens else '')
- + 'Output tokens: '
- + str(output_tokens)
- + '\n'
- )
- # read the prompt cache hit, if any
- prompt_tokens_details: PromptTokensDetails = usage.get(
- 'prompt_tokens_details'
- )
- cache_hit_tokens = (
- prompt_tokens_details.cached_tokens if prompt_tokens_details else None
- )
- if cache_hit_tokens:
- stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n'
- # For Anthropic, the cache writes have a different cost than regular input tokens
- # but litellm doesn't separate them in the usage stats
- # so we can read it from the provider-specific extra field
- model_extra = usage.get('model_extra', {})
- cache_write_tokens = model_extra.get('cache_creation_input_tokens')
- if cache_write_tokens:
- stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n'
- # log the stats
- if stats:
- logger.debug(stats)
- return cur_cost
- def get_token_count(self, messages: list[dict] | list[Message]) -> int:
- """Get the number of tokens in a list of messages. Use dicts for better token counting.
- Args:
- messages (list): A list of messages, either as a list of dicts or as a list of Message objects.
- Returns:
- int: The number of tokens.
- """
- # attempt to convert Message objects to dicts, litellm expects dicts
- if (
- isinstance(messages, list)
- and len(messages) > 0
- and isinstance(messages[0], Message)
- ):
- logger.info(
- 'Message objects now include serialized tool calls in token counting'
- )
- messages = self.format_messages_for_llm(messages) # type: ignore
- # try to get the token count with the default litellm tokenizers
- # or the custom tokenizer if set for this LLM configuration
- try:
- return litellm.token_counter(
- model=self.config.model,
- messages=messages,
- custom_tokenizer=self.tokenizer,
- )
- except Exception as e:
- # limit logspam in case token count is not supported
- logger.error(
- f'Error getting token count for\n model {self.config.model}\n{e}'
- + (
- f'\ncustom_tokenizer: {self.config.custom_tokenizer}'
- if self.config.custom_tokenizer is not None
- else ''
- )
- )
- return 0
- def _is_local(self) -> bool:
- """Determines if the system is using a locally running LLM.
- Returns:
- boolean: True if executing a local model.
- """
- if self.config.base_url is not None:
- for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
- if substring in self.config.base_url:
- return True
- elif self.config.model is not None:
- if self.config.model.startswith('ollama'):
- return True
- return False
- def _completion_cost(self, response) -> float:
- """Calculate the cost of a completion response based on the model. Local models are treated as free.
- Add the current cost into total cost in metrics.
- Args:
- response: A response from a model invocation.
- Returns:
- number: The cost of the response.
- """
- if not self.cost_metric_supported:
- return 0.0
- extra_kwargs = {}
- if (
- self.config.input_cost_per_token is not None
- and self.config.output_cost_per_token is not None
- ):
- cost_per_token = CostPerToken(
- input_cost_per_token=self.config.input_cost_per_token,
- output_cost_per_token=self.config.output_cost_per_token,
- )
- logger.debug(f'Using custom cost per token: {cost_per_token}')
- extra_kwargs['custom_cost_per_token'] = cost_per_token
- try:
- # try directly get response_cost from response
- cost = getattr(response, '_hidden_params', {}).get('response_cost', None)
- if cost is None:
- cost = litellm_completion_cost(
- completion_response=response, **extra_kwargs
- )
- self.metrics.add_cost(cost)
- return cost
- except Exception:
- self.cost_metric_supported = False
- logger.debug('Cost calculation not supported for this model.')
- return 0.0
- def __str__(self):
- if self.config.api_version:
- return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
- elif self.config.base_url:
- return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
- return f'LLM(model={self.config.model})'
- def __repr__(self):
- return str(self)
- def reset(self) -> None:
- self.metrics.reset()
- def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
- if isinstance(messages, Message):
- messages = [messages]
- # set flags to know how to serialize the messages
- for message in messages:
- message.cache_enabled = self.is_caching_prompt_active()
- message.vision_enabled = self.vision_is_active()
- message.function_calling_enabled = self.is_function_calling_active()
- # let pydantic handle the serialization
- return [message.model_dump() for message in messages]
|