streaming_llm.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import asyncio
  2. from functools import partial
  3. from typing import Any
  4. from openhands.core.exceptions import UserCancelledError
  5. from openhands.core.logger import openhands_logger as logger
  6. from openhands.llm.async_llm import LLM_RETRY_EXCEPTIONS, AsyncLLM
  7. class StreamingLLM(AsyncLLM):
  8. """Streaming LLM class."""
  9. def __init__(self, *args, **kwargs):
  10. super().__init__(*args, **kwargs)
  11. self._async_streaming_completion = partial(
  12. self._call_acompletion,
  13. model=self.config.model,
  14. api_key=self.config.api_key,
  15. base_url=self.config.base_url,
  16. api_version=self.config.api_version,
  17. custom_llm_provider=self.config.custom_llm_provider,
  18. max_tokens=self.config.max_output_tokens,
  19. timeout=self.config.timeout,
  20. temperature=self.config.temperature,
  21. top_p=self.config.top_p,
  22. drop_params=self.config.drop_params,
  23. stream=True, # Ensure streaming is enabled
  24. )
  25. async_streaming_completion_unwrapped = self._async_streaming_completion
  26. @self.retry_decorator(
  27. num_retries=self.config.num_retries,
  28. retry_exceptions=LLM_RETRY_EXCEPTIONS,
  29. retry_min_wait=self.config.retry_min_wait,
  30. retry_max_wait=self.config.retry_max_wait,
  31. retry_multiplier=self.config.retry_multiplier,
  32. )
  33. async def async_streaming_completion_wrapper(*args, **kwargs):
  34. messages: list[dict[str, Any]] | dict[str, Any] = []
  35. # some callers might send the model and messages directly
  36. # litellm allows positional args, like completion(model, messages, **kwargs)
  37. # see llm.py for more details
  38. if len(args) > 1:
  39. messages = args[1] if len(args) > 1 else args[0]
  40. kwargs['messages'] = messages
  41. # remove the first args, they're sent in kwargs
  42. args = args[2:]
  43. elif 'messages' in kwargs:
  44. messages = kwargs['messages']
  45. # ensure we work with a list of messages
  46. messages = messages if isinstance(messages, list) else [messages]
  47. # if we have no messages, something went very wrong
  48. if not messages:
  49. raise ValueError(
  50. 'The messages list is empty. At least one message is required.'
  51. )
  52. self.log_prompt(messages)
  53. try:
  54. # Directly call and await litellm_acompletion
  55. resp = await async_streaming_completion_unwrapped(*args, **kwargs)
  56. # For streaming we iterate over the chunks
  57. async for chunk in resp:
  58. # Check for cancellation before yielding the chunk
  59. if (
  60. hasattr(self.config, 'on_cancel_requested_fn')
  61. and self.config.on_cancel_requested_fn is not None
  62. and await self.config.on_cancel_requested_fn()
  63. ):
  64. raise UserCancelledError(
  65. 'LLM request cancelled due to CANCELLED state'
  66. )
  67. # with streaming, it is "delta", not "message"!
  68. message_back = chunk['choices'][0]['delta'].get('content', '')
  69. if message_back:
  70. self.log_response(message_back)
  71. self._post_completion(chunk)
  72. yield chunk
  73. except UserCancelledError:
  74. logger.debug('LLM request cancelled by user.')
  75. raise
  76. except Exception as e:
  77. logger.error(f'Completion Error occurred:\n{e}')
  78. raise
  79. finally:
  80. # sleep for 0.1 seconds to allow the stream to be flushed
  81. if kwargs.get('stream', False):
  82. await asyncio.sleep(0.1)
  83. self._async_streaming_completion = async_streaming_completion_wrapper
  84. @property
  85. def async_streaming_completion(self):
  86. """Decorator for the async litellm acompletion function with streaming."""
  87. return self._async_streaming_completion