| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- import asyncio
- from functools import partial
- from typing import Any
- from openhands.core.exceptions import UserCancelledError
- from openhands.core.logger import openhands_logger as logger
- from openhands.llm.async_llm import LLM_RETRY_EXCEPTIONS, AsyncLLM
- class StreamingLLM(AsyncLLM):
- """Streaming LLM class."""
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._async_streaming_completion = partial(
- self._call_acompletion,
- model=self.config.model,
- api_key=self.config.api_key,
- base_url=self.config.base_url,
- api_version=self.config.api_version,
- custom_llm_provider=self.config.custom_llm_provider,
- max_tokens=self.config.max_output_tokens,
- timeout=self.config.timeout,
- temperature=self.config.temperature,
- top_p=self.config.top_p,
- drop_params=self.config.drop_params,
- stream=True, # Ensure streaming is enabled
- )
- async_streaming_completion_unwrapped = self._async_streaming_completion
- @self.retry_decorator(
- num_retries=self.config.num_retries,
- retry_exceptions=LLM_RETRY_EXCEPTIONS,
- retry_min_wait=self.config.retry_min_wait,
- retry_max_wait=self.config.retry_max_wait,
- retry_multiplier=self.config.retry_multiplier,
- )
- async def async_streaming_completion_wrapper(*args, **kwargs):
- messages: list[dict[str, Any]] | dict[str, Any] = []
- # some callers might send the model and messages directly
- # litellm allows positional args, like completion(model, messages, **kwargs)
- # see llm.py for more details
- if len(args) > 1:
- messages = args[1] if len(args) > 1 else args[0]
- kwargs['messages'] = messages
- # remove the first args, they're sent in kwargs
- args = args[2:]
- elif 'messages' in kwargs:
- messages = kwargs['messages']
- # ensure we work with a list of messages
- messages = messages if isinstance(messages, list) else [messages]
- # if we have no messages, something went very wrong
- if not messages:
- raise ValueError(
- 'The messages list is empty. At least one message is required.'
- )
- self.log_prompt(messages)
- try:
- # Directly call and await litellm_acompletion
- resp = await async_streaming_completion_unwrapped(*args, **kwargs)
- # For streaming we iterate over the chunks
- async for chunk in resp:
- # Check for cancellation before yielding the chunk
- if (
- hasattr(self.config, 'on_cancel_requested_fn')
- and self.config.on_cancel_requested_fn is not None
- and await self.config.on_cancel_requested_fn()
- ):
- raise UserCancelledError(
- 'LLM request cancelled due to CANCELLED state'
- )
- # with streaming, it is "delta", not "message"!
- message_back = chunk['choices'][0]['delta'].get('content', '')
- if message_back:
- self.log_response(message_back)
- self._post_completion(chunk)
- yield chunk
- except UserCancelledError:
- logger.debug('LLM request cancelled by user.')
- raise
- except Exception as e:
- logger.error(f'Completion Error occurred:\n{e}')
- raise
- finally:
- # sleep for 0.1 seconds to allow the stream to be flushed
- if kwargs.get('stream', False):
- await asyncio.sleep(0.1)
- self._async_streaming_completion = async_streaming_completion_wrapper
- @property
- def async_streaming_completion(self):
- """Decorator for the async litellm acompletion function with streaming."""
- return self._async_streaming_completion
|