Pārlūkot izejas kodu

Fix issue #5559: The turn limit should be measured from the last user interaction (#5560)

Co-authored-by: Graham Neubig <neubig@gmail.com>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
OpenHands 11 mēneši atpakaļ
vecāks
revīzija
4998b5de32

+ 1 - 2
evaluation/benchmarks/swe_bench/run_infer.py

@@ -9,7 +9,6 @@ import toml
 from datasets import load_dataset
 
 import openhands.agenthub
-
 from evaluation.utils.shared import (
     EvalException,
     EvalMetadata,
@@ -76,7 +75,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
         '4. Rerun your reproduce script and confirm that the error is fixed!\n'
         '5. Think about edgecases and make sure your fix handles them as well\n'
         "Your thinking should be thorough and so it's fine if it's very long.\n"
-        )
+    )
 
     if RUN_WITH_BROWSING:
         instruction += (

+ 16 - 0
openhands/controller/agent_controller.py

@@ -312,6 +312,20 @@ class AgentController:
                 str(action),
                 extra={'msg_type': 'ACTION', 'event_source': EventSource.USER},
             )
+            # Extend max iterations when the user sends a message (only in non-headless mode)
+            if self._initial_max_iterations is not None and not self.headless_mode:
+                self.state.max_iterations = (
+                    self.state.iteration + self._initial_max_iterations
+                )
+                if (
+                    self.state.traffic_control_state == TrafficControlState.THROTTLING
+                    or self.state.traffic_control_state == TrafficControlState.PAUSED
+                ):
+                    self.state.traffic_control_state = TrafficControlState.NORMAL
+                self.log(
+                    'debug',
+                    f'Extended max iterations to {self.state.max_iterations} after user message',
+                )
             if self.get_agent_state() != AgentState.RUNNING:
                 await self.set_agent_state_to(AgentState.RUNNING)
         elif action.source == EventSource.AGENT and action.wait_for_response:
@@ -342,6 +356,7 @@ class AgentController:
         elif (
             new_state == AgentState.RUNNING
             and self.state.agent_state == AgentState.PAUSED
+            # TODO: do we really need both THROTTLING and PAUSED states, or can we clean up one of them completely?
             and self.state.traffic_control_state == TrafficControlState.THROTTLING
         ):
             # user intends to interrupt traffic control and let the task resume temporarily
@@ -351,6 +366,7 @@ class AgentController:
                 self.state.iteration is not None
                 and self.state.max_iterations is not None
                 and self._initial_max_iterations is not None
+                and not self.headless_mode
             ):
                 if self.state.iteration >= self.state.max_iterations:
                     self.state.max_iterations += self._initial_max_iterations

+ 2 - 1
pyproject.toml

@@ -100,6 +100,7 @@ reportlab = "*"
 [tool.coverage.run]
 concurrency = ["gevent"]
 
+
 [tool.poetry.group.runtime.dependencies]
 jupyterlab = "*"
 notebook = "*"
@@ -107,7 +108,6 @@ jupyter_kernel_gateway = "*"
 flake8 = "*"
 opencv-python = "*"
 
