test_is_stuck.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. import logging
  2. from unittest.mock import Mock, patch
  3. import pytest
  4. from pytest import TempPathFactory
  5. from openhands.controller.agent_controller import AgentController
  6. from openhands.controller.state.state import State
  7. from openhands.controller.stuck import StuckDetector
  8. from openhands.events.action import CmdRunAction, FileReadAction, MessageAction
  9. from openhands.events.action.commands import IPythonRunCellAction
  10. from openhands.events.observation import (
  11. CmdOutputObservation,
  12. FileReadObservation,
  13. )
  14. from openhands.events.observation.commands import IPythonRunCellObservation
  15. from openhands.events.observation.empty import NullObservation
  16. from openhands.events.observation.error import ErrorObservation
  17. from openhands.events.stream import EventSource, EventStream
  18. from openhands.events.utils import get_pairs_from_events
  19. from openhands.memory.history import ShortTermHistory
  20. from openhands.storage import get_file_store
  21. def collect_events(stream):
  22. return [event for event in stream.get_events()]
  23. logging.basicConfig(level=logging.DEBUG)
  24. jupyter_line_1 = '\n[Jupyter current working directory:'
  25. jupyter_line_2 = '\n[Jupyter Python interpreter:'
  26. code_snippet = """
  27. edit_file_by_replace(
  28. 'book_store.py',
  29. to_replace=\"""def total(basket):
  30. if not basket:
  31. return 0
  32. """
  33. @pytest.fixture
  34. def temp_dir(tmp_path_factory: TempPathFactory) -> str:
  35. return str(tmp_path_factory.mktemp('test_is_stuck'))
  36. @pytest.fixture
  37. def event_stream(temp_dir):
  38. file_store = get_file_store('local', temp_dir)
  39. event_stream = EventStream('asdf', file_store)
  40. yield event_stream
  41. # clear after each test
  42. event_stream.clear()
  43. class TestStuckDetector:
  44. @pytest.fixture
  45. def stuck_detector(self, event_stream):
  46. state = State(inputs={}, max_iterations=50)
  47. state.history.set_event_stream(event_stream)
  48. return StuckDetector(state)
  49. def _impl_syntax_error_events(
  50. self,
  51. event_stream: EventStream,
  52. error_message: str,
  53. random_line: bool,
  54. incidents: int = 4,
  55. ):
  56. for i in range(incidents):
  57. ipython_action = IPythonRunCellAction(code=code_snippet)
  58. event_stream.add_event(ipython_action, EventSource.AGENT)
  59. extra_number = (i + 1) * 10 if random_line else '42'
  60. extra_line = '\n' * (i + 1) if random_line else ''
  61. ipython_observation = IPythonRunCellObservation(
  62. content=f' Cell In[1], line {extra_number}\n'
  63. 'to_replace="""def largest(min_factor, max_factor):\n ^\n'
  64. f'{error_message}{extra_line}' + jupyter_line_1 + jupyter_line_2,
  65. code=code_snippet,
  66. )
  67. ipython_observation._cause = ipython_action._id
  68. event_stream.add_event(ipython_observation, EventSource.USER)
  69. def _impl_unterminated_string_error_events(
  70. self, event_stream: EventStream, random_line: bool, incidents: int = 4
  71. ):
  72. for i in range(incidents):
  73. ipython_action = IPythonRunCellAction(code=code_snippet)
  74. event_stream.add_event(ipython_action, EventSource.AGENT)
  75. line_number = (i + 1) * 10 if random_line else '1'
  76. ipython_observation = IPythonRunCellObservation(
  77. content=f'print(" Cell In[1], line {line_number}\nhello\n ^\nSyntaxError: unterminated string literal (detected at line {line_number})'
  78. + jupyter_line_1
  79. + jupyter_line_2,
  80. code=code_snippet,
  81. )
  82. ipython_observation._cause = ipython_action._id
  83. event_stream.add_event(ipython_observation, EventSource.USER)
  84. def test_history_too_short(
  85. self, stuck_detector: StuckDetector, event_stream: EventStream
  86. ):
  87. message_action = MessageAction(content='Hello', wait_for_response=False)
  88. message_action._source = EventSource.USER
  89. observation = NullObservation(content='')
  90. observation._cause = message_action.id
  91. event_stream.add_event(message_action, EventSource.USER)
  92. event_stream.add_event(observation, EventSource.USER)
  93. cmd_action = CmdRunAction(command='ls')
  94. event_stream.add_event(cmd_action, EventSource.AGENT)
  95. cmd_observation = CmdOutputObservation(
  96. command_id=1, command='ls', content='file1.txt\nfile2.txt'
  97. )
  98. cmd_observation._cause = cmd_action._id
  99. event_stream.add_event(cmd_observation, EventSource.USER)
  100. # stuck_detector.state.history.set_event_stream(event_stream)
  101. assert stuck_detector.is_stuck() is False
  102. def test_is_stuck_repeating_action_observation(
  103. self, stuck_detector: StuckDetector, event_stream: EventStream
  104. ):
  105. message_action = MessageAction(content='Done', wait_for_response=False)
  106. message_action._source = EventSource.USER
  107. hello_action = MessageAction(content='Hello', wait_for_response=False)
  108. hello_observation = NullObservation('')
  109. # 2 events
  110. event_stream.add_event(hello_action, EventSource.USER)
  111. event_stream.add_event(hello_observation, EventSource.USER)
  112. cmd_action_1 = CmdRunAction(command='ls')
  113. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  114. cmd_observation_1 = CmdOutputObservation(
  115. content='', command='ls', command_id=cmd_action_1._id
  116. )
  117. cmd_observation_1._cause = cmd_action_1._id
  118. event_stream.add_event(cmd_observation_1, EventSource.USER)
  119. # 4 events
  120. cmd_action_2 = CmdRunAction(command='ls')
  121. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  122. cmd_observation_2 = CmdOutputObservation(
  123. content='', command='ls', command_id=cmd_action_2._id
  124. )
  125. cmd_observation_2._cause = cmd_action_2._id
  126. event_stream.add_event(cmd_observation_2, EventSource.USER)
  127. # 6 events
  128. # random user message just because we can
  129. message_null_observation = NullObservation(content='')
  130. event_stream.add_event(message_action, EventSource.USER)
  131. event_stream.add_event(message_null_observation, EventSource.USER)
  132. # 8 events
  133. assert stuck_detector.is_stuck() is False
  134. assert stuck_detector.state.almost_stuck == 2
  135. cmd_action_3 = CmdRunAction(command='ls')
  136. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  137. cmd_observation_3 = CmdOutputObservation(
  138. content='', command='ls', command_id=cmd_action_3._id
  139. )
  140. cmd_observation_3._cause = cmd_action_3._id
  141. event_stream.add_event(cmd_observation_3, EventSource.USER)
  142. # 10 events
  143. assert len(collect_events(event_stream)) == 10
  144. assert len(list(stuck_detector.state.history.get_events())) == 8
  145. assert (
  146. len(
  147. get_pairs_from_events(
  148. stuck_detector.state.history.get_events_as_list(
  149. include_delegates=True
  150. )
  151. )
  152. )
  153. == 5
  154. )
  155. assert stuck_detector.is_stuck() is False
  156. assert stuck_detector.state.almost_stuck == 1
  157. cmd_action_4 = CmdRunAction(command='ls')
  158. event_stream.add_event(cmd_action_4, EventSource.AGENT)
  159. cmd_observation_4 = CmdOutputObservation(
  160. content='', command='ls', command_id=cmd_action_4._id
  161. )
  162. cmd_observation_4._cause = cmd_action_4._id
  163. event_stream.add_event(cmd_observation_4, EventSource.USER)
  164. # 12 events
  165. assert len(collect_events(event_stream)) == 12
  166. assert len(list(stuck_detector.state.history.get_events())) == 10
  167. assert (
  168. len(
  169. get_pairs_from_events(
  170. stuck_detector.state.history.get_events_as_list(
  171. include_delegates=True
  172. )
  173. )
  174. )
  175. == 6
  176. )
  177. with patch('logging.Logger.warning') as mock_warning:
  178. assert stuck_detector.is_stuck() is True
  179. assert stuck_detector.state.almost_stuck == 0
  180. mock_warning.assert_called_once_with('Action, Observation loop detected')
  181. def test_is_stuck_repeating_action_error(
  182. self, stuck_detector: StuckDetector, event_stream: EventStream
  183. ):
  184. # (action, error_observation), not necessarily the same error
  185. message_action = MessageAction(content='Done', wait_for_response=False)
  186. message_action._source = EventSource.USER
  187. hello_action = MessageAction(content='Hello', wait_for_response=False)
  188. hello_observation = NullObservation(content='')
  189. event_stream.add_event(hello_action, EventSource.USER)
  190. hello_observation._cause = hello_action._id
  191. event_stream.add_event(hello_observation, EventSource.USER)
  192. # 2 events
  193. cmd_action_1 = CmdRunAction(command='invalid_command')
  194. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  195. error_observation_1 = ErrorObservation(content='Command not found')
  196. error_observation_1._cause = cmd_action_1._id
  197. event_stream.add_event(error_observation_1, EventSource.USER)
  198. # 4 events
  199. cmd_action_2 = CmdRunAction(command='invalid_command')
  200. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  201. error_observation_2 = ErrorObservation(
  202. content='Command still not found or another error'
  203. )
  204. error_observation_2._cause = cmd_action_2._id
  205. event_stream.add_event(error_observation_2, EventSource.USER)
  206. # 6 events
  207. message_null_observation = NullObservation(content='')
  208. event_stream.add_event(message_action, EventSource.USER)
  209. event_stream.add_event(message_null_observation, EventSource.USER)
  210. # 8 events
  211. cmd_action_3 = CmdRunAction(command='invalid_command')
  212. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  213. error_observation_3 = ErrorObservation(content='Different error')
  214. error_observation_3._cause = cmd_action_3._id
  215. event_stream.add_event(error_observation_3, EventSource.USER)
  216. # 10 events
  217. cmd_action_4 = CmdRunAction(command='invalid_command')
  218. event_stream.add_event(cmd_action_4, EventSource.AGENT)
  219. error_observation_4 = ErrorObservation(content='Command not found')
  220. error_observation_4._cause = cmd_action_4._id
  221. event_stream.add_event(error_observation_4, EventSource.USER)
  222. # 12 events
  223. with patch('logging.Logger.warning') as mock_warning:
  224. assert stuck_detector.is_stuck() is True
  225. mock_warning.assert_called_once_with(
  226. 'Action, ErrorObservation loop detected'
  227. )
  228. def test_is_stuck_invalid_syntax_error(
  229. self, stuck_detector: StuckDetector, event_stream: EventStream
  230. ):
  231. self._impl_syntax_error_events(
  232. event_stream,
  233. error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
  234. random_line=False,
  235. )
  236. with patch('logging.Logger.warning'):
  237. assert stuck_detector.is_stuck() is True
  238. def test_is_not_stuck_invalid_syntax_error_random_lines(
  239. self, stuck_detector: StuckDetector, event_stream: EventStream
  240. ):
  241. self._impl_syntax_error_events(
  242. event_stream,
  243. error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
  244. random_line=True,
  245. )
  246. with patch('logging.Logger.warning'):
  247. assert stuck_detector.is_stuck() is False
  248. def test_is_not_stuck_invalid_syntax_error_only_three_incidents(
  249. self, stuck_detector: StuckDetector, event_stream: EventStream
  250. ):
  251. self._impl_syntax_error_events(
  252. event_stream,
  253. error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
  254. random_line=True,
  255. incidents=3,
  256. )
  257. with patch('logging.Logger.warning'):
  258. assert stuck_detector.is_stuck() is False
  259. def test_is_stuck_incomplete_input_error(
  260. self, stuck_detector: StuckDetector, event_stream: EventStream
  261. ):
  262. self._impl_syntax_error_events(
  263. event_stream,
  264. error_message='SyntaxError: incomplete input',
  265. random_line=False,
  266. )
  267. with patch('logging.Logger.warning'):
  268. assert stuck_detector.is_stuck() is True
  269. def test_is_not_stuck_incomplete_input_error(
  270. self, stuck_detector: StuckDetector, event_stream: EventStream
  271. ):
  272. self._impl_syntax_error_events(
  273. event_stream,
  274. error_message='SyntaxError: incomplete input',
  275. random_line=True,
  276. )
  277. with patch('logging.Logger.warning'):
  278. assert stuck_detector.is_stuck() is False
  279. def test_is_not_stuck_ipython_unterminated_string_error_random_lines(
  280. self, stuck_detector: StuckDetector, event_stream: EventStream
  281. ):
  282. self._impl_unterminated_string_error_events(event_stream, random_line=True)
  283. with patch('logging.Logger.warning'):
  284. assert stuck_detector.is_stuck() is False
  285. def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
  286. self, stuck_detector: StuckDetector, event_stream: EventStream
  287. ):
  288. self._impl_unterminated_string_error_events(
  289. event_stream, random_line=False, incidents=3
  290. )
  291. with patch('logging.Logger.warning'):
  292. assert stuck_detector.is_stuck() is False
  293. def test_is_stuck_ipython_unterminated_string_error(
  294. self, stuck_detector: StuckDetector, event_stream: EventStream
  295. ):
  296. self._impl_unterminated_string_error_events(event_stream, random_line=False)
  297. with patch('logging.Logger.warning'):
  298. assert stuck_detector.is_stuck() is True
  299. def test_is_not_stuck_ipython_syntax_error_not_at_end(
  300. self, stuck_detector: StuckDetector, event_stream: EventStream
  301. ):
  302. # this test is to make sure we don't get false positives
  303. # since the "at line x" is changing in between!
  304. ipython_action_1 = IPythonRunCellAction(code='print("hello')
  305. event_stream.add_event(ipython_action_1, EventSource.AGENT)
  306. ipython_observation_1 = IPythonRunCellObservation(
  307. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nThis is some additional output',
  308. code='print("hello',
  309. )
  310. ipython_observation_1._cause = ipython_action_1._id
  311. event_stream.add_event(ipython_observation_1, EventSource.USER)
  312. ipython_action_2 = IPythonRunCellAction(code='print("hello')
  313. event_stream.add_event(ipython_action_2, EventSource.AGENT)
  314. ipython_observation_2 = IPythonRunCellObservation(
  315. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nToo much output here on and on',
  316. code='print("hello',
  317. )
  318. ipython_observation_2._cause = ipython_action_2._id
  319. event_stream.add_event(ipython_observation_2, EventSource.USER)
  320. ipython_action_3 = IPythonRunCellAction(code='print("hello')
  321. event_stream.add_event(ipython_action_3, EventSource.AGENT)
  322. ipython_observation_3 = IPythonRunCellObservation(
  323. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 3)\nEnough',
  324. code='print("hello',
  325. )
  326. ipython_observation_3._cause = ipython_action_3._id
  327. event_stream.add_event(ipython_observation_3, EventSource.USER)
  328. ipython_action_4 = IPythonRunCellAction(code='print("hello')
  329. event_stream.add_event(ipython_action_4, EventSource.AGENT)
  330. ipython_observation_4 = IPythonRunCellObservation(
  331. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 2)\nLast line of output',
  332. code='print("hello',
  333. )
  334. ipython_observation_4._cause = ipython_action_4._id
  335. event_stream.add_event(ipython_observation_4, EventSource.USER)
  336. with patch('logging.Logger.warning') as mock_warning:
  337. assert stuck_detector.is_stuck() is False
  338. mock_warning.assert_not_called()
  339. def test_is_stuck_repeating_action_observation_pattern(
  340. self, stuck_detector: StuckDetector, event_stream: EventStream
  341. ):
  342. message_action = MessageAction(content='Come on', wait_for_response=False)
  343. message_action._source = EventSource.USER
  344. event_stream.add_event(message_action, EventSource.USER)
  345. message_observation = NullObservation(content='')
  346. event_stream.add_event(message_observation, EventSource.USER)
  347. cmd_action_1 = CmdRunAction(command='ls')
  348. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  349. cmd_observation_1 = CmdOutputObservation(
  350. command_id=1, command='ls', content='file1.txt\nfile2.txt'
  351. )
  352. cmd_observation_1._cause = cmd_action_1._id
  353. event_stream.add_event(cmd_observation_1, EventSource.USER)
  354. read_action_1 = FileReadAction(path='file1.txt')
  355. event_stream.add_event(read_action_1, EventSource.AGENT)
  356. read_observation_1 = FileReadObservation(
  357. content='File content', path='file1.txt'
  358. )
  359. read_observation_1._cause = read_action_1._id
  360. event_stream.add_event(read_observation_1, EventSource.USER)
  361. cmd_action_2 = CmdRunAction(command='ls')
  362. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  363. cmd_observation_2 = CmdOutputObservation(
  364. command_id=2, command='ls', content='file1.txt\nfile2.txt'
  365. )
  366. cmd_observation_2._cause = cmd_action_2._id
  367. event_stream.add_event(cmd_observation_2, EventSource.USER)
  368. read_action_2 = FileReadAction(path='file1.txt')
  369. event_stream.add_event(read_action_2, EventSource.AGENT)
  370. read_observation_2 = FileReadObservation(
  371. content='File content', path='file1.txt'
  372. )
  373. read_observation_2._cause = read_action_2._id
  374. event_stream.add_event(read_observation_2, EventSource.USER)
  375. # one more message to break the pattern
  376. message_null_observation = NullObservation(content='')
  377. event_stream.add_event(message_action, EventSource.USER)
  378. event_stream.add_event(message_null_observation, EventSource.USER)
  379. cmd_action_3 = CmdRunAction(command='ls')
  380. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  381. cmd_observation_3 = CmdOutputObservation(
  382. command_id=3, command='ls', content='file1.txt\nfile2.txt'
  383. )
  384. cmd_observation_3._cause = cmd_action_3._id
  385. event_stream.add_event(cmd_observation_3, EventSource.USER)
  386. read_action_3 = FileReadAction(path='file1.txt')
  387. event_stream.add_event(read_action_3, EventSource.AGENT)
  388. read_observation_3 = FileReadObservation(
  389. content='File content', path='file1.txt'
  390. )
  391. read_observation_3._cause = read_action_3._id
  392. event_stream.add_event(read_observation_3, EventSource.USER)
  393. with patch('logging.Logger.warning') as mock_warning:
  394. assert stuck_detector.is_stuck() is True
  395. mock_warning.assert_called_once_with('Action, Observation pattern detected')
  396. def test_is_stuck_not_stuck(
  397. self, stuck_detector: StuckDetector, event_stream: EventStream
  398. ):
  399. message_action = MessageAction(content='Done', wait_for_response=False)
  400. message_action._source = EventSource.USER
  401. hello_action = MessageAction(content='Hello', wait_for_response=False)
  402. event_stream.add_event(hello_action, EventSource.USER)
  403. hello_observation = NullObservation(content='')
  404. hello_observation._cause = hello_action._id
  405. event_stream.add_event(hello_observation, EventSource.USER)
  406. cmd_action_1 = CmdRunAction(command='ls')
  407. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  408. cmd_observation_1 = CmdOutputObservation(
  409. command_id=cmd_action_1.id, command='ls', content='file1.txt\nfile2.txt'
  410. )
  411. cmd_observation_1._cause = cmd_action_1._id
  412. event_stream.add_event(cmd_observation_1, EventSource.USER)
  413. read_action_1 = FileReadAction(path='file1.txt')
  414. event_stream.add_event(read_action_1, EventSource.AGENT)
  415. read_observation_1 = FileReadObservation(
  416. content='File content', path='file1.txt'
  417. )
  418. read_observation_1._cause = read_action_1._id
  419. event_stream.add_event(read_observation_1, EventSource.USER)
  420. cmd_action_2 = CmdRunAction(command='pwd')
  421. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  422. cmd_observation_2 = CmdOutputObservation(
  423. command_id=2, command='pwd', content='/home/user'
  424. )
  425. cmd_observation_2._cause = cmd_action_2._id
  426. event_stream.add_event(cmd_observation_2, EventSource.USER)
  427. read_action_2 = FileReadAction(path='file2.txt')
  428. event_stream.add_event(read_action_2, EventSource.AGENT)
  429. read_observation_2 = FileReadObservation(
  430. content='Another file content', path='file2.txt'
  431. )
  432. read_observation_2._cause = read_action_2._id
  433. event_stream.add_event(read_observation_2, EventSource.USER)
  434. message_null_observation = NullObservation(content='')
  435. event_stream.add_event(message_action, EventSource.USER)
  436. event_stream.add_event(message_null_observation, EventSource.USER)
  437. cmd_action_3 = CmdRunAction(command='pwd')
  438. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  439. cmd_observation_3 = CmdOutputObservation(
  440. command_id=cmd_action_3.id, command='pwd', content='/home/user'
  441. )
  442. cmd_observation_3._cause = cmd_action_3._id
  443. event_stream.add_event(cmd_observation_3, EventSource.USER)
  444. read_action_3 = FileReadAction(path='file2.txt')
  445. event_stream.add_event(read_action_3, EventSource.AGENT)
  446. read_observation_3 = FileReadObservation(
  447. content='Another file content', path='file2.txt'
  448. )
  449. read_observation_3._cause = read_action_3._id
  450. event_stream.add_event(read_observation_3, EventSource.USER)
  451. assert stuck_detector.is_stuck() is False
  452. def test_is_stuck_monologue(self, stuck_detector, event_stream):
  453. # Add events to the event stream
  454. message_action_1 = MessageAction(content='Hi there!')
  455. event_stream.add_event(message_action_1, EventSource.USER)
  456. message_action_1._source = EventSource.USER
  457. message_action_2 = MessageAction(content='Hi there!')
  458. event_stream.add_event(message_action_2, EventSource.AGENT)
  459. message_action_2._source = EventSource.AGENT
  460. message_action_3 = MessageAction(content='How are you?')
  461. event_stream.add_event(message_action_3, EventSource.USER)
  462. message_action_3._source = EventSource.USER
  463. cmd_kill_action = CmdRunAction(
  464. command='echo 42', thought="I'm not stuck, he's stuck"
  465. )
  466. event_stream.add_event(cmd_kill_action, EventSource.AGENT)
  467. message_action_4 = MessageAction(content="I'm doing well, thanks for asking.")
  468. event_stream.add_event(message_action_4, EventSource.AGENT)
  469. message_action_4._source = EventSource.AGENT
  470. message_action_5 = MessageAction(content="I'm doing well, thanks for asking.")
  471. event_stream.add_event(message_action_5, EventSource.AGENT)
  472. message_action_5._source = EventSource.AGENT
  473. message_action_6 = MessageAction(content="I'm doing well, thanks for asking.")
  474. event_stream.add_event(message_action_6, EventSource.AGENT)
  475. message_action_6._source = EventSource.AGENT
  476. assert stuck_detector.is_stuck()
  477. # Add an observation event between the repeated message actions
  478. cmd_output_observation = CmdOutputObservation(
  479. content='OK, I was stuck, but no more.',
  480. command_id=42,
  481. command='storybook',
  482. exit_code=0,
  483. )
  484. cmd_output_observation._cause = cmd_kill_action._id
  485. event_stream.add_event(cmd_output_observation, EventSource.USER)
  486. message_action_7 = MessageAction(content="I'm doing well, thanks for asking.")
  487. event_stream.add_event(message_action_7, EventSource.AGENT)
  488. message_action_7._source = EventSource.AGENT
  489. message_action_8 = MessageAction(content="I'm doing well, thanks for asking.")
  490. event_stream.add_event(message_action_8, EventSource.AGENT)
  491. message_action_8._source = EventSource.AGENT
  492. with patch('logging.Logger.warning'):
  493. assert not stuck_detector.is_stuck()
  494. class TestAgentController:
  495. @pytest.fixture
  496. def controller(self):
  497. controller = Mock(spec=AgentController)
  498. controller._is_stuck = AgentController._is_stuck.__get__(
  499. controller, AgentController
  500. )
  501. controller.delegate = None
  502. controller.state = Mock()
  503. controller.state.history = ShortTermHistory()
  504. return controller
  505. def test_is_stuck_delegate_stuck(self, controller: AgentController):
  506. controller.delegate = Mock()
  507. controller.delegate._is_stuck.return_value = True
  508. assert controller._is_stuck() is True