| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import asyncio
- from functools import partial
- from typing import Any
- from litellm import acompletion as litellm_acompletion
- from openhands.core.exceptions import UserCancelledError
- from openhands.core.logger import openhands_logger as logger
- from openhands.llm.llm import LLM, LLM_RETRY_EXCEPTIONS
- from openhands.utils.shutdown_listener import should_continue
- class AsyncLLM(LLM):
- """Asynchronous LLM class."""
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._async_completion = partial(
- self._call_acompletion,
- model=self.config.model,
- api_key=self.config.api_key,
- base_url=self.config.base_url,
- api_version=self.config.api_version,
- custom_llm_provider=self.config.custom_llm_provider,
- max_tokens=self.config.max_output_tokens,
- timeout=self.config.timeout,
- temperature=self.config.temperature,
- top_p=self.config.top_p,
- drop_params=self.config.drop_params,
- )
- async_completion_unwrapped = self._async_completion
- @self.retry_decorator(
- num_retries=self.config.num_retries,
- retry_exceptions=LLM_RETRY_EXCEPTIONS,
- retry_min_wait=self.config.retry_min_wait,
- retry_max_wait=self.config.retry_max_wait,
- retry_multiplier=self.config.retry_multiplier,
- )
- async def async_completion_wrapper(*args, **kwargs):
- """Wrapper for the litellm acompletion function that adds logging and cost tracking."""
- messages: list[dict[str, Any]] | dict[str, Any] = []
- # some callers might send the model and messages directly
- # litellm allows positional args, like completion(model, messages, **kwargs)
- # see llm.py for more details
- if len(args) > 1:
- messages = args[1] if len(args) > 1 else args[0]
- kwargs['messages'] = messages
- # remove the first args, they're sent in kwargs
- args = args[2:]
- elif 'messages' in kwargs:
- messages = kwargs['messages']
- # ensure we work with a list of messages
- messages = messages if isinstance(messages, list) else [messages]
- # if we have no messages, something went very wrong
- if not messages:
- raise ValueError(
- 'The messages list is empty. At least one message is required.'
- )
- self.log_prompt(messages)
- async def check_stopped():
- while should_continue():
- if (
- hasattr(self.config, 'on_cancel_requested_fn')
- and self.config.on_cancel_requested_fn is not None
- and await self.config.on_cancel_requested_fn()
- ):
- return
- await asyncio.sleep(0.1)
- stop_check_task = asyncio.create_task(check_stopped())
- try:
- # Directly call and await litellm_acompletion
- resp = await async_completion_unwrapped(*args, **kwargs)
- message_back = resp['choices'][0]['message']['content']
- self.log_response(message_back)
- # log costs and tokens used
- self._post_completion(resp)
- # We do not support streaming in this method, thus return resp
- return resp
- except UserCancelledError:
- logger.debug('LLM request cancelled by user.')
- raise
- except Exception as e:
- logger.error(f'Completion Error occurred:\n{e}')
- raise
- finally:
- await asyncio.sleep(0.1)
- stop_check_task.cancel()
- try:
- await stop_check_task
- except asyncio.CancelledError:
- pass
- self._async_completion = async_completion_wrapper # type: ignore
- async def _call_acompletion(self, *args, **kwargs):
- """Wrapper for the litellm acompletion function."""
- # Used in testing?
- return await litellm_acompletion(*args, **kwargs)
- @property
- def async_completion(self):
- """Decorator for the async litellm acompletion function."""
- return self._async_completion
|