Engel Nyst 1 год назад
Родитель
Сommit
0a03c802f5

+ 14 - 2
docs/modules/usage/llms/llms.md

@@ -59,9 +59,21 @@ We have a few guides for running OpenHands with specific model providers:
 
 ### API retries and rate limits
 
-Some LLMs have rate limits and may require retries. OpenHands will automatically retry requests if it receives a 429 error or API connection error.
-You can set the following environment variables to control the number of retries and the time between retries:
+LLM providers typically have rate limits, sometimes very low, and may require retries. OpenHands will automatically retry requests if it receives a Rate Limit Error (429 error code), API connection error, or other transient errors.
+
+You can customize these options as you need for the provider you're using. Check their documentation, and set the following environment variables to control the number of retries and the time between retries:
 
 * `LLM_NUM_RETRIES` (Default of 8)
 * `LLM_RETRY_MIN_WAIT` (Default of 15 seconds)
 * `LLM_RETRY_MAX_WAIT` (Default of 120 seconds)
+* `LLM_RETRY_MULTIPLIER` (Default of 2)
+
+If you running `openhands` in development mode, you can also set these options to the values you need in `config.toml` file:
+
+```toml
+[llm]
+num_retries = 8
+retry_min_wait = 15
+retry_max_wait = 120
+retry_multiplier = 2
+```

+ 107 - 0
openhands/llm/async_llm.py

@@ -0,0 +1,107 @@
+import asyncio
+from functools import partial
+
+from litellm import completion as litellm_acompletion
+
+from openhands.core.exceptions import LLMResponseError, UserCancelledError
+from openhands.core.logger import openhands_logger as logger
+from openhands.llm.llm import LLM
+from openhands.runtime.utils.shutdown_listener import should_continue
+
+
+class AsyncLLM(LLM):
+    """Asynchronous LLM class."""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self._async_completion = partial(
+            self._call_acompletion,
+            model=self.config.model,
+            api_key=self.config.api_key,
+            base_url=self.config.base_url,
+            api_version=self.config.api_version,
+            custom_llm_provider=self.config.custom_llm_provider,
+            max_tokens=self.config.max_output_tokens,
+            timeout=self.config.timeout,
+            temperature=self.config.temperature,
+            top_p=self.config.top_p,
+            drop_params=self.config.drop_params,
+        )
+
+        async_completion_unwrapped = self._async_completion
+
+        @self.retry_decorator(
+            num_retries=self.config.num_retries,
+            retry_exceptions=self.retry_exceptions,
+            retry_min_wait=self.config.retry_min_wait,
+            retry_max_wait=self.config.retry_max_wait,
+            retry_multiplier=self.config.retry_multiplier,
+        )
+        async def async_completion_wrapper(*args, **kwargs):
+            """Wrapper for the litellm acompletion function."""
+            # some callers might just send the messages directly
+            if 'messages' in kwargs:
+                messages = kwargs['messages']
+            else:
+                messages = args[1] if len(args) > 1 else []
+
+            if not messages:
+                raise ValueError(
+                    'The messages list is empty. At least one message is required.'
+                )
+
+            self.log_prompt(messages)
+
+            async def check_stopped():
+                while should_continue():
+                    if (
+                        hasattr(self.config, 'on_cancel_requested_fn')
+                        and self.config.on_cancel_requested_fn is not None
+                        and await self.config.on_cancel_requested_fn()
+                    ):
+                        raise UserCancelledError('LLM request cancelled by user')
+                    await asyncio.sleep(0.1)
+
+            stop_check_task = asyncio.create_task(check_stopped())
+
+            try:
+                # Directly call and await litellm_acompletion
+                resp = await async_completion_unwrapped(*args, **kwargs)
+
+                message_back = resp['choices'][0]['message']['content']
+                self.log_response(message_back)
+                self._post_completion(resp)
+
+                # We do not support streaming in this method, thus return resp
+                return resp
+
+            except UserCancelledError:
+                logger.info('LLM request cancelled by user.')
+                raise
+            except Exception as e:
+                logger.error(f'Completion Error occurred:\n{e}')
+                raise
+
+            finally:
+                await asyncio.sleep(0.1)
+                stop_check_task.cancel()
+                try:
+                    await stop_check_task
+                except asyncio.CancelledError:
+                    pass
+
+        self._async_completion = async_completion_wrapper  # type: ignore
+
+    async def _call_acompletion(self, *args, **kwargs):
+        """Wrapper for the litellm acompletion function."""
+        # Used in testing?
+        return await litellm_acompletion(*args, **kwargs)
+
+    @property
+    def async_completion(self):
+        """Decorator for the async litellm acompletion function."""
+        try:
+            return self._async_completion
+        except Exception as e:
+            raise LLMResponseError(e)

