Bladeren bron

fix(controller): stop when run into loop (#4579)

Xingyao Wang 1 jaar geleden
bovenliggende
commit
98d4884ced
2 gewijzigde bestanden met toevoegingen van 88 en 6 verwijderingen
  1. 11 5
      openhands/controller/agent_controller.py
  2. 77 1
      tests/unit/test_agent_controller.py

+ 11 - 5
openhands/controller/agent_controller.py

@@ -403,9 +403,20 @@ class AgentController:
             return
 
         if self._pending_action:
+            logger.debug(
+                f'{self.agent.name} LEVEL {self.state.delegate_level} LOCAL STEP {self.state.local_iteration} GLOBAL STEP {self.state.iteration} awaiting pending action to get executed: {self._pending_action}'
+            )
             await asyncio.sleep(1)
             return
 
+        # check if agent got stuck before taking any action
+        if self._is_stuck():
+            # This need to go BEFORE report_error to sync metrics
+            self.event_stream.add_event(
+                FatalErrorObservation('Agent got stuck in a loop'), EventSource.USER
+            )
+            return
+
         if self.delegate is not None:
             assert self.delegate != self
             if self.delegate.get_agent_state() == AgentState.PAUSED:
@@ -467,11 +478,6 @@ class AgentController:
         await self.update_state_after_step()
         logger.info(action, extra={'msg_type': 'ACTION'})
 
-        if self._is_stuck():
-            # This need to go BEFORE report_error to sync metrics
-            await self.set_agent_state_to(AgentState.ERROR)
-            await self.report_error('Agent got stuck in a loop')
-
     async def _delegate_step(self):
         """Executes a single step of the delegate agent."""
         logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')

+ 77 - 1
tests/unit/test_agent_controller.py

@@ -12,7 +12,11 @@ from openhands.core.main import run_controller
 from openhands.core.schema import AgentState
 from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
 from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
-from openhands.events.observation import FatalErrorObservation
+from openhands.events.observation import (
+    ErrorObservation,
+    FatalErrorObservation,
+)
+from openhands.events.serialization import event_to_dict
 from openhands.llm import LLM
 from openhands.llm.metrics import Metrics
 from openhands.runtime.base import Runtime
@@ -177,6 +181,78 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
     assert len(list(event_stream.get_events())) == 5
 
 
+@pytest.mark.asyncio
+async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
+    config = AppConfig()
+    file_store = get_file_store(config.file_store, config.file_store_path)
+    event_stream = EventStream(sid='test', file_store=file_store)
+
+    agent = MagicMock(spec=Agent)
+    # a random message to send to the runtime
+    event = CmdRunAction(command='ls')
+
+    def agent_step_fn(state):
+        print(f'agent_step_fn received state: {state}')
+        return event
+
+    agent.step = agent_step_fn
+    agent.llm = MagicMock(spec=LLM)
+    agent.llm.metrics = Metrics()
+    agent.llm.config = config.get_llm_config()
+    runtime = MagicMock(spec=Runtime)
+
+    async def on_event(event: Event):
+        if isinstance(event, CmdRunAction):
+            non_fatal_error_obs = ErrorObservation(
+                'Non fatal error here to trigger loop'
+            )
+            non_fatal_error_obs._cause = event.id
+            await event_stream.async_add_event(non_fatal_error_obs, EventSource.USER)
+
+    event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
+    runtime.event_stream = event_stream
+
+    state = await run_controller(
+        config=config,
+        initial_user_action=MessageAction(content='Test message'),
+        runtime=runtime,
+        sid='test',
+        agent=agent,
+        fake_user_response_fn=lambda _: 'repeat',
+    )
+    events = list(event_stream.get_events())
+    print(f'state: {state}')
+    for i, event in enumerate(events):
+        print(f'event {i}: {event_to_dict(event)}')
+
+    assert state.iteration == 4
+    assert len(events) == 12
+    # check the eventstream have 4 pairs of repeated actions and observations
+    repeating_actions_and_observations = events[2:10]
+    for action, observation in zip(
+        repeating_actions_and_observations[0::2],
+        repeating_actions_and_observations[1::2],
+    ):
+        action_dict = event_to_dict(action)
+        observation_dict = event_to_dict(observation)
+        assert action_dict['action'] == 'run' and action_dict['args']['command'] == 'ls'
+        assert (
+            observation_dict['observation'] == 'error'
+            and observation_dict['content'] == 'Non fatal error here to trigger loop'
+        )
+    last_event = event_to_dict(events[-1])
+    assert last_event['extras']['agent_state'] == 'error'
+    assert last_event['observation'] == 'agent_state_changed'
+
+    # it will first become AgentState.ERROR, then become AgentState.STOPPED
+    # in side run_controller (since the while loop + sleep no longer loop)
+    assert state.agent_state == AgentState.STOPPED
+    assert (
+        state.last_error
+        == 'There was a fatal error during agent execution: **FatalErrorObservation**\nAgent got stuck in a loop'
+    )
+
+
 @pytest.mark.asyncio
 @pytest.mark.parametrize(
     'delegate_state',