Browse Source

Add tests for agent controller (#3357)

* Add tests for agent controller

* Remove dead code

* Remove dead code
Graham Neubig 1 year ago
parent
commit
50b1256c49
1 changed files with 194 additions and 0 deletions
  1. 194 0
      tests/unit/test_agent_controller.py

+ 194 - 0
tests/unit/test_agent_controller.py

@@ -0,0 +1,194 @@
+import asyncio
+from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+
+from opendevin.controller.agent import Agent
+from opendevin.controller.agent_controller import AgentController
+from opendevin.controller.state.state import TrafficControlState
+from opendevin.core.exceptions import LLMMalformedActionError
+from opendevin.core.schema import AgentState
+from opendevin.events import EventStream
+from opendevin.events.action import ChangeAgentStateAction, MessageAction
+
+
+@pytest.fixture
+def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
+    return str(tmp_path_factory.mktemp('test_event_stream'))
+
+
+@pytest.fixture(scope='function')
+def event_loop():
+    loop = asyncio.get_event_loop_policy().new_event_loop()
+    yield loop
+    loop.close()
+
+
+@pytest.fixture
+def mock_agent():
+    return MagicMock(spec=Agent)
+
+
+@pytest.fixture
+def mock_event_stream():
+    return MagicMock(spec=EventStream)
+
+
+@pytest.mark.asyncio
+async def test_set_agent_state(mock_agent, mock_event_stream):
+    controller = AgentController(
+        agent=mock_agent,
+        event_stream=mock_event_stream,
+        max_iterations=10,
+        sid='test',
+        confirmation_mode=False,
+        headless_mode=True,
+    )
+    await controller.set_agent_state_to(AgentState.RUNNING)
+    assert controller.get_agent_state() == AgentState.RUNNING
+
+    await controller.set_agent_state_to(AgentState.PAUSED)
+    assert controller.get_agent_state() == AgentState.PAUSED
+
+
+@pytest.mark.asyncio
+async def test_on_event_message_action(mock_agent, mock_event_stream):
+    controller = AgentController(
+        agent=mock_agent,
+        event_stream=mock_event_stream,
+        max_iterations=10,
+        sid='test',
+        confirmation_mode=False,
+        headless_mode=True,
+    )
+    controller.state.agent_state = AgentState.RUNNING
+    message_action = MessageAction(content='Test message')
+    await controller.on_event(message_action)
+    assert controller.get_agent_state() == AgentState.RUNNING
+
+
+@pytest.mark.asyncio
+async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
+    controller = AgentController(
+        agent=mock_agent,
+        event_stream=mock_event_stream,
+        max_iterations=10,
+        sid='test',
+        confirmation_mode=False,
+        headless_mode=True,
+    )
+    controller.state.agent_state = AgentState.RUNNING
+    change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
+    await controller.on_event(change_state_action)
+    assert controller.get_agent_state() == AgentState.PAUSED
+
+
+@pytest.mark.asyncio
+async def test_report_error(mock_agent, mock_event_stream):
+    controller = AgentController(
+        agent=mock_agent,
+        event_stream=mock_event_stream,
+        max_iterations=10,
+        sid='test',
+        confirmation_mode=False,
+        headless_mode=True,
+    )
+    error_message = 'Test error'
+    await controller.report_error(error_message)
+    assert controller.state.last_error == error_message
+    controller.event_stream.add_event.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_step_with_exception(mock_agent, mock_event_stream):
+    controller = AgentController(
+        agent=mock_agent,
+        event_stream=mock_event_stream,
+        max_iterations=10,
+        sid='test',
+        confirmation_mode=False,
+        headless_mode=True,
+    )
+    controller.state.agent_state = AgentState.RUNNING
+    controller.report_error = AsyncMock()
+    controller.agent.step.side_effect = LLMMalformedActionError('Malformed action')
+    await controller._step()
+
+    # Verify that report_error was called with the correct error message
+    controller.report_error.assert_called_once_with('Malformed action')
+
+
+@pytest.mark.asyncio
+async def test_step_max_iterations(mock_agent, mock_event_stream):
+    controller = AgentController(
+        agent=mock_agent,
+        event_stream=mock_event_stream,
+        max_iterations=10,
+        sid='test',
+        confirmation_mode=False,
+        headless_mode=False,
+    )
+    controller.state.agent_state = AgentState.RUNNING
+    controller.state.iteration = 10
+    assert controller.state.traffic_control_state == TrafficControlState.NORMAL
+    await controller._step()
+    assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
+    assert controller.state.agent_state == AgentState.PAUSED
+
+
+@pytest.mark.asyncio
+async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
+    controller = AgentController(
+        agent=mock_agent,
+        event_stream=mock_event_stream,
+        max_iterations=10,
+        sid='test',
+        confirmation_mode=False,
+        headless_mode=True,
+    )
+    controller.state.agent_state = AgentState.RUNNING
+    controller.state.iteration = 10
+    assert controller.state.traffic_control_state == TrafficControlState.NORMAL
+    await controller._step()
+    assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
+    # In headless mode, throttling results in an error
+    assert controller.state.agent_state == AgentState.ERROR
+
+
+@pytest.mark.asyncio
+async def test_step_max_budget(mock_agent, mock_event_stream):
+    controller = AgentController(
+        agent=mock_agent,
+        event_stream=mock_event_stream,
+        max_iterations=10,
+        max_budget_per_task=10,
+        sid='test',
+        confirmation_mode=False,
+        headless_mode=False,
+    )
+    controller.state.agent_state = AgentState.RUNNING
+    controller.state.metrics.accumulated_cost = 10.1
+    assert controller.state.traffic_control_state == TrafficControlState.NORMAL
+    await controller._step()
+    assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
+    assert controller.state.agent_state == AgentState.PAUSED
+
+
+@pytest.mark.asyncio
+async def test_step_max_budget_headless(mock_agent, mock_event_stream):
+    controller = AgentController(
+        agent=mock_agent,
+        event_stream=mock_event_stream,
+        max_iterations=10,
+        max_budget_per_task=10,
+        sid='test',
+        confirmation_mode=False,
+        headless_mode=True,
+    )
+    controller.state.agent_state = AgentState.RUNNING
+    controller.state.metrics.accumulated_cost = 10.1
+    assert controller.state.traffic_control_state == TrafficControlState.NORMAL
+    await controller._step()
+    assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
+    # In headless mode, throttling results in an error
+    assert controller.state.agent_state == AgentState.ERROR