Просмотр исходного кода

fix(controller): stop when run into loop (#4579)

Xingyao Wang 1 год назад
Родитель
Сommit
98d4884ced
2 измененных файлов с 88 добавлено и 6 удалено
  1. 11 5
      openhands/controller/agent_controller.py
  2. 77 1
      tests/unit/test_agent_controller.py

+ 11 - 5
openhands/controller/agent_controller.py

@@ -403,9 +403,20 @@ class AgentController:
             return
             return
 
 
         if self._pending_action:
         if self._pending_action:
+            logger.debug(
+                f'{self.agent.name} LEVEL {self.state.delegate_level} LOCAL STEP {self.state.local_iteration} GLOBAL STEP {self.state.iteration} awaiting pending action to get executed: {self._pending_action}'
+            )
             await asyncio.sleep(1)
             await asyncio.sleep(1)
             return
             return
 
 
+        # check if agent got stuck before taking any action
+        if self._is_stuck():
+            # This need to go BEFORE report_error to sync metrics
+            self.event_stream.add_event(
+                FatalErrorObservation('Agent got stuck in a loop'), EventSource.USER
+            )
+            return
+
         if self.delegate is not None:
         if self.delegate is not None:
             assert self.delegate != self
             assert self.delegate != self
             if self.delegate.get_agent_state() == AgentState.PAUSED:
             if self.delegate.get_agent_state() == AgentState.PAUSED:
@@ -467,11 +478,6 @@ class AgentController:
         await self.update_state_after_step()
         await self.update_state_after_step()
         logger.info(action, extra={'msg_type': 'ACTION'})
         logger.info(action, extra={'msg_type': 'ACTION'})
 
 
-        if self._is_stuck():
-            # This need to go BEFORE report_error to sync metrics
-            await self.set_agent_state_to(AgentState.ERROR)
-            await self.report_error('Agent got stuck in a loop')
-
     async def _delegate_step(self):
     async def _delegate_step(self):
         """Executes a single step of the delegate agent."""
         """Executes a single step of the delegate agent."""
         logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
         logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')

+ 77 - 1
tests/unit/test_agent_controller.py

@@ -12,7 +12,11 @@ from openhands.core.main import run_controller
 from openhands.core.schema import AgentState
 from openhands.core.schema import AgentState
 from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
 from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
 from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
 from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
-from openhands.events.observation import FatalErrorObservation
+from openhands.events.observation import (
+    ErrorObservation,
+    FatalErrorObservation,
+)
+from openhands.events.serialization import event_to_dict
 from openhands.llm import LLM
 from openhands.llm import LLM
 from openhands.llm.metrics import Metrics
 from openhands.llm.metrics import Metrics
 from openhands.runtime.base import Runtime
 from openhands.runtime.base import Runtime
@@ -177,6 +181,78 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
     assert len(list(event_stream.get_events())) == 5
     assert len(list(event_stream.get_events())) == 5
 
 
 
 
+@pytest.mark.asyncio
+async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
+    config = AppConfig()
+    file_store = get_file_store(config.file_store, config.file_store_path)
+    event_stream = EventStream(sid='test', file_store=file_store)
+
+    agent = MagicMock(spec=Agent)
+    # a random message to send to the runtime
+    event = CmdRunAction(command='ls')
+
+    def agent_step_fn(state):
+        print(f'agent_step_fn received state: {state}')
+        return event
+
+    agent.step = agent_step_fn
+    agent.llm = MagicMock(spec=LLM)
+    agent.llm.metrics = Metrics()
+    agent.llm.config = config.get_llm_config()
+    runtime = MagicMock(spec=Runtime)
+
+    async def on_event(event: Event):
+        if isinstance(event, CmdRunAction):
+            non_fatal_error_obs = ErrorObservation(
+                'Non fatal error here to trigger loop'
+            )
+            non_fatal_error_obs._cause = event.id
+            await event_stream.async_add_event(non_fatal_error_obs, EventSource.USER)
+
+    event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
+    runtime.event_stream = event_stream
+
+    state = await run_controller(
+        config=config,
+        initial_user_action=MessageAction(content='Test message'),
+        runtime=runtime,
+        sid='test',
+        agent=agent,
+        fake_user_response_fn=lambda _: 'repeat',
+    )
+    events = list(event_stream.get_events())
+    print(f'state: {state}')
+    for i, event in enumerate(events):
+        print(f'event {i}: {event_to_dict(event)}')
+
+    assert state.iteration == 4
+    assert len(events) == 12
+    # check the eventstream have 4 pairs of repeated actions and observations
+    repeating_actions_and_observations = events[2:10]
+    for action, observation in zip(
+        repeating_actions_and_observations[0::2],
+        repeating_actions_and_observations[1::2],
+    ):
+        action_dict = event_to_dict(action)
+        observation_dict = event_to_dict(observation)
+        assert action_dict['action'] == 'run' and action_dict['args']['command'] == 'ls'
+        assert (
+            observation_dict['observation'] == 'error'
+            and observation_dict['content'] == 'Non fatal error here to trigger loop'
+        )
+    last_event = event_to_dict(events[-1])
+    assert last_event['extras']['agent_state'] == 'error'
+    assert last_event['observation'] == 'agent_state_changed'
+
+    # it will first become AgentState.ERROR, then become AgentState.STOPPED
+    # in side run_controller (since the while loop + sleep no longer loop)
+    assert state.agent_state == AgentState.STOPPED
+    assert (
+        state.last_error
+        == 'There was a fatal error during agent execution: **FatalErrorObservation**\nAgent got stuck in a loop'
+    )
+
+
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     'delegate_state',
     'delegate_state',