test_agent_controller.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import asyncio
  2. from unittest.mock import AsyncMock, MagicMock
  3. import pytest
  4. from openhands.controller.agent import Agent
  5. from openhands.controller.agent_controller import AgentController
  6. from openhands.controller.state.state import TrafficControlState
  7. from openhands.core.exceptions import LLMMalformedActionError
  8. from openhands.core.schema import AgentState
  9. from openhands.events import EventStream
  10. from openhands.events.action import ChangeAgentStateAction, MessageAction
  11. @pytest.fixture
  12. def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
  13. return str(tmp_path_factory.mktemp('test_event_stream'))
  14. @pytest.fixture(scope='function')
  15. def event_loop():
  16. loop = asyncio.get_event_loop_policy().new_event_loop()
  17. yield loop
  18. loop.close()
  19. @pytest.fixture
  20. def mock_agent():
  21. return MagicMock(spec=Agent)
  22. @pytest.fixture
  23. def mock_event_stream():
  24. return MagicMock(spec=EventStream)
  25. @pytest.mark.asyncio
  26. async def test_set_agent_state(mock_agent, mock_event_stream):
  27. controller = AgentController(
  28. agent=mock_agent,
  29. event_stream=mock_event_stream,
  30. max_iterations=10,
  31. sid='test',
  32. confirmation_mode=False,
  33. headless_mode=True,
  34. )
  35. await controller.set_agent_state_to(AgentState.RUNNING)
  36. assert controller.get_agent_state() == AgentState.RUNNING
  37. await controller.set_agent_state_to(AgentState.PAUSED)
  38. assert controller.get_agent_state() == AgentState.PAUSED
  39. @pytest.mark.asyncio
  40. async def test_on_event_message_action(mock_agent, mock_event_stream):
  41. controller = AgentController(
  42. agent=mock_agent,
  43. event_stream=mock_event_stream,
  44. max_iterations=10,
  45. sid='test',
  46. confirmation_mode=False,
  47. headless_mode=True,
  48. )
  49. controller.state.agent_state = AgentState.RUNNING
  50. message_action = MessageAction(content='Test message')
  51. await controller.on_event(message_action)
  52. assert controller.get_agent_state() == AgentState.RUNNING
  53. @pytest.mark.asyncio
  54. async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
  55. controller = AgentController(
  56. agent=mock_agent,
  57. event_stream=mock_event_stream,
  58. max_iterations=10,
  59. sid='test',
  60. confirmation_mode=False,
  61. headless_mode=True,
  62. )
  63. controller.state.agent_state = AgentState.RUNNING
  64. change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
  65. await controller.on_event(change_state_action)
  66. assert controller.get_agent_state() == AgentState.PAUSED
  67. @pytest.mark.asyncio
  68. async def test_report_error(mock_agent, mock_event_stream):
  69. controller = AgentController(
  70. agent=mock_agent,
  71. event_stream=mock_event_stream,
  72. max_iterations=10,
  73. sid='test',
  74. confirmation_mode=False,
  75. headless_mode=True,
  76. )
  77. error_message = 'Test error'
  78. await controller.report_error(error_message)
  79. assert controller.state.last_error == error_message
  80. controller.event_stream.add_event.assert_called_once()
  81. @pytest.mark.asyncio
  82. async def test_step_with_exception(mock_agent, mock_event_stream):
  83. controller = AgentController(
  84. agent=mock_agent,
  85. event_stream=mock_event_stream,
  86. max_iterations=10,
  87. sid='test',
  88. confirmation_mode=False,
  89. headless_mode=True,
  90. )
  91. controller.state.agent_state = AgentState.RUNNING
  92. controller.report_error = AsyncMock()
  93. controller.agent.step.side_effect = LLMMalformedActionError('Malformed action')
  94. await controller._step()
  95. # Verify that report_error was called with the correct error message
  96. controller.report_error.assert_called_once_with('Malformed action')
  97. @pytest.mark.asyncio
  98. async def test_step_max_iterations(mock_agent, mock_event_stream):
  99. controller = AgentController(
  100. agent=mock_agent,
  101. event_stream=mock_event_stream,
  102. max_iterations=10,
  103. sid='test',
  104. confirmation_mode=False,
  105. headless_mode=False,
  106. )
  107. controller.state.agent_state = AgentState.RUNNING
  108. controller.state.iteration = 10
  109. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  110. await controller._step()
  111. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  112. assert controller.state.agent_state == AgentState.PAUSED
  113. @pytest.mark.asyncio
  114. async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
  115. controller = AgentController(
  116. agent=mock_agent,
  117. event_stream=mock_event_stream,
  118. max_iterations=10,
  119. sid='test',
  120. confirmation_mode=False,
  121. headless_mode=True,
  122. )
  123. controller.state.agent_state = AgentState.RUNNING
  124. controller.state.iteration = 10
  125. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  126. await controller._step()
  127. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  128. # In headless mode, throttling results in an error
  129. assert controller.state.agent_state == AgentState.ERROR
  130. @pytest.mark.asyncio
  131. async def test_step_max_budget(mock_agent, mock_event_stream):
  132. controller = AgentController(
  133. agent=mock_agent,
  134. event_stream=mock_event_stream,
  135. max_iterations=10,
  136. max_budget_per_task=10,
  137. sid='test',
  138. confirmation_mode=False,
  139. headless_mode=False,
  140. )
  141. controller.state.agent_state = AgentState.RUNNING
  142. controller.state.metrics.accumulated_cost = 10.1
  143. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  144. await controller._step()
  145. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  146. assert controller.state.agent_state == AgentState.PAUSED
  147. @pytest.mark.asyncio
  148. async def test_step_max_budget_headless(mock_agent, mock_event_stream):
  149. controller = AgentController(
  150. agent=mock_agent,
  151. event_stream=mock_event_stream,
  152. max_iterations=10,
  153. max_budget_per_task=10,
  154. sid='test',
  155. confirmation_mode=False,
  156. headless_mode=True,
  157. )
  158. controller.state.agent_state = AgentState.RUNNING
  159. controller.state.metrics.accumulated_cost = 10.1
  160. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  161. await controller._step()
  162. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  163. # In headless mode, throttling results in an error
  164. assert controller.state.agent_state == AgentState.ERROR