|
|
@@ -0,0 +1,194 @@
|
|
|
+import asyncio
|
|
|
+from unittest.mock import AsyncMock, MagicMock
|
|
|
+
|
|
|
+import pytest
|
|
|
+
|
|
|
+from opendevin.controller.agent import Agent
|
|
|
+from opendevin.controller.agent_controller import AgentController
|
|
|
+from opendevin.controller.state.state import TrafficControlState
|
|
|
+from opendevin.core.exceptions import LLMMalformedActionError
|
|
|
+from opendevin.core.schema import AgentState
|
|
|
+from opendevin.events import EventStream
|
|
|
+from opendevin.events.action import ChangeAgentStateAction, MessageAction
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
|
|
|
+ return str(tmp_path_factory.mktemp('test_event_stream'))
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture(scope='function')
|
|
|
+def event_loop():
|
|
|
+ loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
|
+ yield loop
|
|
|
+ loop.close()
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def mock_agent():
|
|
|
+ return MagicMock(spec=Agent)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def mock_event_stream():
|
|
|
+ return MagicMock(spec=EventStream)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_set_agent_state(mock_agent, mock_event_stream):
|
|
|
+ controller = AgentController(
|
|
|
+ agent=mock_agent,
|
|
|
+ event_stream=mock_event_stream,
|
|
|
+ max_iterations=10,
|
|
|
+ sid='test',
|
|
|
+ confirmation_mode=False,
|
|
|
+ headless_mode=True,
|
|
|
+ )
|
|
|
+ await controller.set_agent_state_to(AgentState.RUNNING)
|
|
|
+ assert controller.get_agent_state() == AgentState.RUNNING
|
|
|
+
|
|
|
+ await controller.set_agent_state_to(AgentState.PAUSED)
|
|
|
+ assert controller.get_agent_state() == AgentState.PAUSED
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_on_event_message_action(mock_agent, mock_event_stream):
|
|
|
+ controller = AgentController(
|
|
|
+ agent=mock_agent,
|
|
|
+ event_stream=mock_event_stream,
|
|
|
+ max_iterations=10,
|
|
|
+ sid='test',
|
|
|
+ confirmation_mode=False,
|
|
|
+ headless_mode=True,
|
|
|
+ )
|
|
|
+ controller.state.agent_state = AgentState.RUNNING
|
|
|
+ message_action = MessageAction(content='Test message')
|
|
|
+ await controller.on_event(message_action)
|
|
|
+ assert controller.get_agent_state() == AgentState.RUNNING
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
|
|
|
+ controller = AgentController(
|
|
|
+ agent=mock_agent,
|
|
|
+ event_stream=mock_event_stream,
|
|
|
+ max_iterations=10,
|
|
|
+ sid='test',
|
|
|
+ confirmation_mode=False,
|
|
|
+ headless_mode=True,
|
|
|
+ )
|
|
|
+ controller.state.agent_state = AgentState.RUNNING
|
|
|
+ change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
|
|
|
+ await controller.on_event(change_state_action)
|
|
|
+ assert controller.get_agent_state() == AgentState.PAUSED
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_report_error(mock_agent, mock_event_stream):
|
|
|
+ controller = AgentController(
|
|
|
+ agent=mock_agent,
|
|
|
+ event_stream=mock_event_stream,
|
|
|
+ max_iterations=10,
|
|
|
+ sid='test',
|
|
|
+ confirmation_mode=False,
|
|
|
+ headless_mode=True,
|
|
|
+ )
|
|
|
+ error_message = 'Test error'
|
|
|
+ await controller.report_error(error_message)
|
|
|
+ assert controller.state.last_error == error_message
|
|
|
+ controller.event_stream.add_event.assert_called_once()
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_step_with_exception(mock_agent, mock_event_stream):
|
|
|
+ controller = AgentController(
|
|
|
+ agent=mock_agent,
|
|
|
+ event_stream=mock_event_stream,
|
|
|
+ max_iterations=10,
|
|
|
+ sid='test',
|
|
|
+ confirmation_mode=False,
|
|
|
+ headless_mode=True,
|
|
|
+ )
|
|
|
+ controller.state.agent_state = AgentState.RUNNING
|
|
|
+ controller.report_error = AsyncMock()
|
|
|
+ controller.agent.step.side_effect = LLMMalformedActionError('Malformed action')
|
|
|
+ await controller._step()
|
|
|
+
|
|
|
+ # Verify that report_error was called with the correct error message
|
|
|
+ controller.report_error.assert_called_once_with('Malformed action')
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_step_max_iterations(mock_agent, mock_event_stream):
|
|
|
+ controller = AgentController(
|
|
|
+ agent=mock_agent,
|
|
|
+ event_stream=mock_event_stream,
|
|
|
+ max_iterations=10,
|
|
|
+ sid='test',
|
|
|
+ confirmation_mode=False,
|
|
|
+ headless_mode=False,
|
|
|
+ )
|
|
|
+ controller.state.agent_state = AgentState.RUNNING
|
|
|
+ controller.state.iteration = 10
|
|
|
+ assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
|
+ await controller._step()
|
|
|
+ assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
|
+ assert controller.state.agent_state == AgentState.PAUSED
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
|
|
|
+ controller = AgentController(
|
|
|
+ agent=mock_agent,
|
|
|
+ event_stream=mock_event_stream,
|
|
|
+ max_iterations=10,
|
|
|
+ sid='test',
|
|
|
+ confirmation_mode=False,
|
|
|
+ headless_mode=True,
|
|
|
+ )
|
|
|
+ controller.state.agent_state = AgentState.RUNNING
|
|
|
+ controller.state.iteration = 10
|
|
|
+ assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
|
+ await controller._step()
|
|
|
+ assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
|
+ # In headless mode, throttling results in an error
|
|
|
+ assert controller.state.agent_state == AgentState.ERROR
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_step_max_budget(mock_agent, mock_event_stream):
|
|
|
+ controller = AgentController(
|
|
|
+ agent=mock_agent,
|
|
|
+ event_stream=mock_event_stream,
|
|
|
+ max_iterations=10,
|
|
|
+ max_budget_per_task=10,
|
|
|
+ sid='test',
|
|
|
+ confirmation_mode=False,
|
|
|
+ headless_mode=False,
|
|
|
+ )
|
|
|
+ controller.state.agent_state = AgentState.RUNNING
|
|
|
+ controller.state.metrics.accumulated_cost = 10.1
|
|
|
+ assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
|
+ await controller._step()
|
|
|
+ assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
|
+ assert controller.state.agent_state == AgentState.PAUSED
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
|
|
+ controller = AgentController(
|
|
|
+ agent=mock_agent,
|
|
|
+ event_stream=mock_event_stream,
|
|
|
+ max_iterations=10,
|
|
|
+ max_budget_per_task=10,
|
|
|
+ sid='test',
|
|
|
+ confirmation_mode=False,
|
|
|
+ headless_mode=True,
|
|
|
+ )
|
|
|
+ controller.state.agent_state = AgentState.RUNNING
|
|
|
+ controller.state.metrics.accumulated_cost = 10.1
|
|
|
+ assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
|
+ await controller._step()
|
|
|
+ assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
|
+ # In headless mode, throttling results in an error
|
|
|
+ assert controller.state.agent_state == AgentState.ERROR
|