import asyncio from unittest.mock import AsyncMock, MagicMock, Mock from uuid import uuid4 import pytest from openhands.controller.agent import Agent from openhands.controller.agent_controller import AgentController from openhands.controller.state.state import State, TrafficControlState from openhands.core.config import AppConfig from openhands.core.main import run_controller from openhands.core.schema import AgentState from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction from openhands.events.observation import ( ErrorObservation, ) from openhands.events.serialization import event_to_dict from openhands.llm import LLM from openhands.llm.metrics import Metrics from openhands.runtime.base import Runtime from openhands.storage import get_file_store @pytest.fixture def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str: return str(tmp_path_factory.mktemp('test_event_stream')) @pytest.fixture(scope='function') def event_loop(): loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() @pytest.fixture def mock_agent(): return MagicMock(spec=Agent) @pytest.fixture def mock_event_stream(): mock = MagicMock(spec=EventStream) mock.get_latest_event_id.return_value = 0 return mock @pytest.fixture def mock_status_callback(): return AsyncMock() @pytest.mark.asyncio async def test_set_agent_state(mock_agent, mock_event_stream): controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, max_iterations=10, sid='test', confirmation_mode=False, headless_mode=True, ) await controller.set_agent_state_to(AgentState.RUNNING) assert controller.get_agent_state() == AgentState.RUNNING await controller.set_agent_state_to(AgentState.PAUSED) assert controller.get_agent_state() == AgentState.PAUSED await controller.close() @pytest.mark.asyncio async def test_on_event_message_action(mock_agent, mock_event_stream): controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, max_iterations=10, sid='test', confirmation_mode=False, headless_mode=True, ) controller.state.agent_state = AgentState.RUNNING message_action = MessageAction(content='Test message') await controller.on_event(message_action) assert controller.get_agent_state() == AgentState.RUNNING await controller.close() @pytest.mark.asyncio async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream): controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, max_iterations=10, sid='test', confirmation_mode=False, headless_mode=True, ) controller.state.agent_state = AgentState.RUNNING change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED) await controller.on_event(change_state_action) assert controller.get_agent_state() == AgentState.PAUSED await controller.close() @pytest.mark.asyncio async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback): controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, status_callback=mock_status_callback, max_iterations=10, sid='test', confirmation_mode=False, headless_mode=True, ) error_message = 'Test error' await controller._react_to_exception(RuntimeError(error_message)) controller.status_callback.assert_called_once() await controller.close() @pytest.mark.asyncio async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream): config = AppConfig() file_store = get_file_store(config.file_store, config.file_store_path) event_stream = EventStream(sid='test', file_store=file_store) agent = MagicMock(spec=Agent) agent = MagicMock(spec=Agent) def agent_step_fn(state): print(f'agent_step_fn received state: {state}') return CmdRunAction(command='ls') agent.step = agent_step_fn agent.llm = MagicMock(spec=LLM) agent.llm.metrics = Metrics() agent.llm.config = config.get_llm_config() runtime = MagicMock(spec=Runtime) async def on_event(event: Event): if isinstance(event, CmdRunAction): error_obs = ErrorObservation('You messed around with Jim') error_obs._cause = event.id event_stream.add_event(error_obs, EventSource.USER) event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4())) runtime.event_stream = event_stream state = await run_controller( config=config, initial_user_action=MessageAction(content='Test message'), runtime=runtime, sid='test', agent=agent, fake_user_response_fn=lambda _: 'repeat', ) print(f'state: {state}') print(f'event_stream: {list(event_stream.get_events())}') assert state.iteration == 4 assert state.agent_state == AgentState.ERROR assert state.last_error == 'Agent got stuck in a loop' assert len(list(event_stream.get_events())) == 11 @pytest.mark.asyncio async def test_run_controller_stop_with_stuck(): config = AppConfig() file_store = get_file_store(config.file_store, config.file_store_path) event_stream = EventStream(sid='test', file_store=file_store) agent = MagicMock(spec=Agent) def agent_step_fn(state): print(f'agent_step_fn received state: {state}') return CmdRunAction(command='ls') agent.step = agent_step_fn agent.llm = MagicMock(spec=LLM) agent.llm.metrics = Metrics() agent.llm.config = config.get_llm_config() runtime = MagicMock(spec=Runtime) async def on_event(event: Event): if isinstance(event, CmdRunAction): non_fatal_error_obs = ErrorObservation( 'Non fatal error here to trigger loop' ) non_fatal_error_obs._cause = event.id event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT) event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4())) runtime.event_stream = event_stream state = await run_controller( config=config, initial_user_action=MessageAction(content='Test message'), runtime=runtime, sid='test', agent=agent, fake_user_response_fn=lambda _: 'repeat', ) events = list(event_stream.get_events()) print(f'state: {state}') for i, event in enumerate(events): print(f'event {i}: {event_to_dict(event)}') assert state.iteration == 4 assert len(events) == 11 # check the eventstream have 4 pairs of repeated actions and observations repeating_actions_and_observations = events[2:10] for action, observation in zip( repeating_actions_and_observations[0::2], repeating_actions_and_observations[1::2], ): action_dict = event_to_dict(action) observation_dict = event_to_dict(observation) assert action_dict['action'] == 'run' and action_dict['args']['command'] == 'ls' assert ( observation_dict['observation'] == 'error' and observation_dict['content'] == 'Non fatal error here to trigger loop' ) last_event = event_to_dict(events[-1]) assert last_event['extras']['agent_state'] == 'error' assert last_event['observation'] == 'agent_state_changed' assert state.agent_state == AgentState.ERROR assert state.last_error == 'Agent got stuck in a loop' @pytest.mark.asyncio @pytest.mark.parametrize( 'delegate_state', [ AgentState.RUNNING, AgentState.FINISHED, AgentState.ERROR, AgentState.REJECTED, ], ) async def test_delegate_step_different_states( mock_agent, mock_event_stream, delegate_state ): controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, max_iterations=10, sid='test', confirmation_mode=False, headless_mode=True, ) mock_delegate = AsyncMock() controller.delegate = mock_delegate mock_delegate.state.iteration = 5 mock_delegate.state.outputs = {'result': 'test'} mock_delegate.agent.name = 'TestDelegate' mock_delegate.get_agent_state = Mock(return_value=delegate_state) mock_delegate._step = AsyncMock() mock_delegate.close = AsyncMock() await controller._delegate_step() mock_delegate._step.assert_called_once() if delegate_state == AgentState.RUNNING: assert controller.delegate is not None assert controller.state.iteration == 0 mock_delegate.close.assert_not_called() else: assert controller.delegate is None assert controller.state.iteration == 5 mock_delegate.close.assert_called_once() await controller.close() @pytest.mark.asyncio async def test_max_iterations_extension(mock_agent, mock_event_stream): # Test with headless_mode=False - should extend max_iterations initial_state = State(max_iterations=10) controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, max_iterations=10, sid='test', confirmation_mode=False, headless_mode=False, initial_state=initial_state, ) controller.state.agent_state = AgentState.RUNNING controller.state.iteration = 10 assert controller.state.traffic_control_state == TrafficControlState.NORMAL # Trigger throttling by calling _step() when we hit max_iterations await controller._step() assert controller.state.traffic_control_state == TrafficControlState.THROTTLING assert controller.state.agent_state == AgentState.ERROR # Simulate a new user message message_action = MessageAction(content='Test message') message_action._source = EventSource.USER await controller.on_event(message_action) # Max iterations should be extended to current iteration + initial max_iterations assert ( controller.state.max_iterations == 20 ) # Current iteration (10 initial because _step() should not have been executed) + initial max_iterations (10) assert controller.state.traffic_control_state == TrafficControlState.NORMAL assert controller.state.agent_state == AgentState.RUNNING # Close the controller to clean up await controller.close() # Test with headless_mode=True - should NOT extend max_iterations initial_state = State(max_iterations=10) controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, max_iterations=10, sid='test', confirmation_mode=False, headless_mode=True, initial_state=initial_state, ) controller.state.agent_state = AgentState.RUNNING controller.state.iteration = 10 assert controller.state.traffic_control_state == TrafficControlState.NORMAL # Simulate a new user message message_action = MessageAction(content='Test message') message_action._source = EventSource.USER await controller.on_event(message_action) # Max iterations should NOT be extended in headless mode assert controller.state.max_iterations == 10 # Original value unchanged # Trigger throttling by calling _step() when we hit max_iterations await controller._step() assert controller.state.traffic_control_state == TrafficControlState.THROTTLING assert controller.state.agent_state == AgentState.ERROR await controller.close() @pytest.mark.asyncio async def test_step_max_budget(mock_agent, mock_event_stream): controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, max_iterations=10, max_budget_per_task=10, sid='test', confirmation_mode=False, headless_mode=False, ) controller.state.agent_state = AgentState.RUNNING controller.state.metrics.accumulated_cost = 10.1 assert controller.state.traffic_control_state == TrafficControlState.NORMAL await controller._step() assert controller.state.traffic_control_state == TrafficControlState.THROTTLING assert controller.state.agent_state == AgentState.ERROR await controller.close() @pytest.mark.asyncio async def test_step_max_budget_headless(mock_agent, mock_event_stream): controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, max_iterations=10, max_budget_per_task=10, sid='test', confirmation_mode=False, headless_mode=True, ) controller.state.agent_state = AgentState.RUNNING controller.state.metrics.accumulated_cost = 10.1 assert controller.state.traffic_control_state == TrafficControlState.NORMAL await controller._step() assert controller.state.traffic_control_state == TrafficControlState.THROTTLING # In headless mode, throttling results in an error assert controller.state.agent_state == AgentState.ERROR await controller.close()