Selaa lähdekoodia

(fix) CodeActAgent/LLM: react on should_exit flag (user cancellation) (#3968)

tobitege 1 vuosi sitten
vanhempi
sitoutus
01462e11d7

+ 4 - 0
agenthub/codeact_agent/codeact_agent.py

@@ -5,6 +5,7 @@ from agenthub.codeact_agent.action_parser import CodeActResponseParser
 from openhands.controller.agent import Agent
 from openhands.controller.state.state import State
 from openhands.core.config import AgentConfig
+from openhands.core.exceptions import OperationCancelled
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.message import ImageContent, Message, TextContent
 from openhands.events.action import (
@@ -211,8 +212,11 @@ class CodeActAgent(Agent):
                 'anthropic-beta': 'prompt-caching-2024-07-31',
             }
 
+        # TODO: move exception handling to agent_controller
         try:
             response = self.llm.completion(**params)
+        except OperationCancelled as e:
+            raise e
         except Exception as e:
             logger.error(f'{e}')
             error_message = '{}: {}'.format(type(e).__name__, str(e).split('\n')[0])

+ 7 - 0
openhands/core/exceptions.py

@@ -77,3 +77,10 @@ class UserCancelledError(Exception):
 class MicroAgentValidationError(Exception):
     def __init__(self, message='Micro agent validation failed'):
         super().__init__(message)
+
+
+class OperationCancelled(Exception):
+    """Exception raised when an operation is cancelled (e.g. by a keyboard interrupt)."""
+
+    def __init__(self, message='Operation was cancelled'):
+        super().__init__(message)

+ 41 - 13
openhands/llm/llm.py

@@ -24,15 +24,21 @@ from litellm.types.utils import CostPerToken
 from tenacity import (
     retry,
     retry_if_exception_type,
+    retry_if_not_exception_type,
     stop_after_attempt,
     wait_exponential,
 )
 
-from openhands.core.exceptions import LLMResponseError, UserCancelledError
+from openhands.core.exceptions import (
+    LLMResponseError,
+    OperationCancelled,
+    UserCancelledError,
+)
 from openhands.core.logger import llm_prompt_logger, llm_response_logger
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.message import Message
 from openhands.core.metrics import Metrics
+from openhands.runtime.utils.shutdown_listener import should_exit
 
 __all__ = ['LLM']
 
@@ -169,13 +175,18 @@ class LLM:
 
         completion_unwrapped = self._completion
 
-        def attempt_on_error(retry_state):
-            """Custom attempt function for litellm completion."""
+        def log_retry_attempt(retry_state):
+            """With before_sleep, this is called before `custom_completion_wait` and
+            ONLY if the retry is triggered by an exception."""
+            if should_exit():
+                raise OperationCancelled(
+                    'Operation cancelled.'
+                )  # exits the @retry loop
+            exception = retry_state.outcome.exception()
             logger.error(
-                f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
+                f'{exception}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
                 exc_info=False,
             )
-            return None
 
         def custom_completion_wait(retry_state):
             """Custom wait function for litellm completion."""
@@ -211,10 +222,13 @@ class LLM:
             return exponential_wait(retry_state)
 
         @retry(
-            after=attempt_on_error,
+            before_sleep=log_retry_attempt,
             stop=stop_after_attempt(self.config.num_retries),
             reraise=True,
-            retry=retry_if_exception_type(self.retry_exceptions),
+            retry=(
+                retry_if_exception_type(self.retry_exceptions)
+                & retry_if_not_exception_type(OperationCancelled)
+            ),
             wait=custom_completion_wait,
         )
         def wrapper(*args, **kwargs):
@@ -278,10 +292,13 @@ class LLM:
         async_completion_unwrapped = self._async_completion
 
         @retry(
-            after=attempt_on_error,
+            before_sleep=log_retry_attempt,
             stop=stop_after_attempt(self.config.num_retries),
             reraise=True,
-            retry=retry_if_exception_type(self.retry_exceptions),
+            retry=(
+                retry_if_exception_type(self.retry_exceptions)
+                & retry_if_not_exception_type(OperationCancelled)
+            ),
             wait=custom_completion_wait,
         )
         async def async_completion_wrapper(*args, **kwargs):
@@ -351,10 +368,13 @@ class LLM:
                     pass
 
         @retry(
-            after=attempt_on_error,
+            before_sleep=log_retry_attempt,
             stop=stop_after_attempt(self.config.num_retries),
             reraise=True,
-            retry=retry_if_exception_type(self.retry_exceptions),
+            retry=(
+                retry_if_exception_type(self.retry_exceptions)
+                & retry_if_not_exception_type(OperationCancelled)
+            ),
             wait=custom_completion_wait,
         )
         async def async_acompletion_stream_wrapper(*args, **kwargs):
@@ -448,6 +468,9 @@ class LLM:
         return str(element)
 
     async def _call_acompletion(self, *args, **kwargs):
+        """This is a wrapper for the litellm acompletion function which
+        makes it mockable for testing.
+        """
         return await litellm.acompletion(*args, **kwargs)
 
     @property
@@ -528,10 +551,15 @@ class LLM:
             output_tokens = usage.get('completion_tokens')
 
             if input_tokens:
-                stats += 'Input tokens: ' + str(input_tokens) + '\n'
+                stats += 'Input tokens: ' + str(input_tokens)
 
             if output_tokens:
-                stats += 'Output tokens: ' + str(output_tokens) + '\n'
+                stats += (
+                    (' | ' if input_tokens else '')
+                    + 'Output tokens: '
+                    + str(output_tokens)
+                    + '\n'
+                )
 
             model_extra = usage.get('model_extra', {})
 

+ 221 - 5
tests/unit/test_llm.py

@@ -1,15 +1,38 @@
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
 
 import pytest
+from litellm.exceptions import (
+    APIConnectionError,
+    ContentPolicyViolationError,
+    InternalServerError,
+    OpenAIError,
+    RateLimitError,
+)
 
 from openhands.core.config import LLMConfig
+from openhands.core.exceptions import OperationCancelled
 from openhands.core.metrics import Metrics
 from openhands.llm.llm import LLM
 
 
+@pytest.fixture(autouse=True)
+def mock_logger(monkeypatch):
+    # suppress logging of completion data to file
+    mock_logger = MagicMock()
+    monkeypatch.setattr('openhands.llm.llm.llm_prompt_logger', mock_logger)
+    monkeypatch.setattr('openhands.llm.llm.llm_response_logger', mock_logger)
+    return mock_logger
+
+
 @pytest.fixture
 def default_config():
-    return LLMConfig(model='gpt-4o', api_key='test_key')
+    return LLMConfig(
+        model='gpt-4o',
+        api_key='test_key',
+        num_retries=2,
+        retry_min_wait=1,
+        retry_max_wait=2,
+    )
 
 
 def test_llm_init_with_default_config(default_config):
@@ -64,7 +87,7 @@ def test_llm_init_with_metrics():
 
 
 def test_llm_reset():
-    llm = LLM(LLMConfig(model='gpt-3.5-turbo', api_key='test_key'))
+    llm = LLM(LLMConfig(model='gpt-4o-mini', api_key='test_key'))
     initial_metrics = llm.metrics
     llm.reset()
     assert llm.metrics is not initial_metrics
@@ -73,7 +96,7 @@ def test_llm_reset():
 
 @patch('openhands.llm.llm.litellm.get_model_info')
 def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
-    default_config.model = 'openrouter:gpt-3.5-turbo'
+    default_config.model = 'openrouter:gpt-4o-mini'
     mock_get_model_info.return_value = {
         'max_input_tokens': 7000,
         'max_output_tokens': 1500,
@@ -81,4 +104,197 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
     llm = LLM(default_config)
     assert llm.config.max_input_tokens == 7000
     assert llm.config.max_output_tokens == 1500
-    mock_get_model_info.assert_called_once_with('openrouter:gpt-3.5-turbo')
+    mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini')
+
+
+# Tests involving completion and retries
+
+
+@patch('openhands.llm.llm.litellm_completion')
+def test_completion_with_mocked_logger(
+    mock_litellm_completion, default_config, mock_logger
+):
+    mock_litellm_completion.return_value = {
+        'choices': [{'message': {'content': 'Test response'}}]
+    }
+
+    llm = LLM(config=default_config)
+    response = llm.completion(
+        messages=[{'role': 'user', 'content': 'Hello!'}],
+        stream=False,
+    )
+
+    assert response['choices'][0]['message']['content'] == 'Test response'
+    assert mock_litellm_completion.call_count == 1
+
+    mock_logger.debug.assert_called()
+
+
+@pytest.mark.parametrize(
+    'exception_class,extra_args,expected_retries',
+    [
+        (
+            APIConnectionError,
+            {'llm_provider': 'test_provider', 'model': 'test_model'},
+            2,
+        ),
+        (
+            ContentPolicyViolationError,
+            {'model': 'test_model', 'llm_provider': 'test_provider'},
+            2,
+        ),
+        (
+            InternalServerError,
+            {'llm_provider': 'test_provider', 'model': 'test_model'},
+            2,
+        ),
+        (OpenAIError, {}, 2),
+        (RateLimitError, {'llm_provider': 'test_provider', 'model': 'test_model'}, 2),
+    ],
+)
+@patch('openhands.llm.llm.litellm_completion')
+def test_completion_retries(
+    mock_litellm_completion,
+    default_config,
+    exception_class,
+    extra_args,
+    expected_retries,
+):
+    mock_litellm_completion.side_effect = [
+        exception_class('Test error message', **extra_args),
+        {'choices': [{'message': {'content': 'Retry successful'}}]},
+    ]
+
+    llm = LLM(config=default_config)
+    response = llm.completion(
+        messages=[{'role': 'user', 'content': 'Hello!'}],
+        stream=False,
+    )
+
+    assert response['choices'][0]['message']['content'] == 'Retry successful'
+    assert mock_litellm_completion.call_count == expected_retries
+
+
+@patch('openhands.llm.llm.litellm_completion')
+def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config):
+    with patch('time.sleep') as mock_sleep:
+        mock_litellm_completion.side_effect = [
+            RateLimitError(
+                'Rate limit exceeded', llm_provider='test_provider', model='test_model'
+            ),
+            {'choices': [{'message': {'content': 'Retry successful'}}]},
+        ]
+
+        llm = LLM(config=default_config)
+        response = llm.completion(
+            messages=[{'role': 'user', 'content': 'Hello!'}],
+            stream=False,
+        )
+
+        assert response['choices'][0]['message']['content'] == 'Retry successful'
+        assert mock_litellm_completion.call_count == 2
+
+        mock_sleep.assert_called_once()
+        wait_time = mock_sleep.call_args[0][0]
+        assert (
+            60 <= wait_time <= 240
+        ), f'Expected wait time between 60 and 240 seconds, but got {wait_time}'
+
+
+@patch('openhands.llm.llm.litellm_completion')
+def test_completion_exhausts_retries(mock_litellm_completion, default_config):
+    mock_litellm_completion.side_effect = APIConnectionError(
+        'Persistent error', llm_provider='test_provider', model='test_model'
+    )
+
+    llm = LLM(config=default_config)
+    with pytest.raises(APIConnectionError):
+        llm.completion(
+            messages=[{'role': 'user', 'content': 'Hello!'}],
+            stream=False,
+        )
+
+    assert mock_litellm_completion.call_count == llm.config.num_retries
+
+
+@patch('openhands.llm.llm.litellm_completion')
+def test_completion_operation_cancelled(mock_litellm_completion, default_config):
+    mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
+
+    llm = LLM(config=default_config)
+    with pytest.raises(OperationCancelled):
+        llm.completion(
+            messages=[{'role': 'user', 'content': 'Hello!'}],
+            stream=False,
+        )
+
+    assert mock_litellm_completion.call_count == 1
+
+
+@patch('openhands.llm.llm.litellm_completion')
+def test_completion_keyboard_interrupt(mock_litellm_completion, default_config):
+    def side_effect(*args, **kwargs):
+        raise KeyboardInterrupt('Simulated KeyboardInterrupt')
+
+    mock_litellm_completion.side_effect = side_effect
+
+    llm = LLM(config=default_config)
+    with pytest.raises(OperationCancelled):
+        try:
+            llm.completion(
+                messages=[{'role': 'user', 'content': 'Hello!'}],
+                stream=False,
+            )
+        except KeyboardInterrupt:
+            raise OperationCancelled('Operation cancelled due to KeyboardInterrupt')
+
+    assert mock_litellm_completion.call_count == 1
+
+
+@patch('openhands.llm.llm.litellm_completion')
+def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_config):
+    global _should_exit
+
+    def side_effect(*args, **kwargs):
+        global _should_exit
+        _should_exit = True
+        return {'choices': [{'message': {'content': 'Simulated interrupt response'}}]}
+
+    mock_litellm_completion.side_effect = side_effect
+
+    llm = LLM(config=default_config)
+    result = llm.completion(
+        messages=[{'role': 'user', 'content': 'Hello!'}],
+        stream=False,
+    )
+
+    assert mock_litellm_completion.call_count == 1
+    assert result['choices'][0]['message']['content'] == 'Simulated interrupt response'
+    assert _should_exit
+
+    _should_exit = False
+
+
+@patch('openhands.llm.llm.litellm_completion')
+def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
+    mock_response = {
+        'choices': [{'message': {'content': 'This is a mocked response.'}}]
+    }
+    mock_litellm_completion.return_value = mock_response
+
+    test_llm = LLM(config=default_config)
+    response = test_llm.completion(
+        messages=[{'role': 'user', 'content': 'Hello!'}],
+        stream=False,
+        drop_params=True,
+    )
+
+    # Assertions
+    assert response['choices'][0]['message']['content'] == 'This is a mocked response.'
+    mock_litellm_completion.assert_called_once()
+
+    # Check if the correct arguments were passed to litellm_completion
+    call_args = mock_litellm_completion.call_args[1]  # Get keyword arguments
+    assert call_args['model'] == default_config.model
+    assert call_args['messages'] == [{'role': 'user', 'content': 'Hello!'}]
+    assert not call_args['stream']