test_agent_controller.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import asyncio
  2. from unittest.mock import AsyncMock, MagicMock, Mock
  3. import pytest
  4. from openhands.controller.agent import Agent
  5. from openhands.controller.agent_controller import AgentController
  6. from openhands.controller.state.state import TrafficControlState
  7. from openhands.core.exceptions import LLMMalformedActionError
  8. from openhands.core.schema import AgentState
  9. from openhands.events import EventStream
  10. from openhands.events.action import ChangeAgentStateAction, MessageAction
  11. @pytest.fixture
  12. def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
  13. return str(tmp_path_factory.mktemp('test_event_stream'))
  14. @pytest.fixture(scope='function')
  15. def event_loop():
  16. loop = asyncio.get_event_loop_policy().new_event_loop()
  17. yield loop
  18. loop.close()
  19. @pytest.fixture
  20. def mock_agent():
  21. return MagicMock(spec=Agent)
  22. @pytest.fixture
  23. def mock_event_stream():
  24. return MagicMock(spec=EventStream)
  25. @pytest.mark.asyncio
  26. async def test_set_agent_state(mock_agent, mock_event_stream):
  27. controller = AgentController(
  28. agent=mock_agent,
  29. event_stream=mock_event_stream,
  30. max_iterations=10,
  31. sid='test',
  32. confirmation_mode=False,
  33. headless_mode=True,
  34. )
  35. await controller.set_agent_state_to(AgentState.RUNNING)
  36. assert controller.get_agent_state() == AgentState.RUNNING
  37. await controller.set_agent_state_to(AgentState.PAUSED)
  38. assert controller.get_agent_state() == AgentState.PAUSED
  39. await controller.close()
  40. @pytest.mark.asyncio
  41. async def test_on_event_message_action(mock_agent, mock_event_stream):
  42. controller = AgentController(
  43. agent=mock_agent,
  44. event_stream=mock_event_stream,
  45. max_iterations=10,
  46. sid='test',
  47. confirmation_mode=False,
  48. headless_mode=True,
  49. )
  50. controller.state.agent_state = AgentState.RUNNING
  51. message_action = MessageAction(content='Test message')
  52. await controller.on_event(message_action)
  53. assert controller.get_agent_state() == AgentState.RUNNING
  54. await controller.close()
  55. @pytest.mark.asyncio
  56. async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
  57. controller = AgentController(
  58. agent=mock_agent,
  59. event_stream=mock_event_stream,
  60. max_iterations=10,
  61. sid='test',
  62. confirmation_mode=False,
  63. headless_mode=True,
  64. )
  65. controller.state.agent_state = AgentState.RUNNING
  66. change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
  67. await controller.on_event(change_state_action)
  68. assert controller.get_agent_state() == AgentState.PAUSED
  69. await controller.close()
  70. @pytest.mark.asyncio
  71. async def test_report_error(mock_agent, mock_event_stream):
  72. controller = AgentController(
  73. agent=mock_agent,
  74. event_stream=mock_event_stream,
  75. max_iterations=10,
  76. sid='test',
  77. confirmation_mode=False,
  78. headless_mode=True,
  79. )
  80. error_message = 'Test error'
  81. await controller.report_error(error_message)
  82. assert controller.state.last_error == error_message
  83. controller.event_stream.add_event.assert_called_once()
  84. await controller.close()
  85. @pytest.mark.asyncio
  86. async def test_step_with_exception(mock_agent, mock_event_stream):
  87. controller = AgentController(
  88. agent=mock_agent,
  89. event_stream=mock_event_stream,
  90. max_iterations=10,
  91. sid='test',
  92. confirmation_mode=False,
  93. headless_mode=True,
  94. )
  95. controller.state.agent_state = AgentState.RUNNING
  96. controller.report_error = AsyncMock()
  97. controller.agent.step.side_effect = LLMMalformedActionError('Malformed action')
  98. await controller._step()
  99. # Verify that report_error was called with the correct error message
  100. controller.report_error.assert_called_once_with('Malformed action')
  101. await controller.close()
  102. @pytest.mark.asyncio
  103. @pytest.mark.parametrize(
  104. 'delegate_state',
  105. [
  106. AgentState.RUNNING,
  107. AgentState.FINISHED,
  108. AgentState.ERROR,
  109. AgentState.REJECTED,
  110. ],
  111. )
  112. async def test_delegate_step_different_states(
  113. mock_agent, mock_event_stream, delegate_state
  114. ):
  115. controller = AgentController(
  116. agent=mock_agent,
  117. event_stream=mock_event_stream,
  118. max_iterations=10,
  119. sid='test',
  120. confirmation_mode=False,
  121. headless_mode=True,
  122. )
  123. mock_delegate = AsyncMock()
  124. controller.delegate = mock_delegate
  125. mock_delegate.state.iteration = 5
  126. mock_delegate.state.outputs = {'result': 'test'}
  127. mock_delegate.agent.name = 'TestDelegate'
  128. mock_delegate.get_agent_state = Mock(return_value=delegate_state)
  129. mock_delegate._step = AsyncMock()
  130. mock_delegate.close = AsyncMock()
  131. await controller._delegate_step()
  132. mock_delegate._step.assert_called_once()
  133. if delegate_state == AgentState.RUNNING:
  134. assert controller.delegate is not None
  135. assert controller.state.iteration == 0
  136. mock_delegate.close.assert_not_called()
  137. else:
  138. assert controller.delegate is None
  139. assert controller.state.iteration == 5
  140. mock_delegate.close.assert_called_once()
  141. await controller.close()
  142. @pytest.mark.asyncio
  143. async def test_step_max_iterations(mock_agent, mock_event_stream):
  144. controller = AgentController(
  145. agent=mock_agent,
  146. event_stream=mock_event_stream,
  147. max_iterations=10,
  148. sid='test',
  149. confirmation_mode=False,
  150. headless_mode=False,
  151. )
  152. controller.state.agent_state = AgentState.RUNNING
  153. controller.state.iteration = 10
  154. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  155. await controller._step()
  156. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  157. assert controller.state.agent_state == AgentState.PAUSED
  158. await controller.close()
  159. @pytest.mark.asyncio
  160. async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
  161. controller = AgentController(
  162. agent=mock_agent,
  163. event_stream=mock_event_stream,
  164. max_iterations=10,
  165. sid='test',
  166. confirmation_mode=False,
  167. headless_mode=True,
  168. )
  169. controller.state.agent_state = AgentState.RUNNING
  170. controller.state.iteration = 10
  171. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  172. await controller._step()
  173. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  174. # In headless mode, throttling results in an error
  175. assert controller.state.agent_state == AgentState.ERROR
  176. await controller.close()
  177. @pytest.mark.asyncio
  178. async def test_step_max_budget(mock_agent, mock_event_stream):
  179. controller = AgentController(
  180. agent=mock_agent,
  181. event_stream=mock_event_stream,
  182. max_iterations=10,
  183. max_budget_per_task=10,
  184. sid='test',
  185. confirmation_mode=False,
  186. headless_mode=False,
  187. )
  188. controller.state.agent_state = AgentState.RUNNING
  189. controller.state.metrics.accumulated_cost = 10.1
  190. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  191. await controller._step()
  192. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  193. assert controller.state.agent_state == AgentState.PAUSED
  194. await controller.close()
  195. @pytest.mark.asyncio
  196. async def test_step_max_budget_headless(mock_agent, mock_event_stream):
  197. controller = AgentController(
  198. agent=mock_agent,
  199. event_stream=mock_event_stream,
  200. max_iterations=10,
  201. max_budget_per_task=10,
  202. sid='test',
  203. confirmation_mode=False,
  204. headless_mode=True,
  205. )
  206. controller.state.agent_state = AgentState.RUNNING
  207. controller.state.metrics.accumulated_cost = 10.1
  208. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  209. await controller._step()
  210. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  211. # In headless mode, throttling results in an error
  212. assert controller.state.agent_state == AgentState.ERROR
  213. await controller.close()