+ 53 - 0
openhands/llm/debug_mixin.py

@@ -0,0 +1,53 @@
+from openhands.core.logger import llm_prompt_logger, llm_response_logger
+from openhands.core.logger import openhands_logger as logger
+
+MESSAGE_SEPARATOR = '\n\n----------\n\n'
+
+
+class DebugMixin:
+    def log_prompt(self, messages):
+        if not messages:
+            logger.debug('No completion messages!')
+            return
+
+        messages = messages if isinstance(messages, list) else [messages]
+        debug_message = MESSAGE_SEPARATOR.join(
+            self._format_message_content(msg) for msg in messages if msg['content']
+        )
+
+        if debug_message:
+            llm_prompt_logger.debug(debug_message)
+        else:
+            logger.debug('No completion messages!')
+
+    def log_response(self, message_back):
+        if message_back:
+            llm_response_logger.debug(message_back)
+
+    def _format_message_content(self, message):
+        content = message['content']
+        if isinstance(content, list):
+            return '\n'.join(
+                self._format_content_element(element) for element in content
+            )
+        return str(content)
+
+    def _format_content_element(self, element):
+        if isinstance(element, dict):
+            if 'text' in element:
+                return element['text']
+            if (
+                self.vision_is_active()
+                and 'image_url' in element
+                and 'url' in element['image_url']
+            ):
+                return element['image_url']['url']
+        return str(element)
+
+    def _log_stats(self, stats):
+        if stats:
+            logger.info(stats)
+
+    # This method should be implemented in the class that uses DebugMixin
+    def vision_is_active(self):
+        raise NotImplementedError

+ 40 - 346
openhands/llm/llm.py

@@ -1,4 +1,3 @@
-import asyncio
 import copy
 import os
 import time
@@ -7,7 +6,6 @@ from functools import partial
 from typing import Any
 
 from openhands.core.config import LLMConfig
-from openhands.runtime.utils.shutdown_listener import should_continue
 
 with warnings.catch_warnings():
     warnings.simplefilter('ignore')
@@ -18,41 +16,26 @@ from litellm.exceptions import (
     APIConnectionError,
     ContentPolicyViolationError,
     InternalServerError,
-    NotFoundError,
     OpenAIError,
     RateLimitError,
-    ServiceUnavailableError,
 )
 from litellm.types.utils import CostPerToken
-from tenacity import (
-    retry,
-    retry_if_exception_type,
-    retry_if_not_exception_type,
-    stop_after_attempt,
-    wait_exponential,
-)
 
-from openhands.core.exceptions import (
-    OperationCancelled,
-    UserCancelledError,
-)
-from openhands.core.logger import llm_prompt_logger, llm_response_logger
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.message import Message
 from openhands.core.metrics import Metrics
-from openhands.runtime.utils.shutdown_listener import should_exit
+from openhands.llm.debug_mixin import DebugMixin
+from openhands.llm.retry_mixin import RetryMixin
 
 __all__ = ['LLM']
 
-message_separator = '\n\n----------\n\n'
-
 cache_prompting_supported_models = [
     'claude-3-5-sonnet-20240620',
     'claude-3-haiku-20240307',
 ]
 
 
-class LLM:
+class LLM(RetryMixin, DebugMixin):
     """The LLM class represents a Language Model instance.
 
     Attributes:
@@ -69,7 +52,8 @@ class LLM:
         Passing simple parameters always overrides config.
 
         Args:
-            config: The LLM configuration
+            config: The LLM configuration.
+            metrics: The metrics to use.
         """
         self.metrics = metrics if metrics is not None else Metrics()
         self.cost_metric_supported = True
@@ -135,30 +119,6 @@ class LLM:
                 ):
                     self.config.max_output_tokens = self.model_info['max_tokens']
 
