async_llm.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import asyncio
  2. from functools import partial
  3. from litellm import completion as litellm_acompletion
  4. from openhands.core.exceptions import LLMResponseError, UserCancelledError
  5. from openhands.core.logger import openhands_logger as logger
  6. from openhands.llm.llm import LLM
  7. from openhands.runtime.utils.shutdown_listener import should_continue
  8. class AsyncLLM(LLM):
  9. """Asynchronous LLM class."""
  10. def __init__(self, *args, **kwargs):
  11. super().__init__(*args, **kwargs)
  12. self._async_completion = partial(
  13. self._call_acompletion,
  14. model=self.config.model,
  15. api_key=self.config.api_key,
  16. base_url=self.config.base_url,
  17. api_version=self.config.api_version,
  18. custom_llm_provider=self.config.custom_llm_provider,
  19. max_tokens=self.config.max_output_tokens,
  20. timeout=self.config.timeout,
  21. temperature=self.config.temperature,
  22. top_p=self.config.top_p,
  23. drop_params=self.config.drop_params,
  24. )
  25. async_completion_unwrapped = self._async_completion
  26. @self.retry_decorator(
  27. num_retries=self.config.num_retries,
  28. retry_exceptions=self.retry_exceptions,
  29. retry_min_wait=self.config.retry_min_wait,
  30. retry_max_wait=self.config.retry_max_wait,
  31. retry_multiplier=self.config.retry_multiplier,
  32. )
  33. async def async_completion_wrapper(*args, **kwargs):
  34. """Wrapper for the litellm acompletion function."""
  35. # some callers might just send the messages directly
  36. if 'messages' in kwargs:
  37. messages = kwargs['messages']
  38. else:
  39. messages = args[1] if len(args) > 1 else []
  40. if not messages:
  41. raise ValueError(
  42. 'The messages list is empty. At least one message is required.'
  43. )
  44. self.log_prompt(messages)
  45. async def check_stopped():
  46. while should_continue():
  47. if (
  48. hasattr(self.config, 'on_cancel_requested_fn')
  49. and self.config.on_cancel_requested_fn is not None
  50. and await self.config.on_cancel_requested_fn()
  51. ):
  52. raise UserCancelledError('LLM request cancelled by user')
  53. await asyncio.sleep(0.1)
  54. stop_check_task = asyncio.create_task(check_stopped())
  55. try:
  56. # Directly call and await litellm_acompletion
  57. resp = await async_completion_unwrapped(*args, **kwargs)
  58. message_back = resp['choices'][0]['message']['content']
  59. self.log_response(message_back)
  60. self._post_completion(resp)
  61. # We do not support streaming in this method, thus return resp
  62. return resp
  63. except UserCancelledError:
  64. logger.info('LLM request cancelled by user.')
  65. raise
  66. except Exception as e:
  67. logger.error(f'Completion Error occurred:\n{e}')
  68. raise
  69. finally:
  70. await asyncio.sleep(0.1)
  71. stop_check_task.cancel()
  72. try:
  73. await stop_check_task
  74. except asyncio.CancelledError:
  75. pass
  76. self._async_completion = async_completion_wrapper # type: ignore
  77. async def _call_acompletion(self, *args, **kwargs):
  78. """Wrapper for the litellm acompletion function."""
  79. # Used in testing?
  80. return await litellm_acompletion(*args, **kwargs)
  81. @property
  82. def async_completion(self):
  83. """Decorator for the async litellm acompletion function."""
  84. try:
  85. return self._async_completion
  86. except Exception as e:
  87. raise LLMResponseError(e)