|
|
@@ -12,7 +12,11 @@ from openhands.core.main import run_controller
|
|
|
from openhands.core.schema import AgentState
|
|
|
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
|
|
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
|
|
|
-from openhands.events.observation import FatalErrorObservation
|
|
|
+from openhands.events.observation import (
|
|
|
+ ErrorObservation,
|
|
|
+ FatalErrorObservation,
|
|
|
+)
|
|
|
+from openhands.events.serialization import event_to_dict
|
|
|
from openhands.llm import LLM
|
|
|
from openhands.llm.metrics import Metrics
|
|
|
from openhands.runtime.base import Runtime
|
|
|
@@ -177,6 +181,78 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
|
|
|
assert len(list(event_stream.get_events())) == 5
|
|
|
|
|
|
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
|
|
|
+ config = AppConfig()
|
|
|
+ file_store = get_file_store(config.file_store, config.file_store_path)
|
|
|
+ event_stream = EventStream(sid='test', file_store=file_store)
|
|
|
+
|
|
|
+ agent = MagicMock(spec=Agent)
|
|
|
+ # a random message to send to the runtime
|
|
|
+ event = CmdRunAction(command='ls')
|
|
|
+
|
|
|
+ def agent_step_fn(state):
|
|
|
+ print(f'agent_step_fn received state: {state}')
|
|
|
+ return event
|
|
|
+
|
|
|
+ agent.step = agent_step_fn
|
|
|
+ agent.llm = MagicMock(spec=LLM)
|
|
|
+ agent.llm.metrics = Metrics()
|
|
|
+ agent.llm.config = config.get_llm_config()
|
|
|
+ runtime = MagicMock(spec=Runtime)
|
|
|
+
|
|
|
+ async def on_event(event: Event):
|
|
|
+ if isinstance(event, CmdRunAction):
|
|
|
+ non_fatal_error_obs = ErrorObservation(
|
|
|
+ 'Non fatal error here to trigger loop'
|
|
|
+ )
|
|
|
+ non_fatal_error_obs._cause = event.id
|
|
|
+ await event_stream.async_add_event(non_fatal_error_obs, EventSource.USER)
|
|
|
+
|
|
|
+ event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
|
|
|
+ runtime.event_stream = event_stream
|
|
|
+
|
|
|
+ state = await run_controller(
|
|
|
+ config=config,
|
|
|
+ initial_user_action=MessageAction(content='Test message'),
|
|
|
+ runtime=runtime,
|
|
|
+ sid='test',
|
|
|
+ agent=agent,
|
|
|
+ fake_user_response_fn=lambda _: 'repeat',
|
|
|
+ )
|
|
|
+ events = list(event_stream.get_events())
|
|
|
+ print(f'state: {state}')
|
|
|
+ for i, event in enumerate(events):
|
|
|
+ print(f'event {i}: {event_to_dict(event)}')
|
|
|
+
|
|
|
+ assert state.iteration == 4
|
|
|
+ assert len(events) == 12
|
|
|
+ # check the eventstream have 4 pairs of repeated actions and observations
|
|
|
+ repeating_actions_and_observations = events[2:10]
|
|
|
+ for action, observation in zip(
|
|
|
+ repeating_actions_and_observations[0::2],
|
|
|
+ repeating_actions_and_observations[1::2],
|
|
|
+ ):
|
|
|
+ action_dict = event_to_dict(action)
|
|
|
+ observation_dict = event_to_dict(observation)
|
|
|
+ assert action_dict['action'] == 'run' and action_dict['args']['command'] == 'ls'
|
|
|
+ assert (
|
|
|
+ observation_dict['observation'] == 'error'
|
|
|
+ and observation_dict['content'] == 'Non fatal error here to trigger loop'
|
|
|
+ )
|
|
|
+ last_event = event_to_dict(events[-1])
|
|
|
+ assert last_event['extras']['agent_state'] == 'error'
|
|
|
+ assert last_event['observation'] == 'agent_state_changed'
|
|
|
+
|
|
|
+ # it will first become AgentState.ERROR, then become AgentState.STOPPED
|
|
|
+ # in side run_controller (since the while loop + sleep no longer loop)
|
|
|
+ assert state.agent_state == AgentState.STOPPED
|
|
|
+ assert (
|
|
|
+ state.last_error
|
|
|
+ == 'There was a fatal error during agent execution: **FatalErrorObservation**\nAgent got stuck in a loop'
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.asyncio
|
|
|
@pytest.mark.parametrize(
|
|
|
'delegate_state',
|