-        # This only seems to work with Google as the provider, not with OpenRouter!
-        gemini_safety_settings = (
-            [
-                {
-                    'category': 'HARM_CATEGORY_HARASSMENT',
-                    'threshold': 'BLOCK_NONE',
-                },
-                {
-                    'category': 'HARM_CATEGORY_HATE_SPEECH',
-                    'threshold': 'BLOCK_NONE',
-                },
-                {
-                    'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
-                    'threshold': 'BLOCK_NONE',
-                },
-                {
-                    'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
-                    'threshold': 'BLOCK_NONE',
-                },
-            ]
-            if self.config.model.lower().startswith('gemini')
-            else None
-        )
-
         self._completion = partial(
             litellm_completion,
             model=self.config.model,
@@ -171,11 +131,6 @@ class LLM:
             temperature=self.config.temperature,
             top_p=self.config.top_p,
             drop_params=self.config.drop_params,
-            **(
-                {'safety_settings': gemini_safety_settings}
-                if gemini_safety_settings is not None
-                else {}
-            ),
         )
 
         if self.vision_is_active():
@@ -183,61 +138,12 @@ class LLM:
 
         completion_unwrapped = self._completion
 
-        def log_retry_attempt(retry_state):
-            """With before_sleep, this is called before `custom_completion_wait` and
-            ONLY if the retry is triggered by an exception."""
-            if should_exit():
-                raise OperationCancelled(
-                    'Operation cancelled.'
-                )  # exits the @retry loop
-            exception = retry_state.outcome.exception()
-            logger.error(
-                f'{exception}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
-                exc_info=False,
-            )
-
-        def custom_completion_wait(retry_state):
-            """Custom wait function for litellm completion."""
-            if not retry_state:
-                return 0
-            exception = retry_state.outcome.exception() if retry_state.outcome else None
-            if exception is None:
-                return 0
-
-            min_wait_time = self.config.retry_min_wait
-            max_wait_time = self.config.retry_max_wait
-
-            # for rate limit errors, wait 1 minute by default, max 4 minutes between retries
-            exception_type = type(exception).__name__
-            logger.error(f'\nexception_type: {exception_type}\n')
-
-            if exception_type == 'RateLimitError':
-                min_wait_time = 60
-                max_wait_time = 240
-            elif exception_type == 'BadRequestError' and exception.response:
-                # this should give us the burried, actual error message from
-                # the LLM model.
-                logger.error(f'\n\nBadRequestError: {exception.response}\n\n')
-
-            # Return the wait time using exponential backoff
-            exponential_wait = wait_exponential(
-                multiplier=self.config.retry_multiplier,
-                min=min_wait_time,
-                max=max_wait_time,
-            )
-
-            # Call the exponential wait function with retry_state to get the actual wait time
-            return exponential_wait(retry_state)
-
-        @retry(
-            before_sleep=log_retry_attempt,
-            stop=stop_after_attempt(self.config.num_retries),
-            reraise=True,
-            retry=(
-                retry_if_exception_type(self.retry_exceptions)
-                & retry_if_not_exception_type(OperationCancelled)
-            ),
-            wait=custom_completion_wait,
+        @self.retry_decorator(
+            num_retries=self.config.num_retries,
+            retry_exceptions=self.retry_exceptions,
+            retry_min_wait=self.config.retry_min_wait,
+            retry_max_wait=self.config.retry_max_wait,
+            retry_multiplier=self.config.retry_multiplier,
         )
         def wrapper(*args, **kwargs):
             """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
@@ -247,8 +153,14 @@ class LLM:
             else:
                 messages = args[1] if len(args) > 1 else []
 
-            # this serves to prevent empty messages and logging the messages
-            debug_message = self._get_debug_message(messages)
+            # if we have no messages, something went very wrong
+            if not messages:
+                raise ValueError(
+                    'The messages list is empty. At least one message is required.'
+                )
+
+            # log the entire LLM prompt
+            self.log_prompt(messages)
 
             if self.is_caching_prompt_active():
                 # Anthropic-specific prompt caching
@@ -257,239 +169,30 @@ class LLM:
                         'anthropic-beta': 'prompt-caching-2024-07-31',
                     }
 
-            # skip if messages is empty (thus debug_message is empty)
-            if debug_message:
-                llm_prompt_logger.debug(debug_message)
-                resp = completion_unwrapped(*args, **kwargs)
-            else:
-                logger.debug('No completion messages!')
-                resp = {'choices': [{'message': {'content': ''}}]}
+            resp = completion_unwrapped(*args, **kwargs)
 
+            # log for evals or other scripts that need the raw completion
             if self.config.log_completions:
                 self.llm_completions.append(
                     {
                         'messages': messages,
                         'response': resp,
                         'timestamp': time.time(),
-                        'cost': self.completion_cost(resp),
+                        'cost': self._completion_cost(resp),
                     }
                 )
 
-            # log the response
             message_back = resp['choices'][0]['message']['content']
-            if message_back:
-                llm_response_logger.debug(message_back)
-
-                # post-process to log costs
-                self._post_completion(resp)
 
-            return resp
-
-        self._completion = wrapper  # type: ignore
+            # log the LLM response
+            self.log_response(message_back)
 
-        # Async version
-        self._async_completion = partial(
-            self._call_acompletion,
-            model=self.config.model,
-            api_key=self.config.api_key,
-            base_url=self.config.base_url,
-            api_version=self.config.api_version,
-            custom_llm_provider=self.config.custom_llm_provider,
-            max_tokens=self.config.max_output_tokens,
-            timeout=self.config.timeout,
-            temperature=self.config.temperature,
-            top_p=self.config.top_p,
-            drop_params=self.config.drop_params,
-            **(
-                {'safety_settings': gemini_safety_settings}
-                if gemini_safety_settings is not None
-                else {}
-            ),
-        )
-
-        async_completion_unwrapped = self._async_completion
-
-        @retry(
-            before_sleep=log_retry_attempt,
-            stop=stop_after_attempt(self.config.num_retries),
-            reraise=True,
-            retry=(
-                retry_if_exception_type(self.retry_exceptions)
-                & retry_if_not_exception_type(OperationCancelled)
-            ),
-            wait=custom_completion_wait,
-        )
-        async def async_completion_wrapper(*args, **kwargs):
-            """Async wrapper for the litellm acompletion function."""
-            # some callers might just send the messages directly
-            if 'messages' in kwargs:
-                messages = kwargs['messages']
-            else:
-                messages = args[1] if len(args) > 1 else []
+            # post-process the response
+            self._post_completion(resp)
 
-            # this serves to prevent empty messages and logging the messages
-            debug_message = self._get_debug_message(messages)
-
-            async def check_stopped():
-                while should_continue():
-                    if (
-                        hasattr(self.config, 'on_cancel_requested_fn')
-                        and self.config.on_cancel_requested_fn is not None
-                        and await self.config.on_cancel_requested_fn()
-                    ):
-                        raise UserCancelledError('LLM request cancelled by user')
-                    await asyncio.sleep(0.1)
-
-            stop_check_task = asyncio.create_task(check_stopped())
-
-            try:
-                # Directly call and await litellm_acompletion
-                if debug_message:
-                    llm_prompt_logger.debug(debug_message)
-                    resp = await async_completion_unwrapped(*args, **kwargs)
-                else:
-                    logger.debug('No completion messages!')
-                    resp = {'choices': [{'message': {'content': ''}}]}
-
-                # skip if messages is empty (thus debug_message is empty)
-                if debug_message:
-                    message_back = resp['choices'][0]['message']['content']
-                    llm_response_logger.debug(message_back)
-                else:
-                    resp = {'choices': [{'message': {'content': ''}}]}
-                self._post_completion(resp)
-
-                # We do not support streaming in this method, thus return resp
-                return resp
-
-            except UserCancelledError:
-                logger.info('LLM request cancelled by user.')
-                raise
-            except (
-                APIConnectionError,
-                ContentPolicyViolationError,
-                InternalServerError,
-                NotFoundError,
-                OpenAIError,
-                RateLimitError,
-                ServiceUnavailableError,
-            ) as e:
-                logger.error(f'Completion Error occurred:\n{e}')
-                raise
-
-            finally:
-                await asyncio.sleep(0.1)
-                stop_check_task.cancel()
-                try:
-                    await stop_check_task
-                except asyncio.CancelledError:
-                    pass
-
-        @retry(
-            before_sleep=log_retry_attempt,
-            stop=stop_after_attempt(self.config.num_retries),
-            reraise=True,
-            retry=(
-                retry_if_exception_type(self.retry_exceptions)
-                & retry_if_not_exception_type(OperationCancelled)
-            ),
-            wait=custom_completion_wait,
-        )
-        async def async_acompletion_stream_wrapper(*args, **kwargs):
-            """Async wrapper for the litellm acompletion with streaming function."""
-            # some callers might just send the messages directly
-            if 'messages' in kwargs:
-                messages = kwargs['messages']
-            else:
-                messages = args[1] if len(args) > 1 else []
-
-            # log the prompt
-            debug_message = ''
-            for message in messages:
-                debug_message += message_separator + message['content']
-            llm_prompt_logger.debug(debug_message)
-
-            try:
-                # Directly call and await litellm_acompletion
-                resp = await async_completion_unwrapped(*args, **kwargs)
-
-                # For streaming we iterate over the chunks
-                async for chunk in resp:
-                    # Check for cancellation before yielding the chunk
-                    if (
-                        hasattr(self.config, 'on_cancel_requested_fn')
-                        and self.config.on_cancel_requested_fn is not None
-                        and await self.config.on_cancel_requested_fn()
-                    ):
-                        raise UserCancelledError(
-                            'LLM request cancelled due to CANCELLED state'
-                        )
-                    # with streaming, it is "delta", not "message"!
-                    message_back = chunk['choices'][0]['delta']['content']
-                    llm_response_logger.debug(message_back)
-                    self._post_completion(chunk)
-
-                    yield chunk
-
-            except UserCancelledError:
-                logger.info('LLM request cancelled by user.')
-                raise
-            except (
-                APIConnectionError,
-                ContentPolicyViolationError,
-                InternalServerError,
-                NotFoundError,
-                OpenAIError,
-                RateLimitError,
-                ServiceUnavailableError,
-            ) as e:
-                logger.error(f'Completion Error occurred:\n{e}')
-                raise
-
-            finally:
-                if kwargs.get('stream', False):
-                    await asyncio.sleep(0.1)
-
-        self._async_completion = async_completion_wrapper  # type: ignore
-        self._async_streaming_completion = async_acompletion_stream_wrapper  # type: ignore
-
-    def _get_debug_message(self, messages):
-        if not messages:
-            return ''
-
-        messages = messages if isinstance(messages, list) else [messages]
-        return message_separator.join(
-            self._format_message_content(msg) for msg in messages if msg['content']
-        )
-
-    def _format_message_content(self, message):
-        content = message['content']
-        if isinstance(content, list):
-            return self._format_list_content(content)
-        return str(content)
-
-    def _format_list_content(self, content_list):
-        return '\n'.join(
-            self._format_content_element(element) for element in content_list
-        )
-
-    def _format_content_element(self, element):
-        if isinstance(element, dict):
-            if 'text' in element:
-                return element['text']
-            if (
-                self.vision_is_active()
-                and 'image_url' in element
-                and 'url' in element['image_url']
-            ):
-                return element['image_url']['url']
-        return str(element)
+            return resp
 
-    async def _call_acompletion(self, *args, **kwargs):
-        """This is a wrapper for the litellm acompletion function which
-        makes it mockable for testing.
-        """
-        return await litellm.acompletion(*args, **kwargs)
+        self._completion = wrapper
 
     @property
     def completion(self):
@@ -499,22 +202,6 @@ class LLM:
         """
         return self._completion
 
-    @property
-    def async_completion(self):
-        """Decorator for the async litellm acompletion function.
-
-        Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
-        """
-        return self._async_completion
-
-    @property
-    def async_streaming_completion(self):
-        """Decorator for the async litellm acompletion function with streaming.
-
-        Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
-        """
-        return self._async_streaming_completion
-
     def vision_is_active(self):
         return not self.config.disable_vision and self._supports_vision()
 
@@ -540,14 +227,18 @@ class LLM:
         )
 
     def _post_completion(self, response) -> None:
