test_agent_controller.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. import asyncio
  2. from unittest.mock import AsyncMock, MagicMock, Mock
  3. from uuid import uuid4
  4. import pytest
  5. from openhands.controller.agent import Agent
  6. from openhands.controller.agent_controller import AgentController
  7. from openhands.controller.state.state import State, TrafficControlState
  8. from openhands.core.config import AppConfig
  9. from openhands.core.main import run_controller
  10. from openhands.core.schema import AgentState
  11. from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
  12. from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
  13. from openhands.events.observation import (
  14. ErrorObservation,
  15. )
  16. from openhands.events.serialization import event_to_dict
  17. from openhands.llm import LLM
  18. from openhands.llm.metrics import Metrics
  19. from openhands.runtime.base import Runtime
  20. from openhands.storage import get_file_store
  21. @pytest.fixture
  22. def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
  23. return str(tmp_path_factory.mktemp('test_event_stream'))
  24. @pytest.fixture(scope='function')
  25. def event_loop():
  26. loop = asyncio.get_event_loop_policy().new_event_loop()
  27. yield loop
  28. loop.close()
  29. @pytest.fixture
  30. def mock_agent():
  31. return MagicMock(spec=Agent)
  32. @pytest.fixture
  33. def mock_event_stream():
  34. mock = MagicMock(spec=EventStream)
  35. mock.get_latest_event_id.return_value = 0
  36. return mock
  37. @pytest.fixture
  38. def mock_status_callback():
  39. return AsyncMock()
  40. @pytest.mark.asyncio
  41. async def test_set_agent_state(mock_agent, mock_event_stream):
  42. controller = AgentController(
  43. agent=mock_agent,
  44. event_stream=mock_event_stream,
  45. max_iterations=10,
  46. sid='test',
  47. confirmation_mode=False,
  48. headless_mode=True,
  49. )
  50. await controller.set_agent_state_to(AgentState.RUNNING)
  51. assert controller.get_agent_state() == AgentState.RUNNING
  52. await controller.set_agent_state_to(AgentState.PAUSED)
  53. assert controller.get_agent_state() == AgentState.PAUSED
  54. await controller.close()
  55. @pytest.mark.asyncio
  56. async def test_on_event_message_action(mock_agent, mock_event_stream):
  57. controller = AgentController(
  58. agent=mock_agent,
  59. event_stream=mock_event_stream,
  60. max_iterations=10,
  61. sid='test',
  62. confirmation_mode=False,
  63. headless_mode=True,
  64. )
  65. controller.state.agent_state = AgentState.RUNNING
  66. message_action = MessageAction(content='Test message')
  67. await controller.on_event(message_action)
  68. assert controller.get_agent_state() == AgentState.RUNNING
  69. await controller.close()
  70. @pytest.mark.asyncio
  71. async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
  72. controller = AgentController(
  73. agent=mock_agent,
  74. event_stream=mock_event_stream,
  75. max_iterations=10,
  76. sid='test',
  77. confirmation_mode=False,
  78. headless_mode=True,
  79. )
  80. controller.state.agent_state = AgentState.RUNNING
  81. change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
  82. await controller.on_event(change_state_action)
  83. assert controller.get_agent_state() == AgentState.PAUSED
  84. await controller.close()
  85. @pytest.mark.asyncio
  86. async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback):
  87. controller = AgentController(
  88. agent=mock_agent,
  89. event_stream=mock_event_stream,
  90. status_callback=mock_status_callback,
  91. max_iterations=10,
  92. sid='test',
  93. confirmation_mode=False,
  94. headless_mode=True,
  95. )
  96. error_message = 'Test error'
  97. await controller._react_to_exception(RuntimeError(error_message))
  98. controller.status_callback.assert_called_once()
  99. await controller.close()
  100. @pytest.mark.asyncio
  101. async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
  102. config = AppConfig()
  103. file_store = get_file_store(config.file_store, config.file_store_path)
  104. event_stream = EventStream(sid='test', file_store=file_store)
  105. agent = MagicMock(spec=Agent)
  106. agent = MagicMock(spec=Agent)
  107. def agent_step_fn(state):
  108. print(f'agent_step_fn received state: {state}')
  109. return CmdRunAction(command='ls')
  110. agent.step = agent_step_fn
  111. agent.llm = MagicMock(spec=LLM)
  112. agent.llm.metrics = Metrics()
  113. agent.llm.config = config.get_llm_config()
  114. runtime = MagicMock(spec=Runtime)
  115. async def on_event(event: Event):
  116. if isinstance(event, CmdRunAction):
  117. error_obs = ErrorObservation('You messed around with Jim')
  118. error_obs._cause = event.id
  119. event_stream.add_event(error_obs, EventSource.USER)
  120. event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
  121. runtime.event_stream = event_stream
  122. state = await run_controller(
  123. config=config,
  124. initial_user_action=MessageAction(content='Test message'),
  125. runtime=runtime,
  126. sid='test',
  127. agent=agent,
  128. fake_user_response_fn=lambda _: 'repeat',
  129. )
  130. print(f'state: {state}')
  131. print(f'event_stream: {list(event_stream.get_events())}')
  132. assert state.iteration == 4
  133. assert state.agent_state == AgentState.ERROR
  134. assert state.last_error == 'Agent got stuck in a loop'
  135. assert len(list(event_stream.get_events())) == 11
  136. @pytest.mark.asyncio
  137. async def test_run_controller_stop_with_stuck():
  138. config = AppConfig()
  139. file_store = get_file_store(config.file_store, config.file_store_path)
  140. event_stream = EventStream(sid='test', file_store=file_store)
  141. agent = MagicMock(spec=Agent)
  142. def agent_step_fn(state):
  143. print(f'agent_step_fn received state: {state}')
  144. return CmdRunAction(command='ls')
  145. agent.step = agent_step_fn
  146. agent.llm = MagicMock(spec=LLM)
  147. agent.llm.metrics = Metrics()
  148. agent.llm.config = config.get_llm_config()
  149. runtime = MagicMock(spec=Runtime)
  150. async def on_event(event: Event):
  151. if isinstance(event, CmdRunAction):
  152. non_fatal_error_obs = ErrorObservation(
  153. 'Non fatal error here to trigger loop'
  154. )
  155. non_fatal_error_obs._cause = event.id
  156. event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
  157. event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
  158. runtime.event_stream = event_stream
  159. state = await run_controller(
  160. config=config,
  161. initial_user_action=MessageAction(content='Test message'),
  162. runtime=runtime,
  163. sid='test',
  164. agent=agent,
  165. fake_user_response_fn=lambda _: 'repeat',
  166. )
  167. events = list(event_stream.get_events())
  168. print(f'state: {state}')
  169. for i, event in enumerate(events):
  170. print(f'event {i}: {event_to_dict(event)}')
  171. assert state.iteration == 4
  172. assert len(events) == 11
  173. # check the eventstream have 4 pairs of repeated actions and observations
  174. repeating_actions_and_observations = events[2:10]
  175. for action, observation in zip(
  176. repeating_actions_and_observations[0::2],
  177. repeating_actions_and_observations[1::2],
  178. ):
  179. action_dict = event_to_dict(action)
  180. observation_dict = event_to_dict(observation)
  181. assert action_dict['action'] == 'run' and action_dict['args']['command'] == 'ls'
  182. assert (
  183. observation_dict['observation'] == 'error'
  184. and observation_dict['content'] == 'Non fatal error here to trigger loop'
  185. )
  186. last_event = event_to_dict(events[-1])
  187. assert last_event['extras']['agent_state'] == 'error'
  188. assert last_event['observation'] == 'agent_state_changed'
  189. assert state.agent_state == AgentState.ERROR
  190. assert state.last_error == 'Agent got stuck in a loop'
  191. @pytest.mark.asyncio
  192. @pytest.mark.parametrize(
  193. 'delegate_state',
  194. [
  195. AgentState.RUNNING,
  196. AgentState.FINISHED,
  197. AgentState.ERROR,
  198. AgentState.REJECTED,
  199. ],
  200. )
  201. async def test_delegate_step_different_states(
  202. mock_agent, mock_event_stream, delegate_state
  203. ):
  204. controller = AgentController(
  205. agent=mock_agent,
  206. event_stream=mock_event_stream,
  207. max_iterations=10,
  208. sid='test',
  209. confirmation_mode=False,
  210. headless_mode=True,
  211. )
  212. mock_delegate = AsyncMock()
  213. controller.delegate = mock_delegate
  214. mock_delegate.state.iteration = 5
  215. mock_delegate.state.outputs = {'result': 'test'}
  216. mock_delegate.agent.name = 'TestDelegate'
  217. mock_delegate.get_agent_state = Mock(return_value=delegate_state)
  218. mock_delegate._step = AsyncMock()
  219. mock_delegate.close = AsyncMock()
  220. await controller._delegate_step()
  221. mock_delegate._step.assert_called_once()
  222. if delegate_state == AgentState.RUNNING:
  223. assert controller.delegate is not None
  224. assert controller.state.iteration == 0
  225. mock_delegate.close.assert_not_called()
  226. else:
  227. assert controller.delegate is None
  228. assert controller.state.iteration == 5
  229. mock_delegate.close.assert_called_once()
  230. await controller.close()
  231. @pytest.mark.asyncio
  232. async def test_max_iterations_extension(mock_agent, mock_event_stream):
  233. # Test with headless_mode=False - should extend max_iterations
  234. initial_state = State(max_iterations=10)
  235. controller = AgentController(
  236. agent=mock_agent,
  237. event_stream=mock_event_stream,
  238. max_iterations=10,
  239. sid='test',
  240. confirmation_mode=False,
  241. headless_mode=False,
  242. initial_state=initial_state,
  243. )
  244. controller.state.agent_state = AgentState.RUNNING
  245. controller.state.iteration = 10
  246. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  247. # Trigger throttling by calling _step() when we hit max_iterations
  248. await controller._step()
  249. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  250. assert controller.state.agent_state == AgentState.ERROR
  251. # Simulate a new user message
  252. message_action = MessageAction(content='Test message')
  253. message_action._source = EventSource.USER
  254. await controller.on_event(message_action)
  255. # Max iterations should be extended to current iteration + initial max_iterations
  256. assert (
  257. controller.state.max_iterations == 20
  258. ) # Current iteration (10 initial because _step() should not have been executed) + initial max_iterations (10)
  259. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  260. assert controller.state.agent_state == AgentState.RUNNING
  261. # Close the controller to clean up
  262. await controller.close()
  263. # Test with headless_mode=True - should NOT extend max_iterations
  264. initial_state = State(max_iterations=10)
  265. controller = AgentController(
  266. agent=mock_agent,
  267. event_stream=mock_event_stream,
  268. max_iterations=10,
  269. sid='test',
  270. confirmation_mode=False,
  271. headless_mode=True,
  272. initial_state=initial_state,
  273. )
  274. controller.state.agent_state = AgentState.RUNNING
  275. controller.state.iteration = 10
  276. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  277. # Simulate a new user message
  278. message_action = MessageAction(content='Test message')
  279. message_action._source = EventSource.USER
  280. await controller.on_event(message_action)
  281. # Max iterations should NOT be extended in headless mode
  282. assert controller.state.max_iterations == 10 # Original value unchanged
  283. # Trigger throttling by calling _step() when we hit max_iterations
  284. await controller._step()
  285. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  286. assert controller.state.agent_state == AgentState.ERROR
  287. await controller.close()
  288. @pytest.mark.asyncio
  289. async def test_step_max_budget(mock_agent, mock_event_stream):
  290. controller = AgentController(
  291. agent=mock_agent,
  292. event_stream=mock_event_stream,
  293. max_iterations=10,
  294. max_budget_per_task=10,
  295. sid='test',
  296. confirmation_mode=False,
  297. headless_mode=False,
  298. )
  299. controller.state.agent_state = AgentState.RUNNING
  300. controller.state.metrics.accumulated_cost = 10.1
  301. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  302. await controller._step()
  303. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  304. assert controller.state.agent_state == AgentState.ERROR
  305. await controller.close()
  306. @pytest.mark.asyncio
  307. async def test_step_max_budget_headless(mock_agent, mock_event_stream):
  308. controller = AgentController(
  309. agent=mock_agent,
  310. event_stream=mock_event_stream,
  311. max_iterations=10,
  312. max_budget_per_task=10,
  313. sid='test',
  314. confirmation_mode=False,
  315. headless_mode=True,
  316. )
  317. controller.state.agent_state = AgentState.RUNNING
  318. controller.state.metrics.accumulated_cost = 10.1
  319. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  320. await controller._step()
  321. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  322. # In headless mode, throttling results in an error
  323. assert controller.state.agent_state == AgentState.ERROR
  324. await controller.close()