test_agent_controller.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. import asyncio
  2. from unittest.mock import AsyncMock, MagicMock, Mock
  3. from uuid import uuid4
  4. import pytest
  5. from openhands.controller.agent import Agent
  6. from openhands.controller.agent_controller import AgentController
  7. from openhands.controller.state.state import State, TrafficControlState
  8. from openhands.core.config import AppConfig
  9. from openhands.core.main import run_controller
  10. from openhands.core.schema import AgentState
  11. from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
  12. from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
  13. from openhands.events.observation import (
  14. ErrorObservation,
  15. )
  16. from openhands.events.serialization import event_to_dict
  17. from openhands.llm import LLM
  18. from openhands.llm.metrics import Metrics
  19. from openhands.runtime.base import Runtime
  20. from openhands.storage import get_file_store
  21. from openhands.storage.memory import InMemoryFileStore
  22. @pytest.fixture
  23. def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
  24. return str(tmp_path_factory.mktemp('test_event_stream'))
  25. @pytest.fixture(scope='function')
  26. def event_loop():
  27. loop = asyncio.get_event_loop_policy().new_event_loop()
  28. yield loop
  29. loop.close()
  30. @pytest.fixture
  31. def mock_agent():
  32. return MagicMock(spec=Agent)
  33. @pytest.fixture
  34. def mock_event_stream():
  35. mock = MagicMock(spec=EventStream)
  36. mock.get_latest_event_id.return_value = 0
  37. return mock
  38. @pytest.fixture
  39. def mock_status_callback():
  40. return AsyncMock()
  41. @pytest.mark.asyncio
  42. async def test_set_agent_state(mock_agent, mock_event_stream):
  43. controller = AgentController(
  44. agent=mock_agent,
  45. event_stream=mock_event_stream,
  46. max_iterations=10,
  47. sid='test',
  48. confirmation_mode=False,
  49. headless_mode=True,
  50. )
  51. await controller.set_agent_state_to(AgentState.RUNNING)
  52. assert controller.get_agent_state() == AgentState.RUNNING
  53. await controller.set_agent_state_to(AgentState.PAUSED)
  54. assert controller.get_agent_state() == AgentState.PAUSED
  55. await controller.close()
  56. @pytest.mark.asyncio
  57. async def test_on_event_message_action(mock_agent, mock_event_stream):
  58. controller = AgentController(
  59. agent=mock_agent,
  60. event_stream=mock_event_stream,
  61. max_iterations=10,
  62. sid='test',
  63. confirmation_mode=False,
  64. headless_mode=True,
  65. )
  66. controller.state.agent_state = AgentState.RUNNING
  67. message_action = MessageAction(content='Test message')
  68. await controller.on_event(message_action)
  69. assert controller.get_agent_state() == AgentState.RUNNING
  70. await controller.close()
  71. @pytest.mark.asyncio
  72. async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
  73. controller = AgentController(
  74. agent=mock_agent,
  75. event_stream=mock_event_stream,
  76. max_iterations=10,
  77. sid='test',
  78. confirmation_mode=False,
  79. headless_mode=True,
  80. )
  81. controller.state.agent_state = AgentState.RUNNING
  82. change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
  83. await controller.on_event(change_state_action)
  84. assert controller.get_agent_state() == AgentState.PAUSED
  85. await controller.close()
  86. @pytest.mark.asyncio
  87. async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback):
  88. controller = AgentController(
  89. agent=mock_agent,
  90. event_stream=mock_event_stream,
  91. status_callback=mock_status_callback,
  92. max_iterations=10,
  93. sid='test',
  94. confirmation_mode=False,
  95. headless_mode=True,
  96. )
  97. error_message = 'Test error'
  98. await controller._react_to_exception(RuntimeError(error_message))
  99. controller.status_callback.assert_called_once()
  100. await controller.close()
  101. @pytest.mark.asyncio
  102. async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
  103. config = AppConfig()
  104. file_store = get_file_store(config.file_store, config.file_store_path)
  105. event_stream = EventStream(sid='test', file_store=file_store)
  106. agent = MagicMock(spec=Agent)
  107. agent = MagicMock(spec=Agent)
  108. def agent_step_fn(state):
  109. print(f'agent_step_fn received state: {state}')
  110. return CmdRunAction(command='ls')
  111. agent.step = agent_step_fn
  112. agent.llm = MagicMock(spec=LLM)
  113. agent.llm.metrics = Metrics()
  114. agent.llm.config = config.get_llm_config()
  115. runtime = MagicMock(spec=Runtime)
  116. async def on_event(event: Event):
  117. if isinstance(event, CmdRunAction):
  118. error_obs = ErrorObservation('You messed around with Jim')
  119. error_obs._cause = event.id
  120. event_stream.add_event(error_obs, EventSource.USER)
  121. event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
  122. runtime.event_stream = event_stream
  123. state = await run_controller(
  124. config=config,
  125. initial_user_action=MessageAction(content='Test message'),
  126. runtime=runtime,
  127. sid='test',
  128. agent=agent,
  129. fake_user_response_fn=lambda _: 'repeat',
  130. )
  131. print(f'state: {state}')
  132. print(f'event_stream: {list(event_stream.get_events())}')
  133. assert state.iteration == 4
  134. assert state.agent_state == AgentState.ERROR
  135. assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
  136. assert len(list(event_stream.get_events())) == 11
  137. @pytest.mark.asyncio
  138. async def test_run_controller_stop_with_stuck():
  139. config = AppConfig()
  140. file_store = InMemoryFileStore({})
  141. event_stream = EventStream(sid='test', file_store=file_store)
  142. agent = MagicMock(spec=Agent)
  143. def agent_step_fn(state):
  144. print(f'agent_step_fn received state: {state}')
  145. return CmdRunAction(command='ls')
  146. agent.step = agent_step_fn
  147. agent.llm = MagicMock(spec=LLM)
  148. agent.llm.metrics = Metrics()
  149. agent.llm.config = config.get_llm_config()
  150. runtime = MagicMock(spec=Runtime)
  151. async def on_event(event: Event):
  152. if isinstance(event, CmdRunAction):
  153. non_fatal_error_obs = ErrorObservation(
  154. 'Non fatal error here to trigger loop'
  155. )
  156. non_fatal_error_obs._cause = event.id
  157. event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
  158. event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
  159. runtime.event_stream = event_stream
  160. state = await run_controller(
  161. config=config,
  162. initial_user_action=MessageAction(content='Test message'),
  163. runtime=runtime,
  164. sid='test',
  165. agent=agent,
  166. fake_user_response_fn=lambda _: 'repeat',
  167. )
  168. events = list(event_stream.get_events())
  169. print(f'state: {state}')
  170. for i, event in enumerate(events):
  171. print(f'event {i}: {event_to_dict(event)}')
  172. assert state.iteration == 4
  173. assert len(events) == 11
  174. # check the eventstream have 4 pairs of repeated actions and observations
  175. repeating_actions_and_observations = events[2:10]
  176. for action, observation in zip(
  177. repeating_actions_and_observations[0::2],
  178. repeating_actions_and_observations[1::2],
  179. ):
  180. action_dict = event_to_dict(action)
  181. observation_dict = event_to_dict(observation)
  182. assert action_dict['action'] == 'run' and action_dict['args']['command'] == 'ls'
  183. assert (
  184. observation_dict['observation'] == 'error'
  185. and observation_dict['content'] == 'Non fatal error here to trigger loop'
  186. )
  187. last_event = event_to_dict(events[-1])
  188. assert last_event['extras']['agent_state'] == 'error'
  189. assert last_event['observation'] == 'agent_state_changed'
  190. assert state.agent_state == AgentState.ERROR
  191. assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
  192. @pytest.mark.asyncio
  193. @pytest.mark.parametrize(
  194. 'delegate_state',
  195. [
  196. AgentState.RUNNING,
  197. AgentState.FINISHED,
  198. AgentState.ERROR,
  199. AgentState.REJECTED,
  200. ],
  201. )
  202. async def test_delegate_step_different_states(
  203. mock_agent, mock_event_stream, delegate_state
  204. ):
  205. controller = AgentController(
  206. agent=mock_agent,
  207. event_stream=mock_event_stream,
  208. max_iterations=10,
  209. sid='test',
  210. confirmation_mode=False,
  211. headless_mode=True,
  212. )
  213. mock_delegate = AsyncMock()
  214. controller.delegate = mock_delegate
  215. mock_delegate.state.iteration = 5
  216. mock_delegate.state.outputs = {'result': 'test'}
  217. mock_delegate.agent.name = 'TestDelegate'
  218. mock_delegate.get_agent_state = Mock(return_value=delegate_state)
  219. mock_delegate._step = AsyncMock()
  220. mock_delegate.close = AsyncMock()
  221. await controller._delegate_step()
  222. mock_delegate._step.assert_called_once()
  223. if delegate_state == AgentState.RUNNING:
  224. assert controller.delegate is not None
  225. assert controller.state.iteration == 0
  226. mock_delegate.close.assert_not_called()
  227. else:
  228. assert controller.delegate is None
  229. assert controller.state.iteration == 5
  230. mock_delegate.close.assert_called_once()
  231. await controller.close()
  232. @pytest.mark.asyncio
  233. async def test_max_iterations_extension(mock_agent, mock_event_stream):
  234. # Test with headless_mode=False - should extend max_iterations
  235. initial_state = State(max_iterations=10)
  236. controller = AgentController(
  237. agent=mock_agent,
  238. event_stream=mock_event_stream,
  239. max_iterations=10,
  240. sid='test',
  241. confirmation_mode=False,
  242. headless_mode=False,
  243. initial_state=initial_state,
  244. )
  245. controller.state.agent_state = AgentState.RUNNING
  246. controller.state.iteration = 10
  247. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  248. # Trigger throttling by calling _step() when we hit max_iterations
  249. await controller._step()
  250. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  251. assert controller.state.agent_state == AgentState.ERROR
  252. # Simulate a new user message
  253. message_action = MessageAction(content='Test message')
  254. message_action._source = EventSource.USER
  255. await controller.on_event(message_action)
  256. # Max iterations should be extended to current iteration + initial max_iterations
  257. assert (
  258. controller.state.max_iterations == 20
  259. ) # Current iteration (10 initial because _step() should not have been executed) + initial max_iterations (10)
  260. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  261. assert controller.state.agent_state == AgentState.RUNNING
  262. # Close the controller to clean up
  263. await controller.close()
  264. # Test with headless_mode=True - should NOT extend max_iterations
  265. initial_state = State(max_iterations=10)
  266. controller = AgentController(
  267. agent=mock_agent,
  268. event_stream=mock_event_stream,
  269. max_iterations=10,
  270. sid='test',
  271. confirmation_mode=False,
  272. headless_mode=True,
  273. initial_state=initial_state,
  274. )
  275. controller.state.agent_state = AgentState.RUNNING
  276. controller.state.iteration = 10
  277. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  278. # Simulate a new user message
  279. message_action = MessageAction(content='Test message')
  280. message_action._source = EventSource.USER
  281. await controller.on_event(message_action)
  282. # Max iterations should NOT be extended in headless mode
  283. assert controller.state.max_iterations == 10 # Original value unchanged
  284. # Trigger throttling by calling _step() when we hit max_iterations
  285. await controller._step()
  286. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  287. assert controller.state.agent_state == AgentState.ERROR
  288. await controller.close()
  289. @pytest.mark.asyncio
  290. async def test_step_max_budget(mock_agent, mock_event_stream):
  291. controller = AgentController(
  292. agent=mock_agent,
  293. event_stream=mock_event_stream,
  294. max_iterations=10,
  295. max_budget_per_task=10,
  296. sid='test',
  297. confirmation_mode=False,
  298. headless_mode=False,
  299. )
  300. controller.state.agent_state = AgentState.RUNNING
  301. controller.state.metrics.accumulated_cost = 10.1
  302. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  303. await controller._step()
  304. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  305. assert controller.state.agent_state == AgentState.ERROR
  306. await controller.close()
  307. @pytest.mark.asyncio
  308. async def test_step_max_budget_headless(mock_agent, mock_event_stream):
  309. controller = AgentController(
  310. agent=mock_agent,
  311. event_stream=mock_event_stream,
  312. max_iterations=10,
  313. max_budget_per_task=10,
  314. sid='test',
  315. confirmation_mode=False,
  316. headless_mode=True,
  317. )
  318. controller.state.agent_state = AgentState.RUNNING
  319. controller.state.metrics.accumulated_cost = 10.1
  320. assert controller.state.traffic_control_state == TrafficControlState.NORMAL
  321. await controller._step()
  322. assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
  323. # In headless mode, throttling results in an error
  324. assert controller.state.agent_state == AgentState.ERROR
  325. await controller.close()