-        """Post-process the completion response."""
+        """Post-process the completion response.
+
+        Logs the cost and usage stats of the completion call.
+        """
         try:
-            cur_cost = self.completion_cost(response)
+            cur_cost = self._completion_cost(response)
         except Exception:
             cur_cost = 0
 
         stats = ''
         if self.cost_metric_supported:
+            # keep track of the cost
             stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
                 cur_cost,
                 self.metrics.accumulated_cost,
@@ -556,6 +247,7 @@ class LLM:
         usage = response.get('usage')
 
         if usage:
+            # keep track of the input and output tokens
             input_tokens = usage.get('prompt_tokens')
             output_tokens = usage.get('completion_tokens')
 
@@ -570,6 +262,7 @@ class LLM:
                     + '\n'
                 )
 
+            # read the prompt caching status as received from the provider
             model_extra = usage.get('model_extra', {})
 
             cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
@@ -586,6 +279,7 @@ class LLM:
                     'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
                 )
 
+        # log the stats
         if stats:
             logger.info(stats)
 
@@ -604,7 +298,7 @@ class LLM:
             # TODO: this is to limit logspam in case token count is not supported
             return 0
 
-    def is_local(self):
+    def _is_local(self):
         """Determines if the system is using a locally running LLM.
 
         Returns:
@@ -619,7 +313,7 @@ class LLM:
                 return True
         return False
 
-    def completion_cost(self, response):
+    def _completion_cost(self, response):
         """Calculate the cost of a completion response based on the model.  Local models are treated as free.
         Add the current cost into total cost in metrics.
 
