async_llm.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import asyncio
  2. from functools import partial
  3. from typing import Any
  4. from litellm import acompletion as litellm_acompletion
  5. from openhands.core.exceptions import UserCancelledError
  6. from openhands.core.logger import openhands_logger as logger
  7. from openhands.llm.llm import LLM, LLM_RETRY_EXCEPTIONS
  8. from openhands.utils.shutdown_listener import should_continue
  9. class AsyncLLM(LLM):
  10. """Asynchronous LLM class."""
  11. def __init__(self, *args, **kwargs):
  12. super().__init__(*args, **kwargs)
  13. self._async_completion = partial(
  14. self._call_acompletion,
  15. model=self.config.model,
  16. api_key=self.config.api_key,
  17. base_url=self.config.base_url,
  18. api_version=self.config.api_version,
  19. custom_llm_provider=self.config.custom_llm_provider,
  20. max_tokens=self.config.max_output_tokens,
  21. timeout=self.config.timeout,
  22. temperature=self.config.temperature,
  23. top_p=self.config.top_p,
  24. drop_params=self.config.drop_params,
  25. )
  26. async_completion_unwrapped = self._async_completion
  27. @self.retry_decorator(
  28. num_retries=self.config.num_retries,
  29. retry_exceptions=LLM_RETRY_EXCEPTIONS,
  30. retry_min_wait=self.config.retry_min_wait,
  31. retry_max_wait=self.config.retry_max_wait,
  32. retry_multiplier=self.config.retry_multiplier,
  33. )
  34. async def async_completion_wrapper(*args, **kwargs):
  35. """Wrapper for the litellm acompletion function that adds logging and cost tracking."""
  36. messages: list[dict[str, Any]] | dict[str, Any] = []
  37. # some callers might send the model and messages directly
  38. # litellm allows positional args, like completion(model, messages, **kwargs)
  39. # see llm.py for more details
  40. if len(args) > 1:
  41. messages = args[1] if len(args) > 1 else args[0]
  42. kwargs['messages'] = messages
  43. # remove the first args, they're sent in kwargs
  44. args = args[2:]
  45. elif 'messages' in kwargs:
  46. messages = kwargs['messages']
  47. # ensure we work with a list of messages
  48. messages = messages if isinstance(messages, list) else [messages]
  49. # if we have no messages, something went very wrong
  50. if not messages:
  51. raise ValueError(
  52. 'The messages list is empty. At least one message is required.'
  53. )
  54. self.log_prompt(messages)
  55. async def check_stopped():
  56. while should_continue():
  57. if (
  58. hasattr(self.config, 'on_cancel_requested_fn')
  59. and self.config.on_cancel_requested_fn is not None
  60. and await self.config.on_cancel_requested_fn()
  61. ):
  62. return
  63. await asyncio.sleep(0.1)
  64. stop_check_task = asyncio.create_task(check_stopped())
  65. try:
  66. # Directly call and await litellm_acompletion
  67. resp = await async_completion_unwrapped(*args, **kwargs)
  68. message_back = resp['choices'][0]['message']['content']
  69. self.log_response(message_back)
  70. # log costs and tokens used
  71. self._post_completion(resp)
  72. # We do not support streaming in this method, thus return resp
  73. return resp
  74. except UserCancelledError:
  75. logger.debug('LLM request cancelled by user.')
  76. raise
  77. except Exception as e:
  78. logger.error(f'Completion Error occurred:\n{e}')
  79. raise
  80. finally:
  81. await asyncio.sleep(0.1)
  82. stop_check_task.cancel()
  83. try:
  84. await stop_check_task
  85. except asyncio.CancelledError:
  86. pass
  87. self._async_completion = async_completion_wrapper # type: ignore
  88. async def _call_acompletion(self, *args, **kwargs):
  89. """Wrapper for the litellm acompletion function."""
  90. # Used in testing?
  91. return await litellm_acompletion(*args, **kwargs)
  92. @property
  93. def async_completion(self):
  94. """Decorator for the async litellm acompletion function."""
  95. return self._async_completion