test_agent_controller.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. import asyncio
  2. from unittest.mock import AsyncMock, MagicMock, Mock
  3. import pytest
  4. from openhands.controller.agent import Agent
  5. from openhands.controller.agent_controller import AgentController
  6. from openhands.controller.state.state import TrafficControlState
  7. from openhands.core.config import AppConfig
  8. from openhands.core.exceptions import LLMMalformedActionError
  9. from openhands.core.main import run_controller
  10. from openhands.core.schema import AgentState
  11. from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
  12. from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
  13. from openhands.events.observation import (
  14. ErrorObservation,
  15. FatalErrorObservation,
  16. )
  17. from openhands.events.serialization import event_to_dict
  18. from openhands.llm import LLM
  19. from openhands.llm.metrics import Metrics
  20. from openhands.runtime.base import Runtime
  21. from openhands.storage import get_file_store
  22. @pytest.fixture
  23. def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
  24. return str(tmp_path_factory.mktemp('test_event_stream'))
  25. @pytest.fixture(scope='function')
  26. def event_loop():
  27. loop = asyncio.get_event_loop_policy().new_event_loop()
  28. yield loop
  29. loop.close()
  30. @pytest.fixture
  31. def mock_agent():
  32. return MagicMock(spec=Agent)
  33. @pytest.fixture
  34. def mock_event_stream():
  35. return MagicMock(spec=EventStream)
  36. @pytest.mark.asyncio
  37. async def test_set_agent_state(mock_agent, mock_event_stream):
  38. controller = AgentController(
  39. agent=mock_agent,
  40. event_stream=mock_event_stream,
  41. max_iterations=10,
  42. sid='test',
  43. confirmation_mode=False,
  44. headless_mode=True,
  45. )
  46. await controller.set_agent_state_to(AgentState.RUNNING)
  47. assert controller.get_agent_state() == AgentState.RUNNING
  48. await controller.set_agent_state_to(AgentState.PAUSED)
  49. assert controller.get_agent_state() == AgentState.PAUSED
  50. await controller.close()
  51. @pytest.mark.asyncio
  52. async def test_on_event_message_action(mock_agent, mock_event_stream):
  53. controller = AgentController(
  54. agent=mock_agent,
  55. event_stream=mock_event_stream,
  56. max_iterations=10,
  57. sid='test',
  58. confirmation_mode=False,
  59. headless_mode=True,
  60. )
  61. controller.state.agent_state = AgentState.RUNNING
  62. message_action = MessageAction(content='Test message')
  63. await controller.on_event(message_action)
  64. assert controller.get_agent_state() == AgentState.RUNNING
  65. await controller.close()
  66. @pytest.mark.asyncio
  67. async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
  68. controller = AgentController(
  69. agent=mock_agent,
  70. event_stream=mock_event_stream,
  71. max_iterations=10,
  72. sid='test',
  73. confirmation_mode=False,
  74. headless_mode=True,
  75. )
  76. controller.state.agent_state = AgentState.RUNNING
  77. change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
  78. await controller.on_event(change_state_action)
  79. assert controller.get_agent_state() == AgentState.PAUSED
  80. await controller.close()
  81. @pytest.mark.asyncio
  82. async def test_report_error(mock_agent, mock_event_stream):
  83. controller = AgentController(
  84. agent=mock_agent,
  85. event_stream=mock_event_stream,
  86. max_iterations=10,
  87. sid='test',
  88. confirmation_mode=False,
  89. headless_mode=True,
  90. )
  91. error_message = 'Test error'
  92. await controller.report_error(error_message)
  93. assert controller.state.last_error == error_message
  94. controller.event_stream.add_event.assert_called_once()
  95. await controller.close()
  96. @pytest.mark.asyncio
  97. async def test_step_with_exception(mock_agent, mock_event_stream):
  98. controller = AgentController(
  99. agent=mock_agent,
  100. event_stream=mock_event_stream,
  101. max_iterations=10,
  102. sid='test',
  103. confirmation_mode=False,
  104. headless_mode=True,
  105. )
  106. controller.state.agent_state = AgentState.RUNNING
  107. controller.report_error = AsyncMock()
  108. controller.agent.step.side_effect = LLMMalformedActionError('Malformed action')
  109. await controller._step()
  110. # Verify that report_error was called with the correct error message
  111. controller.report_error.assert_called_once_with('Malformed action')
  112. await controller.close()
  113. @pytest.mark.asyncio
  114. async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
  115. config = AppConfig()
  116. file_store = get_file_store(config.file_store, config.file_store_path)
  117. event_stream = EventStream(sid='test', file_store=file_store)
  118. agent = MagicMock(spec=Agent)
  119. # a random message to send to the runtime
  120. event = CmdRunAction(command='ls')
  121. agent.step.return_value = event
  122. agent.llm = MagicMock(spec=LLM)
  123. agent.llm.metrics = Metrics()
  124. agent.llm.config = config.get_llm_config()
  125. fatal_error_obs = FatalErrorObservation('Fatal error detected')
  126. fatal_error_obs._cause = event.id
  127. runtime = MagicMock(spec=Runtime)
  128. async def on_event(event: Event):
  129. if isinstance(event, CmdRunAction):
  130. await event_stream.async_add_event(fatal_error_obs, EventSource.USER)
  131. event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
  132. runtime.event_stream = event_stream
  133. state = await run_controller(
  134. config=config,
  135. initial_user_action=MessageAction(content='Test message'),
  136. runtime=runtime,
  137. sid='test',
  138. agent=agent,
  139. fake_user_response_fn=lambda _: 'repeat',
  140. )
  141. print(f'state: {state}')
  142. print(f'event_stream: {list(event_stream.get_events())}')
  143. assert state.iteration == 1
  144. # it will first become AgentState.ERROR, then become AgentState.STOPPED
  145. # in side run_controller (since the while loop + sleep no longer loop)
  146. assert state.agent_state == AgentState.STOPPED
  147. assert (
  148. state.last_error
  149. == 'There was a fatal error during agent execution: **FatalErrorObservation**\nFatal error detected'
  150. )
  151. assert len(list(event_stream.get_events())) == 5
  152. @pytest.mark.asyncio
  153. async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
  154. config = AppConfig()
  155. file_store = get_file_store(config.file_store, config.file_store_path)
  156. event_stream = EventStream(sid='test', file_store=file_store)
  157. agent = MagicMock(spec=Agent)
  158. # a random message to send to the runtime
  159. event = CmdRunAction(command='ls')
  160. def agent_step_fn(state):
  161. print(f'agent_step_fn received state: {state}')
  162. return event
  163. agent.step = agent_step_fn
  164. agent.llm = MagicMock(spec=LLM)
  165. agent.llm.metrics = Metrics()
  166. agent.llm.config = config.get_llm_config()
  167. runtime = MagicMock(spec=Runtime)
  168. async def on_event(event: Event):
  169. if isinstance(event, CmdRunAction):
  170. non_fatal_error_obs = ErrorObservation(
  171. 'Non fatal error here to trigger loop'
  172. )
  173. non_fatal_error_obs._cause = event.id
  174. await event_stream.async_add_event(
  175. non_fatal_error_obs, EventSource.ENVIRONMENT
  176. )
  177. event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
  178. runtime.event_stream = event_stream
  179. state = await run_controller(
  180. config=config,
  181. initial_user_action=MessageAction(content='Test message'),
  182. runtime=runtime,
  183. sid='test',
  184. agent=agent,
  185. fake_user_response_fn=lambda _: 'repeat',
  186. )
  187. events = list(event_stream.get_events())
  188. print(f'state: {state}')
  189. for i, event in enumerate(events):
  190. print(f'event {i}: {event_to_dict(event)}')
  191. assert state.iteration == 4
  192. assert len(events) == 12
  193. # check the eventstream have 4 pairs of repeated actions and observations
  194. repeating_actions_and_observations = events[2:10]
  195. for action, observation in zip(
  196. repeating_actions_and_observations[0::2],
  197. repeating_actions_and_observations[1::2],
  198. ):
  199. action_dict = event_to_dict(action)
  200. observation_dict = event_to_dict(observation)
  201. assert action_dict['action'] == 'run' and action_dict['args']['command'] == 'ls'
  202. assert (
  203. observation_dict['observation'] == 'error'
  204. and observation_dict['content'] == 'Non fatal error here to trigger loop'
  205. )
  206. last_event = event_to_dict(events[-1])
  207. assert last_event['extras']['agent_state'] == 'error'
  208. assert last_event['observation'] == 'agent_state_changed'
  209. # it will first become AgentState.ERROR, then become AgentState.STOPPED
  210. # in side run_controller (since the while loop + sleep no longer loop)
  211. assert state.agent_state == AgentState.STOPPED
  212. assert (
  213. state.last_error
  214. == 'There was a fatal error during agent execution: **FatalErrorObservation**\nAgent got stuck in a loop'
  215. )
  216. @pytest.mark.asyncio
  217. @pytest.mark.parametrize(
  218. 'delegate_state',
  219. [
  220. AgentState.RUNNING,
  221. AgentState.FINISHED,
  222. AgentState.ERROR,
  223. AgentState.REJECTED,
  224. ],
  225. )
  226. async def test_delegate_step_different_states(
  227. mock_agent, mock_event_stream, delegate_state
  228. ):
  229. controller = AgentController(
  230. agent=mock_agent,
  231. event_stream=mock_event_stream,
  232. max_iterations=10,
  233. sid='test',
  234. confirmation_mode=False,
  235. headless_mode=True,
  236. )
  237. mock_delegate = AsyncMock()
  238. controller.delegate = mock_delegate
  239. mock_delegate.state.iteration = 5
  240. mock_delegate.state.outputs = {'result': 'test'}
  241. mock_delegate.agent.name = 'TestDelegate'
  242. mock_delegate.get_agent_state = Mock(return_value=delegate_state)
  243. mock_delegate._step = AsyncMock()
  244. mock_delegate.close = AsyncMock()
  245. await controller._delegate_step()
  246. mock_delegate._step.assert_called_once()
  247. if delegate_state == AgentState.RUNNING:
  248. assert controller.delegate is not None
  249. assert controller.state.iteration == 0
  250. mock_delegate.close.assert_not_called()
  251. else:
  252. assert controller.delegate is None
  253. assert controller.state.iteration == 5
  254. mock_delegate.close.assert_called_once()
  255. await controller.close()
  256. @pytest.mark.asyncio
  257. async def test_step_max_iterations(mock_agent, mock_event_stream):
  258. controller = AgentController(
  259. agent=mock_agent,
  260. event_stream=mock_event_stream,
  261. max_iterations=10,
  262. sid='test',
  263. confirmation_mode=False,
  264. headless_mode=False,
  265. )
  266. controller.state.agent_state = AgentState.RUNNING
  267. controller.state.iteration = 10
  268. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  269. await controller._step()
  270. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  271. assert controller.state.agent_state == AgentState.PAUSED
  272. await controller.close()
  273. @pytest.mark.asyncio
  274. async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
  275. controller = AgentController(
  276. agent=mock_agent,
  277. event_stream=mock_event_stream,
  278. max_iterations=10,
  279. sid='test',
  280. confirmation_mode=False,
  281. headless_mode=True,
  282. )
  283. controller.state.agent_state = AgentState.RUNNING
  284. controller.state.iteration = 10
  285. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  286. await controller._step()
  287. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  288. # In headless mode, throttling results in an error
  289. assert controller.state.agent_state == AgentState.ERROR
  290. await controller.close()
  291. @pytest.mark.asyncio
  292. async def test_step_max_budget(mock_agent, mock_event_stream):
  293. controller = AgentController(
  294. agent=mock_agent,
  295. event_stream=mock_event_stream,
  296. max_iterations=10,
  297. max_budget_per_task=10,
  298. sid='test',
  299. confirmation_mode=False,
  300. headless_mode=False,
  301. )
  302. controller.state.agent_state = AgentState.RUNNING
  303. controller.state.metrics.accumulated_cost = 10.1
  304. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  305. await controller._step()
  306. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  307. assert controller.state.agent_state == AgentState.PAUSED
  308. await controller.close()
  309. @pytest.mark.asyncio
  310. async def test_step_max_budget_headless(mock_agent, mock_event_stream):
  311. controller = AgentController(
  312. agent=mock_agent,
  313. event_stream=mock_event_stream,
  314. max_iterations=10,
  315. max_budget_per_task=10,
  316. sid='test',
  317. confirmation_mode=False,
  318. headless_mode=True,
  319. )
  320. controller.state.agent_state = AgentState.RUNNING
  321. controller.state.metrics.accumulated_cost = 10.1
  322. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  323. await controller._step()
  324. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  325. # In headless mode, throttling results in an error
  326. assert controller.state.agent_state == AgentState.ERROR
  327. await controller.close()