-
 [build-system]
 build-backend = "poetry.core.masonry.api"
 requires = [
@@ -130,6 +130,7 @@ ignore = ["D1"]
 [tool.ruff.lint.pydocstyle]
 convention = "google"
 
+
 [tool.poetry.group.evaluation.dependencies]
 streamlit = "*"
 whatthepatch = "*"

+ 38 - 7
tests/unit/test_agent_controller.py

@@ -6,7 +6,7 @@ import pytest
 
 from openhands.controller.agent import Agent
 from openhands.controller.agent_controller import AgentController
-from openhands.controller.state.state import TrafficControlState
+from openhands.controller.state.state import State, TrafficControlState
 from openhands.core.config import AppConfig
 from openhands.core.main import run_controller
 from openhands.core.schema import AgentState
@@ -41,7 +41,9 @@ def mock_agent():
 
 @pytest.fixture
 def mock_event_stream():
-    return MagicMock(spec=EventStream)
+    mock = MagicMock(spec=EventStream)
+    mock.get_latest_event_id.return_value = 0
+    return mock
 
 
 @pytest.fixture
@@ -278,7 +280,9 @@ async def test_delegate_step_different_states(
 
 
 @pytest.mark.asyncio
-async def test_step_max_iterations(mock_agent, mock_event_stream):
+async def test_max_iterations_extension(mock_agent, mock_event_stream):
+    # Test with headless_mode=False - should extend max_iterations
+    initial_state = State(max_iterations=10)
     controller = AgentController(
         agent=mock_agent,
         event_stream=mock_event_stream,
@@ -286,18 +290,34 @@ async def test_step_max_iterations(mock_agent, mock_event_stream):
         sid='test',
         confirmation_mode=False,
         headless_mode=False,
+        initial_state=initial_state,
     )
     controller.state.agent_state = AgentState.RUNNING
     controller.state.iteration = 10
     assert controller.state.traffic_control_state == TrafficControlState.NORMAL
+
+    # Trigger throttling by calling _step() when we hit max_iterations
     await controller._step()
     assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
     assert controller.state.agent_state == AgentState.ERROR
-    await controller.close()
 
+    # Simulate a new user message
+    message_action = MessageAction(content='Test message')
+    message_action._source = EventSource.USER
+    await controller.on_event(message_action)
+
+    # Max iterations should be extended to current iteration + initial max_iterations
+    assert (
+        controller.state.max_iterations == 20
+    )  # Current iteration (10 initial because _step() should not have been executed) + initial max_iterations (10)
+    assert controller.state.traffic_control_state == TrafficControlState.NORMAL
+    assert controller.state.agent_state == AgentState.RUNNING
 
-@pytest.mark.asyncio
-async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
+    # Close the controller to clean up
+    await controller.close()
+
+    # Test with headless_mode=True - should NOT extend max_iterations
+    initial_state = State(max_iterations=10)
     controller = AgentController(
         agent=mock_agent,
         event_stream=mock_event_stream,
@@ -305,13 +325,24 @@ async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
         sid='test',
         confirmation_mode=False,
         headless_mode=True,
+        initial_state=initial_state,
     )
     controller.state.agent_state = AgentState.RUNNING
     controller.state.iteration = 10
     assert controller.state.traffic_control_state == TrafficControlState.NORMAL
+
+    # Simulate a new user message
+    message_action = MessageAction(content='Test message')
+    message_action._source = EventSource.USER
+    await controller.on_event(message_action)
+
+    # Max iterations should NOT be extended in headless mode
+    assert controller.state.max_iterations == 10  # Original value unchanged
+
+    # Trigger throttling by calling _step() when we hit max_iterations
     await controller._step()
+
     assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
-    # In headless mode, throttling results in an error
     assert controller.state.agent_state == AgentState.ERROR
     await controller.close()
 

+ 62 - 0
tests/unit/test_iteration_limit.py

@@ -0,0 +1,62 @@
+import asyncio
+
+import pytest
+
+from openhands.controller.agent_controller import AgentController
+from openhands.core.schema import AgentState
+from openhands.events import EventStream
+from openhands.events.action import MessageAction
+from openhands.events.event import EventSource
+
+
+class DummyAgent:
+    def __init__(self):
+        self.name = 'dummy'
+        self.llm = type(
+            'DummyLLM',
+            (),
+            {'metrics': type('DummyMetrics', (), {'merge': lambda x: None})()},
+        )()
+
+    def reset(self):
+        pass
+
+
+@pytest.mark.asyncio
+async def test_iteration_limit_extends_on_user_message():
+    # Initialize test components
+    from openhands.storage.memory import InMemoryFileStore
+
+    file_store = InMemoryFileStore()
+    event_stream = EventStream(sid='test', file_store=file_store)
+    agent = DummyAgent()
+    initial_max_iterations = 100
+    controller = AgentController(
+        agent=agent,
+        event_stream=event_stream,
+        max_iterations=initial_max_iterations,
+        sid='test',
+        headless_mode=False,
+    )
+
+    # Set initial state
+    await controller.set_agent_state_to(AgentState.RUNNING)
+    controller.state.iteration = 90  # Close to the limit
+    assert controller.state.max_iterations == initial_max_iterations
+
+    # Simulate user message
+    user_message = MessageAction('test message', EventSource.USER)
+    event_stream.add_event(user_message, EventSource.USER)
+    await asyncio.sleep(0.1)  # Give time for event to be processed
+
+    # Verify max_iterations was extended
+    assert controller.state.max_iterations == 90 + initial_max_iterations
+
+    # Simulate more iterations and another user message
+    controller.state.iteration = 180  # Close to new limit
+    user_message2 = MessageAction('another message', EventSource.USER)
+    event_stream.add_event(user_message2, EventSource.USER)
+    await asyncio.sleep(0.1)  # Give time for event to be processed
+
+    # Verify max_iterations was extended again
+    assert controller.state.max_iterations == 180 + initial_max_iterations