@@ -644,7 +338,7 @@ class LLM:
             logger.info(f'Using custom cost per token: {cost_per_token}')
             extra_kwargs['custom_cost_per_token'] = cost_per_token
 
-        if not self.is_local():
+        if not self._is_local():
             try:
                 cost = litellm_completion_cost(
                     completion_response=response, **extra_kwargs

+ 53 - 0
openhands/llm/retry_mixin.py

@@ -0,0 +1,53 @@
+from tenacity import (
+    retry,
+    retry_if_exception_type,
+    stop_after_attempt,
+    wait_exponential,
+)
+
+from openhands.core.exceptions import OperationCancelled
+from openhands.core.logger import openhands_logger as logger
+from openhands.runtime.utils.shutdown_listener import should_exit
+
+
+class RetryMixin:
+    """Mixin class for retry logic."""
+
+    def retry_decorator(self, **kwargs):
+        """
+        Create a LLM retry decorator with customizable parameters. This is used for 429 errors, and a few other exceptions in LLM classes.
+
+        Args:
+            **kwargs: Keyword arguments to override default retry behavior.
+                      Keys: num_retries, retry_exceptions, retry_min_wait, retry_max_wait, retry_multiplier
+
+        Returns:
+            A retry decorator with the parameters customizable in configuration.
+        """
+        num_retries = kwargs.get('num_retries')
+        retry_exceptions = kwargs.get('retry_exceptions')
+        retry_min_wait = kwargs.get('retry_min_wait')
+        retry_max_wait = kwargs.get('retry_max_wait')
+        retry_multiplier = kwargs.get('retry_multiplier')
+
+        return retry(
+            before_sleep=self.log_retry_attempt,
+            stop=stop_after_attempt(num_retries),
+            reraise=True,
+            retry=(retry_if_exception_type(retry_exceptions)),
+            wait=wait_exponential(
+                multiplier=retry_multiplier,
+                min=retry_min_wait,
+                max=retry_max_wait,
+            ),
+        )
+
+    def log_retry_attempt(self, retry_state):
+        """Log retry attempts."""
+        if should_exit():
+            raise OperationCancelled('Operation cancelled.')  # exits the @retry loop
+        exception = retry_state.outcome.exception()
+        logger.error(
+            f'{exception}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
+            exc_info=False,
+        )

+ 96 - 0
openhands/llm/streaming_llm.py

@@ -0,0 +1,96 @@
+import asyncio
+from functools import partial
+
+from openhands.core.exceptions import LLMResponseError, UserCancelledError
+from openhands.core.logger import openhands_logger as logger
+from openhands.llm.async_llm import AsyncLLM
+
+
+class StreamingLLM(AsyncLLM):
+    """Streaming LLM class."""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self._async_streaming_completion = partial(
+            self._call_acompletion,
+            model=self.config.model,
+            api_key=self.config.api_key,
+            base_url=self.config.base_url,
+            api_version=self.config.api_version,
+            custom_llm_provider=self.config.custom_llm_provider,
+            max_tokens=self.config.max_output_tokens,
+            timeout=self.config.timeout,
+            temperature=self.config.temperature,
+            top_p=self.config.top_p,
+            drop_params=self.config.drop_params,
+            stream=True,  # Ensure streaming is enabled
+        )
+
+        async_streaming_completion_unwrapped = self._async_streaming_completion
+
+        @self.retry_decorator(
+            num_retries=self.config.num_retries,
+            retry_exceptions=self.retry_exceptions,
+            retry_min_wait=self.config.retry_min_wait,
+            retry_max_wait=self.config.retry_max_wait,
+            retry_multiplier=self.config.retry_multiplier,
+        )
+        async def async_streaming_completion_wrapper(*args, **kwargs):
+            # some callers might just send the messages directly
+            if 'messages' in kwargs:
+                messages = kwargs['messages']
+            else:
+                messages = args[1] if len(args) > 1 else []
+
+            if not messages:
+                raise ValueError(
+                    'The messages list is empty. At least one message is required.'
+                )
+
+            self.log_prompt(messages)
+
+            try:
+                # Directly call and await litellm_acompletion
+                resp = await async_streaming_completion_unwrapped(*args, **kwargs)
+
+                # For streaming we iterate over the chunks
+                async for chunk in resp:
+                    # Check for cancellation before yielding the chunk
+                    if (
+                        hasattr(self.config, 'on_cancel_requested_fn')
+                        and self.config.on_cancel_requested_fn is not None
+                        and await self.config.on_cancel_requested_fn()
+                    ):
+                        raise UserCancelledError(
+                            'LLM request cancelled due to CANCELLED state'
+                        )
+                    # with streaming, it is "delta", not "message"!
+                    message_back = chunk['choices'][0]['delta'].get('content', '')
+                    if message_back:
+                        self.log_response(message_back)
+                    self._post_completion(chunk)
+
+                    yield chunk
+
+            except UserCancelledError:
+                logger.info('LLM request cancelled by user.')
+                raise
+            except Exception as e:
+                logger.error(f'Completion Error occurred:\n{e}')
+                raise
+
+            finally:
+                # sleep for 0.1 seconds to allow the stream to be flushed
+                if kwargs.get('stream', False):
+                    await asyncio.sleep(0.1)
+
+        self._async_streaming_completion = async_streaming_completion_wrapper
+
+    @property
+    def async_streaming_completion(self):
+        """Decorator for the async litellm acompletion function with streaming."""
+        try:
+            return self._async_streaming_completion
+        except Exception as e:
+            raise LLMResponseError(e)

+ 5 - 5
tests/integration/conftest.py

@@ -11,7 +11,7 @@ from http.server import HTTPServer, SimpleHTTPRequestHandler
 import pytest
 from litellm import completion
 
-from openhands.llm.llm import message_separator
+from openhands.llm.debug_mixin import MESSAGE_SEPARATOR
 
 script_dir = os.environ.get('SCRIPT_DIR')
 project_root = os.environ.get('PROJECT_ROOT')
@@ -81,19 +81,19 @@ def _format_messages(messages):
     message_str = ''
     for message in messages:
         if isinstance(message, str):
-            message_str += message_separator + message if message_str else message
+            message_str += MESSAGE_SEPARATOR + message if message_str else message
         elif isinstance(message, dict):
             if isinstance(message['content'], list):
                 for m in message['content']:
                     if isinstance(m, str):
-                        message_str += message_separator + m if message_str else m
+                        message_str += MESSAGE_SEPARATOR + m if message_str else m
                     elif isinstance(m, dict) and m['type'] == 'text':
                         message_str += (
-                            message_separator + m['text'] if message_str else m['text']
+                            MESSAGE_SEPARATOR + m['text'] if message_str else m['text']
                         )
             elif isinstance(message['content'], str):
                 message_str += (
-                    message_separator + message['content']
+                    MESSAGE_SEPARATOR + message['content']
                     if message_str
                     else message['content']
                 )

+ 10 - 8
tests/unit/test_acompletion.py

@@ -5,7 +5,9 @@ import pytest
 
 from openhands.core.config import load_app_config
 from openhands.core.exceptions import UserCancelledError
+from openhands.llm.async_llm import AsyncLLM
 from openhands.llm.llm import LLM
+from openhands.llm.streaming_llm import StreamingLLM
 
 config = load_app_config()
 
@@ -39,12 +41,12 @@ def mock_response():
 
 @pytest.mark.asyncio
 async def test_acompletion_non_streaming():
-    with patch.object(LLM, '_call_acompletion') as mock_call_acompletion:
+    with patch.object(AsyncLLM, '_call_acompletion') as mock_call_acompletion:
         mock_response = {
             'choices': [{'message': {'content': 'This is a test message.'}}]
         }
         mock_call_acompletion.return_value = mock_response
-        test_llm = LLM(config=config.get_llm_config())
+        test_llm = AsyncLLM(config=config.get_llm_config())
         response = await test_llm.async_completion(
             messages=[{'role': 'user', 'content': 'Hello!'}],
             stream=False,
@@ -56,9 +58,9 @@ async def test_acompletion_non_streaming():
 
 @pytest.mark.asyncio
 async def test_acompletion_streaming(mock_response):
-    with patch.object(LLM, '_call_acompletion') as mock_call_acompletion:
+    with patch.object(StreamingLLM, '_call_acompletion') as mock_call_acompletion:
         mock_call_acompletion.return_value.__aiter__.return_value = iter(mock_response)
-        test_llm = LLM(config=config.get_llm_config())
+        test_llm = StreamingLLM(config=config.get_llm_config())
         async for chunk in test_llm.async_streaming_completion(
             messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
         ):
@@ -104,10 +106,10 @@ async def test_async_completion_with_user_cancellation(cancel_delay):
         return {'choices': [{'message': {'content': 'This is a test message.'}}]}
 
     with patch.object(
-        LLM, '_call_acompletion', new_callable=AsyncMock
+        AsyncLLM, '_call_acompletion', new_callable=AsyncMock
     ) as mock_call_acompletion:
         mock_call_acompletion.side_effect = mock_acompletion
-        test_llm = LLM(config=config.get_llm_config())
+        test_llm = AsyncLLM(config=config.get_llm_config())
 
         async def cancel_after_delay():
             print(f'Starting cancel_after_delay with delay {cancel_delay}')
@@ -166,10 +168,10 @@ async def test_async_streaming_completion_with_user_cancellation(cancel_after_ch
             await asyncio.sleep(0.05)  # Simulate some delay between chunks
 
     with patch.object(
-        LLM, '_call_acompletion', new_callable=AsyncMock
+        AsyncLLM, '_call_acompletion', new_callable=AsyncMock
     ) as mock_call_acompletion:
         mock_call_acompletion.return_value = mock_acompletion()
-        test_llm = LLM(config=config.get_llm_config())
+        test_llm = StreamingLLM(config=config.get_llm_config())
 
         received_chunks = []
         with pytest.raises(UserCancelledError):

+ 4 - 4
tests/unit/test_llm.py

@@ -19,8 +19,8 @@ from openhands.llm.llm import LLM
 def mock_logger(monkeypatch):
     # suppress logging of completion data to file
     mock_logger = MagicMock()
-    monkeypatch.setattr('openhands.llm.llm.llm_prompt_logger', mock_logger)
-    monkeypatch.setattr('openhands.llm.llm.llm_response_logger', mock_logger)
+    monkeypatch.setattr('openhands.llm.debug_mixin.llm_prompt_logger', mock_logger)
+    monkeypatch.setattr('openhands.llm.debug_mixin.llm_response_logger', mock_logger)
     return mock_logger
 
 
@@ -197,8 +197,8 @@ def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config
         mock_sleep.assert_called_once()
         wait_time = mock_sleep.call_args[0][0]
         assert (
-            60 <= wait_time <= 240
-        ), f'Expected wait time between 60 and 240 seconds, but got {wait_time}'
+            default_config.retry_min_wait <= wait_time <= default_config.retry_max_wait
+        ), f'Expected wait time between {default_config.retry_min_wait} and {default_config.retry_max_wait} seconds, but got {wait_time}'
 
 
 @patch('openhands.llm.llm.litellm_completion')