|
|
@@ -1,4 +1,3 @@
|
|
|
-import asyncio
|
|
|
import copy
|
|
|
import os
|
|
|
import time
|
|
|
@@ -7,7 +6,6 @@ from functools import partial
|
|
|
from typing import Any
|
|
|
|
|
|
from openhands.core.config import LLMConfig
|
|
|
-from openhands.runtime.utils.shutdown_listener import should_continue
|
|
|
|
|
|
with warnings.catch_warnings():
|
|
|
warnings.simplefilter('ignore')
|
|
|
@@ -18,41 +16,26 @@ from litellm.exceptions import (
|
|
|
APIConnectionError,
|
|
|
ContentPolicyViolationError,
|
|
|
InternalServerError,
|
|
|
- NotFoundError,
|
|
|
OpenAIError,
|
|
|
RateLimitError,
|
|
|
- ServiceUnavailableError,
|
|
|
)
|
|
|
from litellm.types.utils import CostPerToken
|
|
|
-from tenacity import (
|
|
|
- retry,
|
|
|
- retry_if_exception_type,
|
|
|
- retry_if_not_exception_type,
|
|
|
- stop_after_attempt,
|
|
|
- wait_exponential,
|
|
|
-)
|
|
|
|
|
|
-from openhands.core.exceptions import (
|
|
|
- OperationCancelled,
|
|
|
- UserCancelledError,
|
|
|
-)
|
|
|
-from openhands.core.logger import llm_prompt_logger, llm_response_logger
|
|
|
from openhands.core.logger import openhands_logger as logger
|
|
|
from openhands.core.message import Message
|
|
|
from openhands.core.metrics import Metrics
|
|
|
-from openhands.runtime.utils.shutdown_listener import should_exit
|
|
|
+from openhands.llm.debug_mixin import DebugMixin
|
|
|
+from openhands.llm.retry_mixin import RetryMixin
|
|
|
|
|
|
__all__ = ['LLM']
|
|
|
|
|
|
-message_separator = '\n\n----------\n\n'
|
|
|
-
|
|
|
cache_prompting_supported_models = [
|
|
|
'claude-3-5-sonnet-20240620',
|
|
|
'claude-3-haiku-20240307',
|
|
|
]
|
|
|
|
|
|
|
|
|
-class LLM:
|
|
|
+class LLM(RetryMixin, DebugMixin):
|
|
|
"""The LLM class represents a Language Model instance.
|
|
|
|
|
|
Attributes:
|
|
|
@@ -69,7 +52,8 @@ class LLM:
|
|
|
Passing simple parameters always overrides config.
|
|
|
|
|
|
Args:
|
|
|
- config: The LLM configuration
|
|
|
+ config: The LLM configuration.
|
|
|
+ metrics: The metrics to use.
|
|
|
"""
|
|
|
self.metrics = metrics if metrics is not None else Metrics()
|
|
|
self.cost_metric_supported = True
|
|
|
@@ -135,30 +119,6 @@ class LLM:
|
|
|
):
|
|
|
self.config.max_output_tokens = self.model_info['max_tokens']
|
|
|
|
|
|
- # This only seems to work with Google as the provider, not with OpenRouter!
|
|
|
- gemini_safety_settings = (
|
|
|
- [
|
|
|
- {
|
|
|
- 'category': 'HARM_CATEGORY_HARASSMENT',
|
|
|
- 'threshold': 'BLOCK_NONE',
|
|
|
- },
|
|
|
- {
|
|
|
- 'category': 'HARM_CATEGORY_HATE_SPEECH',
|
|
|
- 'threshold': 'BLOCK_NONE',
|
|
|
- },
|
|
|
- {
|
|
|
- 'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
|
|
|
- 'threshold': 'BLOCK_NONE',
|
|
|
- },
|
|
|
- {
|
|
|
- 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
|
|
|
- 'threshold': 'BLOCK_NONE',
|
|
|
- },
|
|
|
- ]
|
|
|
- if self.config.model.lower().startswith('gemini')
|
|
|
- else None
|
|
|
- )
|
|
|
-
|
|
|
self._completion = partial(
|
|
|
litellm_completion,
|
|
|
model=self.config.model,
|
|
|
@@ -171,11 +131,6 @@ class LLM:
|
|
|
temperature=self.config.temperature,
|
|
|
top_p=self.config.top_p,
|
|
|
drop_params=self.config.drop_params,
|
|
|
- **(
|
|
|
- {'safety_settings': gemini_safety_settings}
|
|
|
- if gemini_safety_settings is not None
|
|
|
- else {}
|
|
|
- ),
|
|
|
)
|
|
|
|
|
|
if self.vision_is_active():
|
|
|
@@ -183,61 +138,12 @@ class LLM:
|
|
|
|
|
|
completion_unwrapped = self._completion
|
|
|
|
|
|
- def log_retry_attempt(retry_state):
|
|
|
- """With before_sleep, this is called before `custom_completion_wait` and
|
|
|
- ONLY if the retry is triggered by an exception."""
|
|
|
- if should_exit():
|
|
|
- raise OperationCancelled(
|
|
|
- 'Operation cancelled.'
|
|
|
- ) # exits the @retry loop
|
|
|
- exception = retry_state.outcome.exception()
|
|
|
- logger.error(
|
|
|
- f'{exception}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
|
|
|
- exc_info=False,
|
|
|
- )
|
|
|
-
|
|
|
- def custom_completion_wait(retry_state):
|
|
|
- """Custom wait function for litellm completion."""
|
|
|
- if not retry_state:
|
|
|
- return 0
|
|
|
- exception = retry_state.outcome.exception() if retry_state.outcome else None
|
|
|
- if exception is None:
|
|
|
- return 0
|
|
|
-
|
|
|
- min_wait_time = self.config.retry_min_wait
|
|
|
- max_wait_time = self.config.retry_max_wait
|
|
|
-
|
|
|
- # for rate limit errors, wait 1 minute by default, max 4 minutes between retries
|
|
|
- exception_type = type(exception).__name__
|
|
|
- logger.error(f'\nexception_type: {exception_type}\n')
|
|
|
-
|
|
|
- if exception_type == 'RateLimitError':
|
|
|
- min_wait_time = 60
|
|
|
- max_wait_time = 240
|
|
|
- elif exception_type == 'BadRequestError' and exception.response:
|
|
|
- # this should give us the burried, actual error message from
|
|
|
- # the LLM model.
|
|
|
- logger.error(f'\n\nBadRequestError: {exception.response}\n\n')
|
|
|
-
|
|
|
- # Return the wait time using exponential backoff
|
|
|
- exponential_wait = wait_exponential(
|
|
|
- multiplier=self.config.retry_multiplier,
|
|
|
- min=min_wait_time,
|
|
|
- max=max_wait_time,
|
|
|
- )
|
|
|
-
|
|
|
- # Call the exponential wait function with retry_state to get the actual wait time
|
|
|
- return exponential_wait(retry_state)
|
|
|
-
|
|
|
- @retry(
|
|
|
- before_sleep=log_retry_attempt,
|
|
|
- stop=stop_after_attempt(self.config.num_retries),
|
|
|
- reraise=True,
|
|
|
- retry=(
|
|
|
- retry_if_exception_type(self.retry_exceptions)
|
|
|
- & retry_if_not_exception_type(OperationCancelled)
|
|
|
- ),
|
|
|
- wait=custom_completion_wait,
|
|
|
+ @self.retry_decorator(
|
|
|
+ num_retries=self.config.num_retries,
|
|
|
+ retry_exceptions=self.retry_exceptions,
|
|
|
+ retry_min_wait=self.config.retry_min_wait,
|
|
|
+ retry_max_wait=self.config.retry_max_wait,
|
|
|
+ retry_multiplier=self.config.retry_multiplier,
|
|
|
)
|
|
|
def wrapper(*args, **kwargs):
|
|
|
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
|
|
@@ -247,8 +153,14 @@ class LLM:
|
|
|
else:
|
|
|
messages = args[1] if len(args) > 1 else []
|
|
|
|
|
|
- # this serves to prevent empty messages and logging the messages
|
|
|
- debug_message = self._get_debug_message(messages)
|
|
|
+ # if we have no messages, something went very wrong
|
|
|
+ if not messages:
|
|
|
+ raise ValueError(
|
|
|
+ 'The messages list is empty. At least one message is required.'
|
|
|
+ )
|
|
|
+
|
|
|
+ # log the entire LLM prompt
|
|
|
+ self.log_prompt(messages)
|
|
|
|
|
|
if self.is_caching_prompt_active():
|
|
|
# Anthropic-specific prompt caching
|
|
|
@@ -257,239 +169,30 @@ class LLM:
|
|
|
'anthropic-beta': 'prompt-caching-2024-07-31',
|
|
|
}
|
|
|
|
|
|
- # skip if messages is empty (thus debug_message is empty)
|
|
|
- if debug_message:
|
|
|
- llm_prompt_logger.debug(debug_message)
|
|
|
- resp = completion_unwrapped(*args, **kwargs)
|
|
|
- else:
|
|
|
- logger.debug('No completion messages!')
|
|
|
- resp = {'choices': [{'message': {'content': ''}}]}
|
|
|
+ resp = completion_unwrapped(*args, **kwargs)
|
|
|
|
|
|
+ # log for evals or other scripts that need the raw completion
|
|
|
if self.config.log_completions:
|
|
|
self.llm_completions.append(
|
|
|
{
|
|
|
'messages': messages,
|
|
|
'response': resp,
|
|
|
'timestamp': time.time(),
|
|
|
- 'cost': self.completion_cost(resp),
|
|
|
+ 'cost': self._completion_cost(resp),
|
|
|
}
|
|
|
)
|
|
|
|
|
|
- # log the response
|
|
|
message_back = resp['choices'][0]['message']['content']
|
|
|
- if message_back:
|
|
|
- llm_response_logger.debug(message_back)
|
|
|
-
|
|
|
- # post-process to log costs
|
|
|
- self._post_completion(resp)
|
|
|
|
|
|
- return resp
|
|
|
-
|
|
|
- self._completion = wrapper # type: ignore
|
|
|
+ # log the LLM response
|
|
|
+ self.log_response(message_back)
|
|
|
|
|
|
- # Async version
|
|
|
- self._async_completion = partial(
|
|
|
- self._call_acompletion,
|
|
|
- model=self.config.model,
|
|
|
- api_key=self.config.api_key,
|
|
|
- base_url=self.config.base_url,
|
|
|
- api_version=self.config.api_version,
|
|
|
- custom_llm_provider=self.config.custom_llm_provider,
|
|
|
- max_tokens=self.config.max_output_tokens,
|
|
|
- timeout=self.config.timeout,
|
|
|
- temperature=self.config.temperature,
|
|
|
- top_p=self.config.top_p,
|
|
|
- drop_params=self.config.drop_params,
|
|
|
- **(
|
|
|
- {'safety_settings': gemini_safety_settings}
|
|
|
- if gemini_safety_settings is not None
|
|
|
- else {}
|
|
|
- ),
|
|
|
- )
|
|
|
-
|
|
|
- async_completion_unwrapped = self._async_completion
|
|
|
-
|
|
|
- @retry(
|
|
|
- before_sleep=log_retry_attempt,
|
|
|
- stop=stop_after_attempt(self.config.num_retries),
|
|
|
- reraise=True,
|
|
|
- retry=(
|
|
|
- retry_if_exception_type(self.retry_exceptions)
|
|
|
- & retry_if_not_exception_type(OperationCancelled)
|
|
|
- ),
|
|
|
- wait=custom_completion_wait,
|
|
|
- )
|
|
|
- async def async_completion_wrapper(*args, **kwargs):
|
|
|
- """Async wrapper for the litellm acompletion function."""
|
|
|
- # some callers might just send the messages directly
|
|
|
- if 'messages' in kwargs:
|
|
|
- messages = kwargs['messages']
|
|
|
- else:
|
|
|
- messages = args[1] if len(args) > 1 else []
|
|
|
+ # post-process the response
|
|
|
+ self._post_completion(resp)
|
|
|
|
|
|
- # this serves to prevent empty messages and logging the messages
|
|
|
- debug_message = self._get_debug_message(messages)
|
|
|
-
|
|
|
- async def check_stopped():
|
|
|
- while should_continue():
|
|
|
- if (
|
|
|
- hasattr(self.config, 'on_cancel_requested_fn')
|
|
|
- and self.config.on_cancel_requested_fn is not None
|
|
|
- and await self.config.on_cancel_requested_fn()
|
|
|
- ):
|
|
|
- raise UserCancelledError('LLM request cancelled by user')
|
|
|
- await asyncio.sleep(0.1)
|
|
|
-
|
|
|
- stop_check_task = asyncio.create_task(check_stopped())
|
|
|
-
|
|
|
- try:
|
|
|
- # Directly call and await litellm_acompletion
|
|
|
- if debug_message:
|
|
|
- llm_prompt_logger.debug(debug_message)
|
|
|
- resp = await async_completion_unwrapped(*args, **kwargs)
|
|
|
- else:
|
|
|
- logger.debug('No completion messages!')
|
|
|
- resp = {'choices': [{'message': {'content': ''}}]}
|
|
|
-
|
|
|
- # skip if messages is empty (thus debug_message is empty)
|
|
|
- if debug_message:
|
|
|
- message_back = resp['choices'][0]['message']['content']
|
|
|
- llm_response_logger.debug(message_back)
|
|
|
- else:
|
|
|
- resp = {'choices': [{'message': {'content': ''}}]}
|
|
|
- self._post_completion(resp)
|
|
|
-
|
|
|
- # We do not support streaming in this method, thus return resp
|
|
|
- return resp
|
|
|
-
|
|
|
- except UserCancelledError:
|
|
|
- logger.info('LLM request cancelled by user.')
|
|
|
- raise
|
|
|
- except (
|
|
|
- APIConnectionError,
|
|
|
- ContentPolicyViolationError,
|
|
|
- InternalServerError,
|
|
|
- NotFoundError,
|
|
|
- OpenAIError,
|
|
|
- RateLimitError,
|
|
|
- ServiceUnavailableError,
|
|
|
- ) as e:
|
|
|
- logger.error(f'Completion Error occurred:\n{e}')
|
|
|
- raise
|
|
|
-
|
|
|
- finally:
|
|
|
- await asyncio.sleep(0.1)
|
|
|
- stop_check_task.cancel()
|
|
|
- try:
|
|
|
- await stop_check_task
|
|
|
- except asyncio.CancelledError:
|
|
|
- pass
|
|
|
-
|
|
|
- @retry(
|
|
|
- before_sleep=log_retry_attempt,
|
|
|
- stop=stop_after_attempt(self.config.num_retries),
|
|
|
- reraise=True,
|
|
|
- retry=(
|
|
|
- retry_if_exception_type(self.retry_exceptions)
|
|
|
- & retry_if_not_exception_type(OperationCancelled)
|
|
|
- ),
|
|
|
- wait=custom_completion_wait,
|
|
|
- )
|
|
|
- async def async_acompletion_stream_wrapper(*args, **kwargs):
|
|
|
- """Async wrapper for the litellm acompletion with streaming function."""
|
|
|
- # some callers might just send the messages directly
|
|
|
- if 'messages' in kwargs:
|
|
|
- messages = kwargs['messages']
|
|
|
- else:
|
|
|
- messages = args[1] if len(args) > 1 else []
|
|
|
-
|
|
|
- # log the prompt
|
|
|
- debug_message = ''
|
|
|
- for message in messages:
|
|
|
- debug_message += message_separator + message['content']
|
|
|
- llm_prompt_logger.debug(debug_message)
|
|
|
-
|
|
|
- try:
|
|
|
- # Directly call and await litellm_acompletion
|
|
|
- resp = await async_completion_unwrapped(*args, **kwargs)
|
|
|
-
|
|
|
- # For streaming we iterate over the chunks
|
|
|
- async for chunk in resp:
|
|
|
- # Check for cancellation before yielding the chunk
|
|
|
- if (
|
|
|
- hasattr(self.config, 'on_cancel_requested_fn')
|
|
|
- and self.config.on_cancel_requested_fn is not None
|
|
|
- and await self.config.on_cancel_requested_fn()
|
|
|
- ):
|
|
|
- raise UserCancelledError(
|
|
|
- 'LLM request cancelled due to CANCELLED state'
|
|
|
- )
|
|
|
- # with streaming, it is "delta", not "message"!
|
|
|
- message_back = chunk['choices'][0]['delta']['content']
|
|
|
- llm_response_logger.debug(message_back)
|
|
|
- self._post_completion(chunk)
|
|
|
-
|
|
|
- yield chunk
|
|
|
-
|
|
|
- except UserCancelledError:
|
|
|
- logger.info('LLM request cancelled by user.')
|
|
|
- raise
|
|
|
- except (
|
|
|
- APIConnectionError,
|
|
|
- ContentPolicyViolationError,
|
|
|
- InternalServerError,
|
|
|
- NotFoundError,
|
|
|
- OpenAIError,
|
|
|
- RateLimitError,
|
|
|
- ServiceUnavailableError,
|
|
|
- ) as e:
|
|
|
- logger.error(f'Completion Error occurred:\n{e}')
|
|
|
- raise
|
|
|
-
|
|
|
- finally:
|
|
|
- if kwargs.get('stream', False):
|
|
|
- await asyncio.sleep(0.1)
|
|
|
-
|
|
|
- self._async_completion = async_completion_wrapper # type: ignore
|
|
|
- self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore
|
|
|
-
|
|
|
- def _get_debug_message(self, messages):
|
|
|
- if not messages:
|
|
|
- return ''
|
|
|
-
|
|
|
- messages = messages if isinstance(messages, list) else [messages]
|
|
|
- return message_separator.join(
|
|
|
- self._format_message_content(msg) for msg in messages if msg['content']
|
|
|
- )
|
|
|
-
|
|
|
- def _format_message_content(self, message):
|
|
|
- content = message['content']
|
|
|
- if isinstance(content, list):
|
|
|
- return self._format_list_content(content)
|
|
|
- return str(content)
|
|
|
-
|
|
|
- def _format_list_content(self, content_list):
|
|
|
- return '\n'.join(
|
|
|
- self._format_content_element(element) for element in content_list
|
|
|
- )
|
|
|
-
|
|
|
- def _format_content_element(self, element):
|
|
|
- if isinstance(element, dict):
|
|
|
- if 'text' in element:
|
|
|
- return element['text']
|
|
|
- if (
|
|
|
- self.vision_is_active()
|
|
|
- and 'image_url' in element
|
|
|
- and 'url' in element['image_url']
|
|
|
- ):
|
|
|
- return element['image_url']['url']
|
|
|
- return str(element)
|
|
|
+ return resp
|
|
|
|
|
|
- async def _call_acompletion(self, *args, **kwargs):
|
|
|
- """This is a wrapper for the litellm acompletion function which
|
|
|
- makes it mockable for testing.
|
|
|
- """
|
|
|
- return await litellm.acompletion(*args, **kwargs)
|
|
|
+ self._completion = wrapper
|
|
|
|
|
|
@property
|
|
|
def completion(self):
|
|
|
@@ -499,22 +202,6 @@ class LLM:
|
|
|
"""
|
|
|
return self._completion
|
|
|
|
|
|
- @property
|
|
|
- def async_completion(self):
|
|
|
- """Decorator for the async litellm acompletion function.
|
|
|
-
|
|
|
- Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
|
|
|
- """
|
|
|
- return self._async_completion
|
|
|
-
|
|
|
- @property
|
|
|
- def async_streaming_completion(self):
|
|
|
- """Decorator for the async litellm acompletion function with streaming.
|
|
|
-
|
|
|
- Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
|
|
|
- """
|
|
|
- return self._async_streaming_completion
|
|
|
-
|
|
|
def vision_is_active(self):
|
|
|
return not self.config.disable_vision and self._supports_vision()
|
|
|
|
|
|
@@ -540,14 +227,18 @@ class LLM:
|
|
|
)
|
|
|
|
|
|
def _post_completion(self, response) -> None:
|
|
|
- """Post-process the completion response."""
|
|
|
+ """Post-process the completion response.
|
|
|
+
|
|
|
+ Logs the cost and usage stats of the completion call.
|
|
|
+ """
|
|
|
try:
|
|
|
- cur_cost = self.completion_cost(response)
|
|
|
+ cur_cost = self._completion_cost(response)
|
|
|
except Exception:
|
|
|
cur_cost = 0
|
|
|
|
|
|
stats = ''
|
|
|
if self.cost_metric_supported:
|
|
|
+ # keep track of the cost
|
|
|
stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
|
|
|
cur_cost,
|
|
|
self.metrics.accumulated_cost,
|
|
|
@@ -556,6 +247,7 @@ class LLM:
|
|
|
usage = response.get('usage')
|
|
|
|
|
|
if usage:
|
|
|
+ # keep track of the input and output tokens
|
|
|
input_tokens = usage.get('prompt_tokens')
|
|
|
output_tokens = usage.get('completion_tokens')
|
|
|
|
|
|
@@ -570,6 +262,7 @@ class LLM:
|
|
|
+ '\n'
|
|
|
)
|
|
|
|
|
|
+ # read the prompt caching status as received from the provider
|
|
|
model_extra = usage.get('model_extra', {})
|
|
|
|
|
|
cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
|
|
|
@@ -586,6 +279,7 @@ class LLM:
|
|
|
'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
|
|
|
)
|
|
|
|
|
|
+ # log the stats
|
|
|
if stats:
|
|
|
logger.info(stats)
|
|
|
|
|
|
@@ -604,7 +298,7 @@ class LLM:
|
|
|
# TODO: this is to limit logspam in case token count is not supported
|
|
|
return 0
|
|
|
|
|
|
- def is_local(self):
|
|
|
+ def _is_local(self):
|
|
|
"""Determines if the system is using a locally running LLM.
|
|
|
|
|
|
Returns:
|
|
|
@@ -619,7 +313,7 @@ class LLM:
|
|
|
return True
|
|
|
return False
|
|
|
|
|
|
- def completion_cost(self, response):
|
|
|
+ def _completion_cost(self, response):
|
|
|
"""Calculate the cost of a completion response based on the model. Local models are treated as free.
|
|
|
Add the current cost into total cost in metrics.
|
|
|
|
|
|
@@ -644,7 +338,7 @@ class LLM:
|
|
|
logger.info(f'Using custom cost per token: {cost_per_token}')
|
|
|
extra_kwargs['custom_cost_per_token'] = cost_per_token
|
|
|
|
|
|
- if not self.is_local():
|
|
|
+ if not self._is_local():
|
|
|
try:
|
|
|
cost = litellm_completion_cost(
|
|
|
completion_response=response, **extra_kwargs
|