test_is_stuck.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. import logging
  2. from unittest.mock import Mock, patch
  3. import pytest
  4. from pytest import TempPathFactory
  5. from openhands.controller.agent_controller import AgentController
  6. from openhands.controller.state.state import State
  7. from openhands.controller.stuck import StuckDetector
  8. from openhands.events.action import CmdRunAction, FileReadAction, MessageAction
  9. from openhands.events.action.commands import IPythonRunCellAction
  10. from openhands.events.observation import (
  11. CmdOutputObservation,
  12. FileReadObservation,
  13. )
  14. from openhands.events.observation.commands import IPythonRunCellObservation
  15. from openhands.events.observation.empty import NullObservation
  16. from openhands.events.observation.error import ErrorObservation
  17. from openhands.events.stream import EventSource, EventStream
  18. from openhands.memory.history import ShortTermHistory
  19. from openhands.storage import get_file_store
  20. def collect_events(stream):
  21. return [event for event in stream.get_events()]
  22. logging.basicConfig(level=logging.DEBUG)
  23. jupyter_line_1 = '\n[Jupyter current working directory:'
  24. jupyter_line_2 = '\n[Jupyter Python interpreter:'
  25. code_snippet = """
  26. edit_file_by_replace(
  27. 'book_store.py',
  28. to_replace=\"""def total(basket):
  29. if not basket:
  30. return 0
  31. """
  32. @pytest.fixture
  33. def temp_dir(tmp_path_factory: TempPathFactory) -> str:
  34. return str(tmp_path_factory.mktemp('test_is_stuck'))
  35. @pytest.fixture
  36. def event_stream(temp_dir):
  37. file_store = get_file_store('local', temp_dir)
  38. event_stream = EventStream('asdf', file_store)
  39. yield event_stream
  40. # clear after each test
  41. event_stream.clear()
  42. class TestStuckDetector:
  43. @pytest.fixture
  44. def stuck_detector(self, event_stream):
  45. state = State(inputs={}, max_iterations=50)
  46. state.history.set_event_stream(event_stream)
  47. return StuckDetector(state)
  48. def _impl_syntax_error_events(
  49. self,
  50. event_stream: EventStream,
  51. error_message: str,
  52. random_line: bool,
  53. incidents: int = 4,
  54. ):
  55. for i in range(incidents):
  56. ipython_action = IPythonRunCellAction(code=code_snippet)
  57. event_stream.add_event(ipython_action, EventSource.AGENT)
  58. extra_number = (i + 1) * 10 if random_line else '42'
  59. extra_line = '\n' * (i + 1) if random_line else ''
  60. ipython_observation = IPythonRunCellObservation(
  61. content=f' Cell In[1], line {extra_number}\n'
  62. 'to_replace="""def largest(min_factor, max_factor):\n ^\n'
  63. f'{error_message}{extra_line}' + jupyter_line_1 + jupyter_line_2,
  64. code=code_snippet,
  65. )
  66. ipython_observation._cause = ipython_action._id
  67. event_stream.add_event(ipython_observation, EventSource.USER)
  68. def _impl_unterminated_string_error_events(
  69. self, event_stream: EventStream, random_line: bool, incidents: int = 4
  70. ):
  71. for i in range(incidents):
  72. ipython_action = IPythonRunCellAction(code=code_snippet)
  73. event_stream.add_event(ipython_action, EventSource.AGENT)
  74. line_number = (i + 1) * 10 if random_line else '1'
  75. ipython_observation = IPythonRunCellObservation(
  76. content=f'print(" Cell In[1], line {line_number}\nhello\n ^\nSyntaxError: unterminated string literal (detected at line {line_number})'
  77. + jupyter_line_1
  78. + jupyter_line_2,
  79. code=code_snippet,
  80. )
  81. ipython_observation._cause = ipython_action._id
  82. event_stream.add_event(ipython_observation, EventSource.USER)
  83. def test_history_too_short(
  84. self, stuck_detector: StuckDetector, event_stream: EventStream
  85. ):
  86. message_action = MessageAction(content='Hello', wait_for_response=False)
  87. message_action._source = EventSource.USER
  88. observation = NullObservation(content='')
  89. observation._cause = message_action.id
  90. event_stream.add_event(message_action, EventSource.USER)
  91. event_stream.add_event(observation, EventSource.USER)
  92. cmd_action = CmdRunAction(command='ls')
  93. event_stream.add_event(cmd_action, EventSource.AGENT)
  94. cmd_observation = CmdOutputObservation(
  95. command_id=1, command='ls', content='file1.txt\nfile2.txt'
  96. )
  97. cmd_observation._cause = cmd_action._id
  98. event_stream.add_event(cmd_observation, EventSource.USER)
  99. # stuck_detector.state.history.set_event_stream(event_stream)
  100. assert stuck_detector.is_stuck() is False
  101. def test_is_stuck_repeating_action_observation(
  102. self, stuck_detector: StuckDetector, event_stream: EventStream
  103. ):
  104. message_action = MessageAction(content='Done', wait_for_response=False)
  105. message_action._source = EventSource.USER
  106. hello_action = MessageAction(content='Hello', wait_for_response=False)
  107. hello_observation = NullObservation('')
  108. # 2 events
  109. event_stream.add_event(hello_action, EventSource.USER)
  110. event_stream.add_event(hello_observation, EventSource.USER)
  111. cmd_action_1 = CmdRunAction(command='ls')
  112. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  113. cmd_observation_1 = CmdOutputObservation(
  114. content='', command='ls', command_id=cmd_action_1._id
  115. )
  116. cmd_observation_1._cause = cmd_action_1._id
  117. event_stream.add_event(cmd_observation_1, EventSource.USER)
  118. # 4 events
  119. cmd_action_2 = CmdRunAction(command='ls')
  120. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  121. cmd_observation_2 = CmdOutputObservation(
  122. content='', command='ls', command_id=cmd_action_2._id
  123. )
  124. cmd_observation_2._cause = cmd_action_2._id
  125. event_stream.add_event(cmd_observation_2, EventSource.USER)
  126. # 6 events
  127. # random user message just because we can
  128. message_null_observation = NullObservation(content='')
  129. event_stream.add_event(message_action, EventSource.USER)
  130. event_stream.add_event(message_null_observation, EventSource.USER)
  131. # 8 events
  132. assert stuck_detector.is_stuck() is False
  133. assert stuck_detector.state.almost_stuck == 2
  134. cmd_action_3 = CmdRunAction(command='ls')
  135. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  136. cmd_observation_3 = CmdOutputObservation(
  137. content='', command='ls', command_id=cmd_action_3._id
  138. )
  139. cmd_observation_3._cause = cmd_action_3._id
  140. event_stream.add_event(cmd_observation_3, EventSource.USER)
  141. # 10 events
  142. assert len(collect_events(event_stream)) == 10
  143. assert len(list(stuck_detector.state.history.get_events())) == 8
  144. assert len(stuck_detector.state.history.get_pairs()) == 5
  145. assert stuck_detector.is_stuck() is False
  146. assert stuck_detector.state.almost_stuck == 1
  147. cmd_action_4 = CmdRunAction(command='ls')
  148. event_stream.add_event(cmd_action_4, EventSource.AGENT)
  149. cmd_observation_4 = CmdOutputObservation(
  150. content='', command='ls', command_id=cmd_action_4._id
  151. )
  152. cmd_observation_4._cause = cmd_action_4._id
  153. event_stream.add_event(cmd_observation_4, EventSource.USER)
  154. # 12 events
  155. assert len(collect_events(event_stream)) == 12
  156. assert len(list(stuck_detector.state.history.get_events())) == 10
  157. assert len(stuck_detector.state.history.get_pairs()) == 6
  158. with patch('logging.Logger.warning') as mock_warning:
  159. assert stuck_detector.is_stuck() is True
  160. assert stuck_detector.state.almost_stuck == 0
  161. mock_warning.assert_called_once_with('Action, Observation loop detected')
  162. def test_is_stuck_repeating_action_error(
  163. self, stuck_detector: StuckDetector, event_stream: EventStream
  164. ):
  165. # (action, error_observation), not necessarily the same error
  166. message_action = MessageAction(content='Done', wait_for_response=False)
  167. message_action._source = EventSource.USER
  168. hello_action = MessageAction(content='Hello', wait_for_response=False)
  169. hello_observation = NullObservation(content='')
  170. event_stream.add_event(hello_action, EventSource.USER)
  171. hello_observation._cause = hello_action._id
  172. event_stream.add_event(hello_observation, EventSource.USER)
  173. # 2 events
  174. cmd_action_1 = CmdRunAction(command='invalid_command')
  175. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  176. error_observation_1 = ErrorObservation(content='Command not found')
  177. error_observation_1._cause = cmd_action_1._id
  178. event_stream.add_event(error_observation_1, EventSource.USER)
  179. # 4 events
  180. cmd_action_2 = CmdRunAction(command='invalid_command')
  181. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  182. error_observation_2 = ErrorObservation(
  183. content='Command still not found or another error'
  184. )
  185. error_observation_2._cause = cmd_action_2._id
  186. event_stream.add_event(error_observation_2, EventSource.USER)
  187. # 6 events
  188. message_null_observation = NullObservation(content='')
  189. event_stream.add_event(message_action, EventSource.USER)
  190. event_stream.add_event(message_null_observation, EventSource.USER)
  191. # 8 events
  192. cmd_action_3 = CmdRunAction(command='invalid_command')
  193. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  194. error_observation_3 = ErrorObservation(content='Different error')
  195. error_observation_3._cause = cmd_action_3._id
  196. event_stream.add_event(error_observation_3, EventSource.USER)
  197. # 10 events
  198. cmd_action_4 = CmdRunAction(command='invalid_command')
  199. event_stream.add_event(cmd_action_4, EventSource.AGENT)
  200. error_observation_4 = ErrorObservation(content='Command not found')
  201. error_observation_4._cause = cmd_action_4._id
  202. event_stream.add_event(error_observation_4, EventSource.USER)
  203. # 12 events
  204. with patch('logging.Logger.warning') as mock_warning:
  205. assert stuck_detector.is_stuck() is True
  206. mock_warning.assert_called_once_with(
  207. 'Action, ErrorObservation loop detected'
  208. )
  209. def test_is_stuck_invalid_syntax_error(
  210. self, stuck_detector: StuckDetector, event_stream: EventStream
  211. ):
  212. self._impl_syntax_error_events(
  213. event_stream,
  214. error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
  215. random_line=False,
  216. )
  217. with patch('logging.Logger.warning'):
  218. assert stuck_detector.is_stuck() is True
  219. def test_is_not_stuck_invalid_syntax_error_random_lines(
  220. self, stuck_detector: StuckDetector, event_stream: EventStream
  221. ):
  222. self._impl_syntax_error_events(
  223. event_stream,
  224. error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
  225. random_line=True,
  226. )
  227. with patch('logging.Logger.warning'):
  228. assert stuck_detector.is_stuck() is False
  229. def test_is_not_stuck_invalid_syntax_error_only_three_incidents(
  230. self, stuck_detector: StuckDetector, event_stream: EventStream
  231. ):
  232. self._impl_syntax_error_events(
  233. event_stream,
  234. error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
  235. random_line=True,
  236. incidents=3,
  237. )
  238. with patch('logging.Logger.warning'):
  239. assert stuck_detector.is_stuck() is False
  240. def test_is_stuck_incomplete_input_error(
  241. self, stuck_detector: StuckDetector, event_stream: EventStream
  242. ):
  243. self._impl_syntax_error_events(
  244. event_stream,
  245. error_message='SyntaxError: incomplete input',
  246. random_line=False,
  247. )
  248. with patch('logging.Logger.warning'):
  249. assert stuck_detector.is_stuck() is True
  250. def test_is_not_stuck_incomplete_input_error(
  251. self, stuck_detector: StuckDetector, event_stream: EventStream
  252. ):
  253. self._impl_syntax_error_events(
  254. event_stream,
  255. error_message='SyntaxError: incomplete input',
  256. random_line=True,
  257. )
  258. with patch('logging.Logger.warning'):
  259. assert stuck_detector.is_stuck() is False
  260. def test_is_not_stuck_ipython_unterminated_string_error_random_lines(
  261. self, stuck_detector: StuckDetector, event_stream: EventStream
  262. ):
  263. self._impl_unterminated_string_error_events(event_stream, random_line=True)
  264. with patch('logging.Logger.warning'):
  265. assert stuck_detector.is_stuck() is False
  266. def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
  267. self, stuck_detector: StuckDetector, event_stream: EventStream
  268. ):
  269. self._impl_unterminated_string_error_events(
  270. event_stream, random_line=False, incidents=3
  271. )
  272. with patch('logging.Logger.warning'):
  273. assert stuck_detector.is_stuck() is False
  274. def test_is_stuck_ipython_unterminated_string_error(
  275. self, stuck_detector: StuckDetector, event_stream: EventStream
  276. ):
  277. self._impl_unterminated_string_error_events(event_stream, random_line=False)
  278. with patch('logging.Logger.warning'):
  279. assert stuck_detector.is_stuck() is True
  280. def test_is_not_stuck_ipython_syntax_error_not_at_end(
  281. self, stuck_detector: StuckDetector, event_stream: EventStream
  282. ):
  283. # this test is to make sure we don't get false positives
  284. # since the "at line x" is changing in between!
  285. ipython_action_1 = IPythonRunCellAction(code='print("hello')
  286. event_stream.add_event(ipython_action_1, EventSource.AGENT)
  287. ipython_observation_1 = IPythonRunCellObservation(
  288. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nThis is some additional output',
  289. code='print("hello',
  290. )
  291. ipython_observation_1._cause = ipython_action_1._id
  292. event_stream.add_event(ipython_observation_1, EventSource.USER)
  293. ipython_action_2 = IPythonRunCellAction(code='print("hello')
  294. event_stream.add_event(ipython_action_2, EventSource.AGENT)
  295. ipython_observation_2 = IPythonRunCellObservation(
  296. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nToo much output here on and on',
  297. code='print("hello',
  298. )
  299. ipython_observation_2._cause = ipython_action_2._id
  300. event_stream.add_event(ipython_observation_2, EventSource.USER)
  301. ipython_action_3 = IPythonRunCellAction(code='print("hello')
  302. event_stream.add_event(ipython_action_3, EventSource.AGENT)
  303. ipython_observation_3 = IPythonRunCellObservation(
  304. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 3)\nEnough',
  305. code='print("hello',
  306. )
  307. ipython_observation_3._cause = ipython_action_3._id
  308. event_stream.add_event(ipython_observation_3, EventSource.USER)
  309. ipython_action_4 = IPythonRunCellAction(code='print("hello')
  310. event_stream.add_event(ipython_action_4, EventSource.AGENT)
  311. ipython_observation_4 = IPythonRunCellObservation(
  312. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 2)\nLast line of output',
  313. code='print("hello',
  314. )
  315. ipython_observation_4._cause = ipython_action_4._id
  316. event_stream.add_event(ipython_observation_4, EventSource.USER)
  317. with patch('logging.Logger.warning') as mock_warning:
  318. assert stuck_detector.is_stuck() is False
  319. mock_warning.assert_not_called()
  320. def test_is_stuck_repeating_action_observation_pattern(
  321. self, stuck_detector: StuckDetector, event_stream: EventStream
  322. ):
  323. message_action = MessageAction(content='Come on', wait_for_response=False)
  324. message_action._source = EventSource.USER
  325. event_stream.add_event(message_action, EventSource.USER)
  326. message_observation = NullObservation(content='')
  327. event_stream.add_event(message_observation, EventSource.USER)
  328. cmd_action_1 = CmdRunAction(command='ls')
  329. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  330. cmd_observation_1 = CmdOutputObservation(
  331. command_id=1, command='ls', content='file1.txt\nfile2.txt'
  332. )
  333. cmd_observation_1._cause = cmd_action_1._id
  334. event_stream.add_event(cmd_observation_1, EventSource.USER)
  335. read_action_1 = FileReadAction(path='file1.txt')
  336. event_stream.add_event(read_action_1, EventSource.AGENT)
  337. read_observation_1 = FileReadObservation(
  338. content='File content', path='file1.txt'
  339. )
  340. read_observation_1._cause = read_action_1._id
  341. event_stream.add_event(read_observation_1, EventSource.USER)
  342. cmd_action_2 = CmdRunAction(command='ls')
  343. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  344. cmd_observation_2 = CmdOutputObservation(
  345. command_id=2, command='ls', content='file1.txt\nfile2.txt'
  346. )
  347. cmd_observation_2._cause = cmd_action_2._id
  348. event_stream.add_event(cmd_observation_2, EventSource.USER)
  349. read_action_2 = FileReadAction(path='file1.txt')
  350. event_stream.add_event(read_action_2, EventSource.AGENT)
  351. read_observation_2 = FileReadObservation(
  352. content='File content', path='file1.txt'
  353. )
  354. read_observation_2._cause = read_action_2._id
  355. event_stream.add_event(read_observation_2, EventSource.USER)
  356. # one more message to break the pattern
  357. message_null_observation = NullObservation(content='')
  358. event_stream.add_event(message_action, EventSource.USER)
  359. event_stream.add_event(message_null_observation, EventSource.USER)
  360. cmd_action_3 = CmdRunAction(command='ls')
  361. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  362. cmd_observation_3 = CmdOutputObservation(
  363. command_id=3, command='ls', content='file1.txt\nfile2.txt'
  364. )
  365. cmd_observation_3._cause = cmd_action_3._id
  366. event_stream.add_event(cmd_observation_3, EventSource.USER)
  367. read_action_3 = FileReadAction(path='file1.txt')
  368. event_stream.add_event(read_action_3, EventSource.AGENT)
  369. read_observation_3 = FileReadObservation(
  370. content='File content', path='file1.txt'
  371. )
  372. read_observation_3._cause = read_action_3._id
  373. event_stream.add_event(read_observation_3, EventSource.USER)
  374. with patch('logging.Logger.warning') as mock_warning:
  375. assert stuck_detector.is_stuck() is True
  376. mock_warning.assert_called_once_with('Action, Observation pattern detected')
  377. def test_is_stuck_not_stuck(
  378. self, stuck_detector: StuckDetector, event_stream: EventStream
  379. ):
  380. message_action = MessageAction(content='Done', wait_for_response=False)
  381. message_action._source = EventSource.USER
  382. hello_action = MessageAction(content='Hello', wait_for_response=False)
  383. event_stream.add_event(hello_action, EventSource.USER)
  384. hello_observation = NullObservation(content='')
  385. hello_observation._cause = hello_action._id
  386. event_stream.add_event(hello_observation, EventSource.USER)
  387. cmd_action_1 = CmdRunAction(command='ls')
  388. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  389. cmd_observation_1 = CmdOutputObservation(
  390. command_id=cmd_action_1.id, command='ls', content='file1.txt\nfile2.txt'
  391. )
  392. cmd_observation_1._cause = cmd_action_1._id
  393. event_stream.add_event(cmd_observation_1, EventSource.USER)
  394. read_action_1 = FileReadAction(path='file1.txt')
  395. event_stream.add_event(read_action_1, EventSource.AGENT)
  396. read_observation_1 = FileReadObservation(
  397. content='File content', path='file1.txt'
  398. )
  399. read_observation_1._cause = read_action_1._id
  400. event_stream.add_event(read_observation_1, EventSource.USER)
  401. cmd_action_2 = CmdRunAction(command='pwd')
  402. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  403. cmd_observation_2 = CmdOutputObservation(
  404. command_id=2, command='pwd', content='/home/user'
  405. )
  406. cmd_observation_2._cause = cmd_action_2._id
  407. event_stream.add_event(cmd_observation_2, EventSource.USER)
  408. read_action_2 = FileReadAction(path='file2.txt')
  409. event_stream.add_event(read_action_2, EventSource.AGENT)
  410. read_observation_2 = FileReadObservation(
  411. content='Another file content', path='file2.txt'
  412. )
  413. read_observation_2._cause = read_action_2._id
  414. event_stream.add_event(read_observation_2, EventSource.USER)
  415. message_null_observation = NullObservation(content='')
  416. event_stream.add_event(message_action, EventSource.USER)
  417. event_stream.add_event(message_null_observation, EventSource.USER)
  418. cmd_action_3 = CmdRunAction(command='pwd')
  419. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  420. cmd_observation_3 = CmdOutputObservation(
  421. command_id=cmd_action_3.id, command='pwd', content='/home/user'
  422. )
  423. cmd_observation_3._cause = cmd_action_3._id
  424. event_stream.add_event(cmd_observation_3, EventSource.USER)
  425. read_action_3 = FileReadAction(path='file2.txt')
  426. event_stream.add_event(read_action_3, EventSource.AGENT)
  427. read_observation_3 = FileReadObservation(
  428. content='Another file content', path='file2.txt'
  429. )
  430. read_observation_3._cause = read_action_3._id
  431. event_stream.add_event(read_observation_3, EventSource.USER)
  432. assert stuck_detector.is_stuck() is False
  433. def test_is_stuck_monologue(self, stuck_detector, event_stream):
  434. # Add events to the event stream
  435. message_action_1 = MessageAction(content='Hi there!')
  436. event_stream.add_event(message_action_1, EventSource.USER)
  437. message_action_1._source = EventSource.USER
  438. message_action_2 = MessageAction(content='Hi there!')
  439. event_stream.add_event(message_action_2, EventSource.AGENT)
  440. message_action_2._source = EventSource.AGENT
  441. message_action_3 = MessageAction(content='How are you?')
  442. event_stream.add_event(message_action_3, EventSource.USER)
  443. message_action_3._source = EventSource.USER
  444. cmd_kill_action = CmdRunAction(
  445. command='echo 42', thought="I'm not stuck, he's stuck"
  446. )
  447. event_stream.add_event(cmd_kill_action, EventSource.AGENT)
  448. message_action_4 = MessageAction(content="I'm doing well, thanks for asking.")
  449. event_stream.add_event(message_action_4, EventSource.AGENT)
  450. message_action_4._source = EventSource.AGENT
  451. message_action_5 = MessageAction(content="I'm doing well, thanks for asking.")
  452. event_stream.add_event(message_action_5, EventSource.AGENT)
  453. message_action_5._source = EventSource.AGENT
  454. message_action_6 = MessageAction(content="I'm doing well, thanks for asking.")
  455. event_stream.add_event(message_action_6, EventSource.AGENT)
  456. message_action_6._source = EventSource.AGENT
  457. assert stuck_detector.is_stuck()
  458. # Add an observation event between the repeated message actions
  459. cmd_output_observation = CmdOutputObservation(
  460. content='OK, I was stuck, but no more.',
  461. command_id=42,
  462. command='storybook',
  463. exit_code=0,
  464. )
  465. cmd_output_observation._cause = cmd_kill_action._id
  466. event_stream.add_event(cmd_output_observation, EventSource.USER)
  467. message_action_7 = MessageAction(content="I'm doing well, thanks for asking.")
  468. event_stream.add_event(message_action_7, EventSource.AGENT)
  469. message_action_7._source = EventSource.AGENT
  470. message_action_8 = MessageAction(content="I'm doing well, thanks for asking.")
  471. event_stream.add_event(message_action_8, EventSource.AGENT)
  472. message_action_8._source = EventSource.AGENT
  473. with patch('logging.Logger.warning'):
  474. assert not stuck_detector.is_stuck()
  475. class TestAgentController:
  476. @pytest.fixture
  477. def controller(self):
  478. controller = Mock(spec=AgentController)
  479. controller._is_stuck = AgentController._is_stuck.__get__(
  480. controller, AgentController
  481. )
  482. controller.delegate = None
  483. controller.state = Mock()
  484. controller.state.history = ShortTermHistory()
  485. return controller
  486. def test_is_stuck_delegate_stuck(self, controller: AgentController):
  487. controller.delegate = Mock()
  488. controller.delegate._is_stuck.return_value = True
  489. assert controller._is_stuck() is True