streaming_llm.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import asyncio
  2. from functools import partial
  3. from openhands.core.exceptions import LLMResponseError, UserCancelledError
  4. from openhands.core.logger import openhands_logger as logger
  5. from openhands.llm.async_llm import AsyncLLM
  6. class StreamingLLM(AsyncLLM):
  7. """Streaming LLM class."""
  8. def __init__(self, *args, **kwargs):
  9. super().__init__(*args, **kwargs)
  10. self._async_streaming_completion = partial(
  11. self._call_acompletion,
  12. model=self.config.model,
  13. api_key=self.config.api_key,
  14. base_url=self.config.base_url,
  15. api_version=self.config.api_version,
  16. custom_llm_provider=self.config.custom_llm_provider,
  17. max_tokens=self.config.max_output_tokens,
  18. timeout=self.config.timeout,
  19. temperature=self.config.temperature,
  20. top_p=self.config.top_p,
  21. drop_params=self.config.drop_params,
  22. stream=True, # Ensure streaming is enabled
  23. )
  24. async_streaming_completion_unwrapped = self._async_streaming_completion
  25. @self.retry_decorator(
  26. num_retries=self.config.num_retries,
  27. retry_exceptions=self.retry_exceptions,
  28. retry_min_wait=self.config.retry_min_wait,
  29. retry_max_wait=self.config.retry_max_wait,
  30. retry_multiplier=self.config.retry_multiplier,
  31. )
  32. async def async_streaming_completion_wrapper(*args, **kwargs):
  33. # some callers might just send the messages directly
  34. if 'messages' in kwargs:
  35. messages = kwargs['messages']
  36. else:
  37. messages = args[1] if len(args) > 1 else []
  38. if not messages:
  39. raise ValueError(
  40. 'The messages list is empty. At least one message is required.'
  41. )
  42. self.log_prompt(messages)
  43. try:
  44. # Directly call and await litellm_acompletion
  45. resp = await async_streaming_completion_unwrapped(*args, **kwargs)
  46. # For streaming we iterate over the chunks
  47. async for chunk in resp:
  48. # Check for cancellation before yielding the chunk
  49. if (
  50. hasattr(self.config, 'on_cancel_requested_fn')
  51. and self.config.on_cancel_requested_fn is not None
  52. and await self.config.on_cancel_requested_fn()
  53. ):
  54. raise UserCancelledError(
  55. 'LLM request cancelled due to CANCELLED state'
  56. )
  57. # with streaming, it is "delta", not "message"!
  58. message_back = chunk['choices'][0]['delta'].get('content', '')
  59. if message_back:
  60. self.log_response(message_back)
  61. self._post_completion(chunk)
  62. yield chunk
  63. except UserCancelledError:
  64. logger.info('LLM request cancelled by user.')
  65. raise
  66. except Exception as e:
  67. logger.error(f'Completion Error occurred:\n{e}')
  68. raise
  69. finally:
  70. # sleep for 0.1 seconds to allow the stream to be flushed
  71. if kwargs.get('stream', False):
  72. await asyncio.sleep(0.1)
  73. self._async_streaming_completion = async_streaming_completion_wrapper
  74. @property
  75. def async_streaming_completion(self):
  76. """Decorator for the async litellm acompletion function with streaming."""
  77. try:
  78. return self._async_streaming_completion
  79. except Exception as e:
  80. raise LLMResponseError(e)