| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388 |
- import asyncio
- from unittest.mock import AsyncMock, MagicMock, Mock
- from uuid import uuid4
- import pytest
- from openhands.controller.agent import Agent
- from openhands.controller.agent_controller import AgentController
- from openhands.controller.state.state import State, TrafficControlState
- from openhands.core.config import AppConfig
- from openhands.core.main import run_controller
- from openhands.core.schema import AgentState
- from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
- from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
- from openhands.events.observation import (
- ErrorObservation,
- )
- from openhands.events.serialization import event_to_dict
- from openhands.llm import LLM
- from openhands.llm.metrics import Metrics
- from openhands.runtime.base import Runtime
- from openhands.storage import get_file_store
- @pytest.fixture
- def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
- return str(tmp_path_factory.mktemp('test_event_stream'))
- @pytest.fixture(scope='function')
- def event_loop():
- loop = asyncio.get_event_loop_policy().new_event_loop()
- yield loop
- loop.close()
- @pytest.fixture
- def mock_agent():
- return MagicMock(spec=Agent)
- @pytest.fixture
- def mock_event_stream():
- mock = MagicMock(spec=EventStream)
- mock.get_latest_event_id.return_value = 0
- return mock
- @pytest.fixture
- def mock_status_callback():
- return AsyncMock()
- @pytest.mark.asyncio
- async def test_set_agent_state(mock_agent, mock_event_stream):
- controller = AgentController(
- agent=mock_agent,
- event_stream=mock_event_stream,
- max_iterations=10,
- sid='test',
- confirmation_mode=False,
- headless_mode=True,
- )
- await controller.set_agent_state_to(AgentState.RUNNING)
- assert controller.get_agent_state() == AgentState.RUNNING
- await controller.set_agent_state_to(AgentState.PAUSED)
- assert controller.get_agent_state() == AgentState.PAUSED
- await controller.close()
- @pytest.mark.asyncio
- async def test_on_event_message_action(mock_agent, mock_event_stream):
- controller = AgentController(
- agent=mock_agent,
- event_stream=mock_event_stream,
- max_iterations=10,
- sid='test',
- confirmation_mode=False,
- headless_mode=True,
- )
- controller.state.agent_state = AgentState.RUNNING
- message_action = MessageAction(content='Test message')
- await controller.on_event(message_action)
- assert controller.get_agent_state() == AgentState.RUNNING
- await controller.close()
- @pytest.mark.asyncio
- async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
- controller = AgentController(
- agent=mock_agent,
- event_stream=mock_event_stream,
- max_iterations=10,
- sid='test',
- confirmation_mode=False,
- headless_mode=True,
- )
- controller.state.agent_state = AgentState.RUNNING
- change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
- await controller.on_event(change_state_action)
- assert controller.get_agent_state() == AgentState.PAUSED
- await controller.close()
- @pytest.mark.asyncio
- async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback):
- controller = AgentController(
- agent=mock_agent,
- event_stream=mock_event_stream,
- status_callback=mock_status_callback,
- max_iterations=10,
- sid='test',
- confirmation_mode=False,
- headless_mode=True,
- )
- error_message = 'Test error'
- await controller._react_to_exception(RuntimeError(error_message))
- controller.status_callback.assert_called_once()
- await controller.close()
- @pytest.mark.asyncio
- async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
- config = AppConfig()
- file_store = get_file_store(config.file_store, config.file_store_path)
- event_stream = EventStream(sid='test', file_store=file_store)
- agent = MagicMock(spec=Agent)
- agent = MagicMock(spec=Agent)
- def agent_step_fn(state):
- print(f'agent_step_fn received state: {state}')
- return CmdRunAction(command='ls')
- agent.step = agent_step_fn
- agent.llm = MagicMock(spec=LLM)
- agent.llm.metrics = Metrics()
- agent.llm.config = config.get_llm_config()
- runtime = MagicMock(spec=Runtime)
- async def on_event(event: Event):
- if isinstance(event, CmdRunAction):
- error_obs = ErrorObservation('You messed around with Jim')
- error_obs._cause = event.id
- event_stream.add_event(error_obs, EventSource.USER)
- event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
- runtime.event_stream = event_stream
- state = await run_controller(
- config=config,
- initial_user_action=MessageAction(content='Test message'),
- runtime=runtime,
- sid='test',
- agent=agent,
- fake_user_response_fn=lambda _: 'repeat',
- )
- print(f'state: {state}')
- print(f'event_stream: {list(event_stream.get_events())}')
- assert state.iteration == 4
- assert state.agent_state == AgentState.ERROR
- assert state.last_error == 'Agent got stuck in a loop'
- assert len(list(event_stream.get_events())) == 11
- @pytest.mark.asyncio
- async def test_run_controller_stop_with_stuck():
- config = AppConfig()
- file_store = get_file_store(config.file_store, config.file_store_path)
- event_stream = EventStream(sid='test', file_store=file_store)
- agent = MagicMock(spec=Agent)
- def agent_step_fn(state):
- print(f'agent_step_fn received state: {state}')
- return CmdRunAction(command='ls')
- agent.step = agent_step_fn
- agent.llm = MagicMock(spec=LLM)
- agent.llm.metrics = Metrics()
- agent.llm.config = config.get_llm_config()
- runtime = MagicMock(spec=Runtime)
- async def on_event(event: Event):
- if isinstance(event, CmdRunAction):
- non_fatal_error_obs = ErrorObservation(
- 'Non fatal error here to trigger loop'
- )
- non_fatal_error_obs._cause = event.id
- event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
- event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
- runtime.event_stream = event_stream
- state = await run_controller(
- config=config,
- initial_user_action=MessageAction(content='Test message'),
- runtime=runtime,
- sid='test',
- agent=agent,
- fake_user_response_fn=lambda _: 'repeat',
- )
- events = list(event_stream.get_events())
- print(f'state: {state}')
- for i, event in enumerate(events):
- print(f'event {i}: {event_to_dict(event)}')
- assert state.iteration == 4
- assert len(events) == 11
- # check the eventstream have 4 pairs of repeated actions and observations
- repeating_actions_and_observations = events[2:10]
- for action, observation in zip(
- repeating_actions_and_observations[0::2],
- repeating_actions_and_observations[1::2],
- ):
- action_dict = event_to_dict(action)
- observation_dict = event_to_dict(observation)
- assert action_dict['action'] == 'run' and action_dict['args']['command'] == 'ls'
- assert (
- observation_dict['observation'] == 'error'
- and observation_dict['content'] == 'Non fatal error here to trigger loop'
- )
- last_event = event_to_dict(events[-1])
- assert last_event['extras']['agent_state'] == 'error'
- assert last_event['observation'] == 'agent_state_changed'
- assert state.agent_state == AgentState.ERROR
- assert state.last_error == 'Agent got stuck in a loop'
- @pytest.mark.asyncio
- @pytest.mark.parametrize(
- 'delegate_state',
- [
- AgentState.RUNNING,
- AgentState.FINISHED,
- AgentState.ERROR,
- AgentState.REJECTED,
- ],
- )
- async def test_delegate_step_different_states(
- mock_agent, mock_event_stream, delegate_state
- ):
- controller = AgentController(
- agent=mock_agent,
- event_stream=mock_event_stream,
- max_iterations=10,
- sid='test',
- confirmation_mode=False,
- headless_mode=True,
- )
- mock_delegate = AsyncMock()
- controller.delegate = mock_delegate
- mock_delegate.state.iteration = 5
- mock_delegate.state.outputs = {'result': 'test'}
- mock_delegate.agent.name = 'TestDelegate'
- mock_delegate.get_agent_state = Mock(return_value=delegate_state)
- mock_delegate._step = AsyncMock()
- mock_delegate.close = AsyncMock()
- await controller._delegate_step()
- mock_delegate._step.assert_called_once()
- if delegate_state == AgentState.RUNNING:
- assert controller.delegate is not None
- assert controller.state.iteration == 0
- mock_delegate.close.assert_not_called()
- else:
- assert controller.delegate is None
- assert controller.state.iteration == 5
- mock_delegate.close.assert_called_once()
- await controller.close()
- @pytest.mark.asyncio
- async def test_max_iterations_extension(mock_agent, mock_event_stream):
- # Test with headless_mode=False - should extend max_iterations
- initial_state = State(max_iterations=10)
- controller = AgentController(
- agent=mock_agent,
- event_stream=mock_event_stream,
- max_iterations=10,
- sid='test',
- confirmation_mode=False,
- headless_mode=False,
- initial_state=initial_state,
- )
- controller.state.agent_state = AgentState.RUNNING
- controller.state.iteration = 10
- assert controller.state.traffic_control_state == TrafficControlState.NORMAL
- # Trigger throttling by calling _step() when we hit max_iterations
- await controller._step()
- assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
- assert controller.state.agent_state == AgentState.ERROR
- # Simulate a new user message
- message_action = MessageAction(content='Test message')
- message_action._source = EventSource.USER
- await controller.on_event(message_action)
- # Max iterations should be extended to current iteration + initial max_iterations
- assert (
- controller.state.max_iterations == 20
- ) # Current iteration (10 initial because _step() should not have been executed) + initial max_iterations (10)
- assert controller.state.traffic_control_state == TrafficControlState.NORMAL
- assert controller.state.agent_state == AgentState.RUNNING
- # Close the controller to clean up
- await controller.close()
- # Test with headless_mode=True - should NOT extend max_iterations
- initial_state = State(max_iterations=10)
- controller = AgentController(
- agent=mock_agent,
- event_stream=mock_event_stream,
- max_iterations=10,
- sid='test',
- confirmation_mode=False,
- headless_mode=True,
- initial_state=initial_state,
- )
- controller.state.agent_state = AgentState.RUNNING
- controller.state.iteration = 10
- assert controller.state.traffic_control_state == TrafficControlState.NORMAL
- # Simulate a new user message
- message_action = MessageAction(content='Test message')
- message_action._source = EventSource.USER
- await controller.on_event(message_action)
- # Max iterations should NOT be extended in headless mode
- assert controller.state.max_iterations == 10 # Original value unchanged
- # Trigger throttling by calling _step() when we hit max_iterations
- await controller._step()
- assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
- assert controller.state.agent_state == AgentState.ERROR
- await controller.close()
- @pytest.mark.asyncio
- async def test_step_max_budget(mock_agent, mock_event_stream):
- controller = AgentController(
- agent=mock_agent,
- event_stream=mock_event_stream,
- max_iterations=10,
- max_budget_per_task=10,
- sid='test',
- confirmation_mode=False,
- headless_mode=False,
- )
- controller.state.agent_state = AgentState.RUNNING
- controller.state.metrics.accumulated_cost = 10.1
- assert controller.state.traffic_control_state == TrafficControlState.NORMAL
- await controller._step()
- assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
- assert controller.state.agent_state == AgentState.ERROR
- await controller.close()
- @pytest.mark.asyncio
- async def test_step_max_budget_headless(mock_agent, mock_event_stream):
- controller = AgentController(
- agent=mock_agent,
- event_stream=mock_event_stream,
- max_iterations=10,
- max_budget_per_task=10,
- sid='test',
- confirmation_mode=False,
- headless_mode=True,
- )
- controller.state.agent_state = AgentState.RUNNING
- controller.state.metrics.accumulated_cost = 10.1
- assert controller.state.traffic_control_state == TrafficControlState.NORMAL
- await controller._step()
- assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
- # In headless mode, throttling results in an error
- assert controller.state.agent_state == AgentState.ERROR
- await controller.close()
|