|
|
@@ -6,10 +6,17 @@ import pytest
|
|
|
from openhands.controller.agent import Agent
|
|
|
from openhands.controller.agent_controller import AgentController
|
|
|
from openhands.controller.state.state import TrafficControlState
|
|
|
+from openhands.core.config import AppConfig
|
|
|
from openhands.core.exceptions import LLMMalformedActionError
|
|
|
+from openhands.core.main import run_controller
|
|
|
from openhands.core.schema import AgentState
|
|
|
-from openhands.events import EventStream
|
|
|
-from openhands.events.action import ChangeAgentStateAction, MessageAction
|
|
|
+from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
|
|
+from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
|
|
|
+from openhands.events.observation import FatalErrorObservation
|
|
|
+from openhands.llm import LLM
|
|
|
+from openhands.llm.metrics import Metrics
|
|
|
+from openhands.runtime.base import Runtime
|
|
|
+from openhands.storage import get_file_store
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
@@ -123,6 +130,53 @@ async def test_step_with_exception(mock_agent, mock_event_stream):
|
|
|
await controller.close()
|
|
|
|
|
|
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
|
|
|
+ config = AppConfig()
|
|
|
+ file_store = get_file_store(config.file_store, config.file_store_path)
|
|
|
+ event_stream = EventStream(sid='test', file_store=file_store)
|
|
|
+
|
|
|
+ agent = MagicMock(spec=Agent)
|
|
|
+ # a random message to send to the runtime
|
|
|
+ event = CmdRunAction(command='ls')
|
|
|
+ agent.step.return_value = event
|
|
|
+ agent.llm = MagicMock(spec=LLM)
|
|
|
+ agent.llm.metrics = Metrics()
|
|
|
+ agent.llm.config = config.get_llm_config()
|
|
|
+
|
|
|
+ fatal_error_obs = FatalErrorObservation('Fatal error detected')
|
|
|
+ fatal_error_obs._cause = event.id
|
|
|
+
|
|
|
+ runtime = MagicMock(spec=Runtime)
|
|
|
+
|
|
|
+ async def on_event(event: Event):
|
|
|
+ if isinstance(event, CmdRunAction):
|
|
|
+ await event_stream.async_add_event(fatal_error_obs, EventSource.USER)
|
|
|
+
|
|
|
+ event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
|
|
|
+ runtime.event_stream = event_stream
|
|
|
+
|
|
|
+ state = await run_controller(
|
|
|
+ config=config,
|
|
|
+ initial_user_action=MessageAction(content='Test message'),
|
|
|
+ runtime=runtime,
|
|
|
+ sid='test',
|
|
|
+ agent=agent,
|
|
|
+ fake_user_response_fn=lambda _: 'repeat',
|
|
|
+ )
|
|
|
+ print(f'state: {state}')
|
|
|
+ print(f'event_stream: {list(event_stream.get_events())}')
|
|
|
+ assert state.iteration == 1
|
|
|
+ # it will first become AgentState.ERROR, then become AgentState.STOPPED
|
|
|
+ # in side run_controller (since the while loop + sleep no longer loop)
|
|
|
+ assert state.agent_state == AgentState.STOPPED
|
|
|
+ assert (
|
|
|
+ state.last_error
|
|
|
+ == 'There was a fatal error during agent execution: **FatalErrorObservation**\nFatal error detected'
|
|
|
+ )
|
|
|
+ assert len(list(event_stream.get_events())) == 5
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.asyncio
|
|
|
@pytest.mark.parametrize(
|
|
|
'delegate_state',
|