ソースを参照

fix(controllor): make agent controller stops when encounter fatal observation (#4573)

Xingyao Wang 1 年間 前
コミット
be3cbb045e

+ 3 - 3
openhands/controller/agent_controller.py

@@ -254,11 +254,11 @@ class AgentController:
             if self.state.agent_state == AgentState.ERROR:
             if self.state.agent_state == AgentState.ERROR:
                 self.state.metrics.merge(self.state.local_metrics)
                 self.state.metrics.merge(self.state.local_metrics)
         elif isinstance(observation, FatalErrorObservation):
         elif isinstance(observation, FatalErrorObservation):
-            await self.report_error(
-                'There was a fatal error during agent execution: ' + str(observation)
+            self.state.last_error = (
+                f'There was a fatal error during agent execution: {str(observation)}'
             )
             )
-            await self.set_agent_state_to(AgentState.ERROR)
             self.state.metrics.merge(self.state.local_metrics)
             self.state.metrics.merge(self.state.local_metrics)
+            await self.set_agent_state_to(AgentState.ERROR)
 
 
     async def _handle_message_action(self, action: MessageAction):
     async def _handle_message_action(self, action: MessageAction):
         """Handles message actions from the event stream.
         """Handles message actions from the event stream.

+ 56 - 2
tests/unit/test_agent_controller.py

@@ -6,10 +6,17 @@ import pytest
 from openhands.controller.agent import Agent
 from openhands.controller.agent import Agent
 from openhands.controller.agent_controller import AgentController
 from openhands.controller.agent_controller import AgentController
 from openhands.controller.state.state import TrafficControlState
 from openhands.controller.state.state import TrafficControlState
+from openhands.core.config import AppConfig
 from openhands.core.exceptions import LLMMalformedActionError
 from openhands.core.exceptions import LLMMalformedActionError
+from openhands.core.main import run_controller
 from openhands.core.schema import AgentState
 from openhands.core.schema import AgentState
-from openhands.events import EventStream
-from openhands.events.action import ChangeAgentStateAction, MessageAction
+from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
+from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
+from openhands.events.observation import FatalErrorObservation
+from openhands.llm import LLM
+from openhands.llm.metrics import Metrics
+from openhands.runtime.base import Runtime
+from openhands.storage import get_file_store
 
 
 
 
 @pytest.fixture
 @pytest.fixture
@@ -123,6 +130,53 @@ async def test_step_with_exception(mock_agent, mock_event_stream):
     await controller.close()
     await controller.close()
 
 
 
 
+@pytest.mark.asyncio
+async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
+    config = AppConfig()
+    file_store = get_file_store(config.file_store, config.file_store_path)
+    event_stream = EventStream(sid='test', file_store=file_store)
+
+    agent = MagicMock(spec=Agent)
+    # a random message to send to the runtime
+    event = CmdRunAction(command='ls')
+    agent.step.return_value = event
+    agent.llm = MagicMock(spec=LLM)
+    agent.llm.metrics = Metrics()
+    agent.llm.config = config.get_llm_config()
+
+    fatal_error_obs = FatalErrorObservation('Fatal error detected')
+    fatal_error_obs._cause = event.id
+
+    runtime = MagicMock(spec=Runtime)
+
+    async def on_event(event: Event):
+        if isinstance(event, CmdRunAction):
+            await event_stream.async_add_event(fatal_error_obs, EventSource.USER)
+
+    event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
+    runtime.event_stream = event_stream
+
+    state = await run_controller(
+        config=config,
+        initial_user_action=MessageAction(content='Test message'),
+        runtime=runtime,
+        sid='test',
+        agent=agent,
+        fake_user_response_fn=lambda _: 'repeat',
+    )
+    print(f'state: {state}')
+    print(f'event_stream: {list(event_stream.get_events())}')
+    assert state.iteration == 1
+    # it will first become AgentState.ERROR, then become AgentState.STOPPED
+    # in side run_controller (since the while loop + sleep no longer loop)
+    assert state.agent_state == AgentState.STOPPED
+    assert (
+        state.last_error
+        == 'There was a fatal error during agent execution: **FatalErrorObservation**\nFatal error detected'
+    )
+    assert len(list(event_stream.get_events())) == 5
+
+
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     'delegate_state',
     'delegate_state',