test_agent_controller.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. import asyncio
  2. from unittest.mock import AsyncMock, MagicMock, Mock
  3. import pytest
  4. from openhands.controller.agent import Agent
  5. from openhands.controller.agent_controller import AgentController
  6. from openhands.controller.state.state import TrafficControlState
  7. from openhands.core.config import AppConfig
  8. from openhands.core.exceptions import LLMMalformedActionError
  9. from openhands.core.main import run_controller
  10. from openhands.core.schema import AgentState
  11. from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
  12. from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
  13. from openhands.events.observation import FatalErrorObservation
  14. from openhands.llm import LLM
  15. from openhands.llm.metrics import Metrics
  16. from openhands.runtime.base import Runtime
  17. from openhands.storage import get_file_store
  18. @pytest.fixture
  19. def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
  20. return str(tmp_path_factory.mktemp('test_event_stream'))
  21. @pytest.fixture(scope='function')
  22. def event_loop():
  23. loop = asyncio.get_event_loop_policy().new_event_loop()
  24. yield loop
  25. loop.close()
  26. @pytest.fixture
  27. def mock_agent():
  28. return MagicMock(spec=Agent)
  29. @pytest.fixture
  30. def mock_event_stream():
  31. return MagicMock(spec=EventStream)
  32. @pytest.mark.asyncio
  33. async def test_set_agent_state(mock_agent, mock_event_stream):
  34. controller = AgentController(
  35. agent=mock_agent,
  36. event_stream=mock_event_stream,
  37. max_iterations=10,
  38. sid='test',
  39. confirmation_mode=False,
  40. headless_mode=True,
  41. )
  42. await controller.set_agent_state_to(AgentState.RUNNING)
  43. assert controller.get_agent_state() == AgentState.RUNNING
  44. await controller.set_agent_state_to(AgentState.PAUSED)
  45. assert controller.get_agent_state() == AgentState.PAUSED
  46. await controller.close()
  47. @pytest.mark.asyncio
  48. async def test_on_event_message_action(mock_agent, mock_event_stream):
  49. controller = AgentController(
  50. agent=mock_agent,
  51. event_stream=mock_event_stream,
  52. max_iterations=10,
  53. sid='test',
  54. confirmation_mode=False,
  55. headless_mode=True,
  56. )
  57. controller.state.agent_state = AgentState.RUNNING
  58. message_action = MessageAction(content='Test message')
  59. await controller.on_event(message_action)
  60. assert controller.get_agent_state() == AgentState.RUNNING
  61. await controller.close()
  62. @pytest.mark.asyncio
  63. async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
  64. controller = AgentController(
  65. agent=mock_agent,
  66. event_stream=mock_event_stream,
  67. max_iterations=10,
  68. sid='test',
  69. confirmation_mode=False,
  70. headless_mode=True,
  71. )
  72. controller.state.agent_state = AgentState.RUNNING
  73. change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
  74. await controller.on_event(change_state_action)
  75. assert controller.get_agent_state() == AgentState.PAUSED
  76. await controller.close()
  77. @pytest.mark.asyncio
  78. async def test_report_error(mock_agent, mock_event_stream):
  79. controller = AgentController(
  80. agent=mock_agent,
  81. event_stream=mock_event_stream,
  82. max_iterations=10,
  83. sid='test',
  84. confirmation_mode=False,
  85. headless_mode=True,
  86. )
  87. error_message = 'Test error'
  88. await controller.report_error(error_message)
  89. assert controller.state.last_error == error_message
  90. controller.event_stream.add_event.assert_called_once()
  91. await controller.close()
  92. @pytest.mark.asyncio
  93. async def test_step_with_exception(mock_agent, mock_event_stream):
  94. controller = AgentController(
  95. agent=mock_agent,
  96. event_stream=mock_event_stream,
  97. max_iterations=10,
  98. sid='test',
  99. confirmation_mode=False,
  100. headless_mode=True,
  101. )
  102. controller.state.agent_state = AgentState.RUNNING
  103. controller.report_error = AsyncMock()
  104. controller.agent.step.side_effect = LLMMalformedActionError('Malformed action')
  105. await controller._step()
  106. # Verify that report_error was called with the correct error message
  107. controller.report_error.assert_called_once_with('Malformed action')
  108. await controller.close()
  109. @pytest.mark.asyncio
  110. async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
  111. config = AppConfig()
  112. file_store = get_file_store(config.file_store, config.file_store_path)
  113. event_stream = EventStream(sid='test', file_store=file_store)
  114. agent = MagicMock(spec=Agent)
  115. # a random message to send to the runtime
  116. event = CmdRunAction(command='ls')
  117. agent.step.return_value = event
  118. agent.llm = MagicMock(spec=LLM)
  119. agent.llm.metrics = Metrics()
  120. agent.llm.config = config.get_llm_config()
  121. fatal_error_obs = FatalErrorObservation('Fatal error detected')
  122. fatal_error_obs._cause = event.id
  123. runtime = MagicMock(spec=Runtime)
  124. async def on_event(event: Event):
  125. if isinstance(event, CmdRunAction):
  126. await event_stream.async_add_event(fatal_error_obs, EventSource.USER)
  127. event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
  128. runtime.event_stream = event_stream
  129. state = await run_controller(
  130. config=config,
  131. initial_user_action=MessageAction(content='Test message'),
  132. runtime=runtime,
  133. sid='test',
  134. agent=agent,
  135. fake_user_response_fn=lambda _: 'repeat',
  136. )
  137. print(f'state: {state}')
  138. print(f'event_stream: {list(event_stream.get_events())}')
  139. assert state.iteration == 1
  140. # it will first become AgentState.ERROR, then become AgentState.STOPPED
  141. # in side run_controller (since the while loop + sleep no longer loop)
  142. assert state.agent_state == AgentState.STOPPED
  143. assert (
  144. state.last_error
  145. == 'There was a fatal error during agent execution: **FatalErrorObservation**\nFatal error detected'
  146. )
  147. assert len(list(event_stream.get_events())) == 5
  148. @pytest.mark.asyncio
  149. @pytest.mark.parametrize(
  150. 'delegate_state',
  151. [
  152. AgentState.RUNNING,
  153. AgentState.FINISHED,
  154. AgentState.ERROR,
  155. AgentState.REJECTED,
  156. ],
  157. )
  158. async def test_delegate_step_different_states(
  159. mock_agent, mock_event_stream, delegate_state
  160. ):
  161. controller = AgentController(
  162. agent=mock_agent,
  163. event_stream=mock_event_stream,
  164. max_iterations=10,
  165. sid='test',
  166. confirmation_mode=False,
  167. headless_mode=True,
  168. )
  169. mock_delegate = AsyncMock()
  170. controller.delegate = mock_delegate
  171. mock_delegate.state.iteration = 5
  172. mock_delegate.state.outputs = {'result': 'test'}
  173. mock_delegate.agent.name = 'TestDelegate'
  174. mock_delegate.get_agent_state = Mock(return_value=delegate_state)
  175. mock_delegate._step = AsyncMock()
  176. mock_delegate.close = AsyncMock()
  177. await controller._delegate_step()
  178. mock_delegate._step.assert_called_once()
  179. if delegate_state == AgentState.RUNNING:
  180. assert controller.delegate is not None
  181. assert controller.state.iteration == 0
  182. mock_delegate.close.assert_not_called()
  183. else:
  184. assert controller.delegate is None
  185. assert controller.state.iteration == 5
  186. mock_delegate.close.assert_called_once()
  187. await controller.close()
  188. @pytest.mark.asyncio
  189. async def test_step_max_iterations(mock_agent, mock_event_stream):
  190. controller = AgentController(
  191. agent=mock_agent,
  192. event_stream=mock_event_stream,
  193. max_iterations=10,
  194. sid='test',
  195. confirmation_mode=False,
  196. headless_mode=False,
  197. )
  198. controller.state.agent_state = AgentState.RUNNING
  199. controller.state.iteration = 10
  200. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  201. await controller._step()
  202. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  203. assert controller.state.agent_state == AgentState.PAUSED
  204. await controller.close()
  205. @pytest.mark.asyncio
  206. async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
  207. controller = AgentController(
  208. agent=mock_agent,
  209. event_stream=mock_event_stream,
  210. max_iterations=10,
  211. sid='test',
  212. confirmation_mode=False,
  213. headless_mode=True,
  214. )
  215. controller.state.agent_state = AgentState.RUNNING
  216. controller.state.iteration = 10
  217. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  218. await controller._step()
  219. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  220. # In headless mode, throttling results in an error
  221. assert controller.state.agent_state == AgentState.ERROR
  222. await controller.close()
  223. @pytest.mark.asyncio
  224. async def test_step_max_budget(mock_agent, mock_event_stream):
  225. controller = AgentController(
  226. agent=mock_agent,
  227. event_stream=mock_event_stream,
  228. max_iterations=10,
  229. max_budget_per_task=10,
  230. sid='test',
  231. confirmation_mode=False,
  232. headless_mode=False,
  233. )
  234. controller.state.agent_state = AgentState.RUNNING
  235. controller.state.metrics.accumulated_cost = 10.1
  236. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  237. await controller._step()
  238. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  239. assert controller.state.agent_state == AgentState.PAUSED
  240. await controller.close()
  241. @pytest.mark.asyncio
  242. async def test_step_max_budget_headless(mock_agent, mock_event_stream):
  243. controller = AgentController(
  244. agent=mock_agent,
  245. event_stream=mock_event_stream,
  246. max_iterations=10,
  247. max_budget_per_task=10,
  248. sid='test',
  249. confirmation_mode=False,
  250. headless_mode=True,
  251. )
  252. controller.state.agent_state = AgentState.RUNNING
  253. controller.state.metrics.accumulated_cost = 10.1
  254. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  255. await controller._step()
  256. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  257. # In headless mode, throttling results in an error
  258. assert controller.state.agent_state == AgentState.ERROR
  259. await controller.close()