test_is_stuck.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. import logging
  2. from unittest.mock import Mock, patch
  3. import pytest
  4. from pytest import TempPathFactory
  5. from openhands.controller.agent_controller import AgentController
  6. from openhands.controller.state.state import State
  7. from openhands.controller.stuck import StuckDetector
  8. from openhands.events.action import CmdRunAction, FileReadAction, MessageAction
  9. from openhands.events.action.commands import IPythonRunCellAction
  10. from openhands.events.observation import (
  11. CmdOutputObservation,
  12. FileReadObservation,
  13. )
  14. from openhands.events.observation.commands import IPythonRunCellObservation
  15. from openhands.events.observation.empty import NullObservation
  16. from openhands.events.observation.error import ErrorObservation
  17. from openhands.events.stream import EventSource, EventStream
  18. from openhands.memory.history import ShortTermHistory
  19. from openhands.storage import get_file_store
  20. def collect_events(stream):
  21. return [event for event in stream.get_events()]
  22. logging.basicConfig(level=logging.DEBUG)
  23. @pytest.fixture
  24. def temp_dir(tmp_path_factory: TempPathFactory) -> str:
  25. return str(tmp_path_factory.mktemp('test_is_stuck'))
  26. @pytest.fixture
  27. def event_stream(temp_dir):
  28. file_store = get_file_store('local', temp_dir)
  29. event_stream = EventStream('asdf', file_store)
  30. yield event_stream
  31. # clear after each test
  32. event_stream.clear()
  33. class TestStuckDetector:
  34. @pytest.fixture
  35. def stuck_detector(self, event_stream):
  36. state = State(inputs={}, max_iterations=50)
  37. # state.history = ShortTermHistory()
  38. state.history.set_event_stream(event_stream)
  39. return StuckDetector(state)
  40. def test_history_too_short(
  41. self, stuck_detector: StuckDetector, event_stream: EventStream
  42. ):
  43. message_action = MessageAction(content='Hello', wait_for_response=False)
  44. message_action._source = EventSource.USER
  45. observation = NullObservation(content='')
  46. observation._cause = message_action.id
  47. event_stream.add_event(message_action, EventSource.USER)
  48. event_stream.add_event(observation, EventSource.USER)
  49. cmd_action = CmdRunAction(command='ls')
  50. event_stream.add_event(cmd_action, EventSource.AGENT)
  51. cmd_observation = CmdOutputObservation(
  52. command_id=1, command='ls', content='file1.txt\nfile2.txt'
  53. )
  54. cmd_observation._cause = cmd_action._id
  55. event_stream.add_event(cmd_observation, EventSource.USER)
  56. # stuck_detector.state.history.set_event_stream(event_stream)
  57. assert stuck_detector.is_stuck() is False
  58. def test_is_stuck_repeating_action_observation(
  59. self, stuck_detector: StuckDetector, event_stream: EventStream
  60. ):
  61. message_action = MessageAction(content='Done', wait_for_response=False)
  62. message_action._source = EventSource.USER
  63. hello_action = MessageAction(content='Hello', wait_for_response=False)
  64. hello_observation = NullObservation('')
  65. # 2 events
  66. event_stream.add_event(hello_action, EventSource.USER)
  67. event_stream.add_event(hello_observation, EventSource.USER)
  68. cmd_action_1 = CmdRunAction(command='ls')
  69. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  70. cmd_observation_1 = CmdOutputObservation(
  71. content='', command='ls', command_id=cmd_action_1._id
  72. )
  73. cmd_observation_1._cause = cmd_action_1._id
  74. event_stream.add_event(cmd_observation_1, EventSource.USER)
  75. # 4 events
  76. cmd_action_2 = CmdRunAction(command='ls')
  77. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  78. cmd_observation_2 = CmdOutputObservation(
  79. content='', command='ls', command_id=cmd_action_2._id
  80. )
  81. cmd_observation_2._cause = cmd_action_2._id
  82. event_stream.add_event(cmd_observation_2, EventSource.USER)
  83. # 6 events
  84. # random user message just because we can
  85. message_null_observation = NullObservation(content='')
  86. event_stream.add_event(message_action, EventSource.USER)
  87. event_stream.add_event(message_null_observation, EventSource.USER)
  88. # 8 events
  89. assert stuck_detector.is_stuck() is False
  90. assert stuck_detector.state.almost_stuck == 2
  91. cmd_action_3 = CmdRunAction(command='ls')
  92. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  93. cmd_observation_3 = CmdOutputObservation(
  94. content='', command='ls', command_id=cmd_action_3._id
  95. )
  96. cmd_observation_3._cause = cmd_action_3._id
  97. event_stream.add_event(cmd_observation_3, EventSource.USER)
  98. # 10 events
  99. assert len(collect_events(event_stream)) == 10
  100. assert len(list(stuck_detector.state.history.get_events())) == 8
  101. assert len(stuck_detector.state.history.get_pairs()) == 5
  102. assert stuck_detector.is_stuck() is False
  103. assert stuck_detector.state.almost_stuck == 1
  104. cmd_action_4 = CmdRunAction(command='ls')
  105. event_stream.add_event(cmd_action_4, EventSource.AGENT)
  106. cmd_observation_4 = CmdOutputObservation(
  107. content='', command='ls', command_id=cmd_action_4._id
  108. )
  109. cmd_observation_4._cause = cmd_action_4._id
  110. event_stream.add_event(cmd_observation_4, EventSource.USER)
  111. # 12 events
  112. assert len(collect_events(event_stream)) == 12
  113. assert len(list(stuck_detector.state.history.get_events())) == 10
  114. assert len(stuck_detector.state.history.get_pairs()) == 6
  115. with patch('logging.Logger.warning') as mock_warning:
  116. assert stuck_detector.is_stuck() is True
  117. assert stuck_detector.state.almost_stuck == 0
  118. mock_warning.assert_called_once_with('Action, Observation loop detected')
  119. def test_is_stuck_repeating_action_error(
  120. self, stuck_detector: StuckDetector, event_stream: EventStream
  121. ):
  122. # (action, error_observation), not necessarily the same error
  123. message_action = MessageAction(content='Done', wait_for_response=False)
  124. message_action._source = EventSource.USER
  125. hello_action = MessageAction(content='Hello', wait_for_response=False)
  126. hello_observation = NullObservation(content='')
  127. event_stream.add_event(hello_action, EventSource.USER)
  128. hello_observation._cause = hello_action._id
  129. event_stream.add_event(hello_observation, EventSource.USER)
  130. # 2 events
  131. cmd_action_1 = CmdRunAction(command='invalid_command')
  132. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  133. error_observation_1 = ErrorObservation(content='Command not found')
  134. error_observation_1._cause = cmd_action_1._id
  135. event_stream.add_event(error_observation_1, EventSource.USER)
  136. # 4 events
  137. cmd_action_2 = CmdRunAction(command='invalid_command')
  138. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  139. error_observation_2 = ErrorObservation(
  140. content='Command still not found or another error'
  141. )
  142. error_observation_2._cause = cmd_action_2._id
  143. event_stream.add_event(error_observation_2, EventSource.USER)
  144. # 6 events
  145. message_null_observation = NullObservation(content='')
  146. event_stream.add_event(message_action, EventSource.USER)
  147. event_stream.add_event(message_null_observation, EventSource.USER)
  148. # 8 events
  149. cmd_action_3 = CmdRunAction(command='invalid_command')
  150. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  151. error_observation_3 = ErrorObservation(content='Different error')
  152. error_observation_3._cause = cmd_action_3._id
  153. event_stream.add_event(error_observation_3, EventSource.USER)
  154. # 10 events
  155. cmd_action_4 = CmdRunAction(command='invalid_command')
  156. event_stream.add_event(cmd_action_4, EventSource.AGENT)
  157. error_observation_4 = ErrorObservation(content='Command not found')
  158. error_observation_4._cause = cmd_action_4._id
  159. event_stream.add_event(error_observation_4, EventSource.USER)
  160. # 12 events
  161. with patch('logging.Logger.warning') as mock_warning:
  162. assert stuck_detector.is_stuck() is True
  163. mock_warning.assert_called_once_with(
  164. 'Action, ErrorObservation loop detected'
  165. )
  166. def test_is_stuck_ipython_syntax_error(
  167. self, stuck_detector: StuckDetector, event_stream: EventStream
  168. ):
  169. ipython_action_1 = IPythonRunCellAction(code='print("hello')
  170. event_stream.add_event(ipython_action_1, EventSource.AGENT)
  171. ipython_observation_1 = IPythonRunCellObservation(
  172. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)',
  173. code='print("hello',
  174. )
  175. ipython_observation_1._cause = ipython_action_1._id
  176. event_stream.add_event(ipython_observation_1, EventSource.USER)
  177. ipython_action_2 = IPythonRunCellAction(code='print("hello')
  178. event_stream.add_event(ipython_action_2, EventSource.AGENT)
  179. ipython_observation_2 = IPythonRunCellObservation(
  180. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)',
  181. code='print("hello',
  182. )
  183. ipython_observation_2._cause = ipython_action_2._id
  184. event_stream.add_event(ipython_observation_2, EventSource.USER)
  185. ipython_action_3 = IPythonRunCellAction(code='print("hello')
  186. event_stream.add_event(ipython_action_3, EventSource.AGENT)
  187. ipython_observation_3 = IPythonRunCellObservation(
  188. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 3)',
  189. code='print("hello',
  190. )
  191. ipython_observation_3._cause = ipython_action_3._id
  192. event_stream.add_event(ipython_observation_3, EventSource.USER)
  193. ipython_action_4 = IPythonRunCellAction(code='print("hello')
  194. event_stream.add_event(ipython_action_4, EventSource.AGENT)
  195. ipython_observation_4 = IPythonRunCellObservation(
  196. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 2)',
  197. code='print("hello',
  198. )
  199. ipython_observation_4._cause = ipython_action_4._id
  200. event_stream.add_event(ipython_observation_4, EventSource.USER)
  201. # stuck_detector.state.history.set_event_stream(event_stream)
  202. last_observations = [
  203. ipython_observation_1,
  204. ipython_observation_2,
  205. ipython_observation_3,
  206. ipython_observation_4,
  207. ]
  208. for observation in last_observations:
  209. has_string = (
  210. observation.content[-100:].find(
  211. 'SyntaxError: unterminated string literal (detected at line'
  212. )
  213. != -1
  214. )
  215. assert has_string
  216. string_is_last = (
  217. len(
  218. observation.content.split(
  219. 'SyntaxError: unterminated string literal (detected at line'
  220. )[-1]
  221. )
  222. < 10
  223. )
  224. assert string_is_last
  225. with patch('logging.Logger.warning') as mock_warning:
  226. assert stuck_detector.is_stuck() is True
  227. mock_warning.assert_called_once_with(
  228. 'Action, IPythonRunCellObservation loop detected'
  229. )
  230. def test_is_stuck_ipython_syntax_error_not_at_end(
  231. self, stuck_detector: StuckDetector, event_stream: EventStream
  232. ):
  233. ipython_action_1 = IPythonRunCellAction(code='print("hello')
  234. event_stream.add_event(ipython_action_1, EventSource.AGENT)
  235. ipython_observation_1 = IPythonRunCellObservation(
  236. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nThis is some additional output',
  237. code='print("hello',
  238. )
  239. ipython_observation_1._cause = ipython_action_1._id
  240. event_stream.add_event(ipython_observation_1, EventSource.USER)
  241. ipython_action_2 = IPythonRunCellAction(code='print("hello')
  242. event_stream.add_event(ipython_action_2, EventSource.AGENT)
  243. ipython_observation_2 = IPythonRunCellObservation(
  244. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nToo much output here on and on',
  245. code='print("hello',
  246. )
  247. ipython_observation_2._cause = ipython_action_2._id
  248. event_stream.add_event(ipython_observation_2, EventSource.USER)
  249. ipython_action_3 = IPythonRunCellAction(code='print("hello')
  250. event_stream.add_event(ipython_action_3, EventSource.AGENT)
  251. ipython_observation_3 = IPythonRunCellObservation(
  252. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 3)\nEnough',
  253. code='print("hello',
  254. )
  255. ipython_observation_3._cause = ipython_action_3._id
  256. event_stream.add_event(ipython_observation_3, EventSource.USER)
  257. ipython_action_4 = IPythonRunCellAction(code='print("hello')
  258. event_stream.add_event(ipython_action_4, EventSource.AGENT)
  259. ipython_observation_4 = IPythonRunCellObservation(
  260. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 2)\nLast line of output',
  261. code='print("hello',
  262. )
  263. ipython_observation_4._cause = ipython_action_4._id
  264. event_stream.add_event(ipython_observation_4, EventSource.USER)
  265. with patch('logging.Logger.warning') as mock_warning:
  266. assert stuck_detector.is_stuck() is False
  267. mock_warning.assert_not_called()
  268. def test_is_stuck_repeating_action_observation_pattern(
  269. self, stuck_detector: StuckDetector, event_stream: EventStream
  270. ):
  271. message_action = MessageAction(content='Come on', wait_for_response=False)
  272. message_action._source = EventSource.USER
  273. event_stream.add_event(message_action, EventSource.USER)
  274. message_observation = NullObservation(content='')
  275. event_stream.add_event(message_observation, EventSource.USER)
  276. cmd_action_1 = CmdRunAction(command='ls')
  277. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  278. cmd_observation_1 = CmdOutputObservation(
  279. command_id=1, command='ls', content='file1.txt\nfile2.txt'
  280. )
  281. cmd_observation_1._cause = cmd_action_1._id
  282. event_stream.add_event(cmd_observation_1, EventSource.USER)
  283. read_action_1 = FileReadAction(path='file1.txt')
  284. event_stream.add_event(read_action_1, EventSource.AGENT)
  285. read_observation_1 = FileReadObservation(
  286. content='File content', path='file1.txt'
  287. )
  288. read_observation_1._cause = read_action_1._id
  289. event_stream.add_event(read_observation_1, EventSource.USER)
  290. cmd_action_2 = CmdRunAction(command='ls')
  291. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  292. cmd_observation_2 = CmdOutputObservation(
  293. command_id=2, command='ls', content='file1.txt\nfile2.txt'
  294. )
  295. cmd_observation_2._cause = cmd_action_2._id
  296. event_stream.add_event(cmd_observation_2, EventSource.USER)
  297. read_action_2 = FileReadAction(path='file1.txt')
  298. event_stream.add_event(read_action_2, EventSource.AGENT)
  299. read_observation_2 = FileReadObservation(
  300. content='File content', path='file1.txt'
  301. )
  302. read_observation_2._cause = read_action_2._id
  303. event_stream.add_event(read_observation_2, EventSource.USER)
  304. # one more message to break the pattern
  305. message_null_observation = NullObservation(content='')
  306. event_stream.add_event(message_action, EventSource.USER)
  307. event_stream.add_event(message_null_observation, EventSource.USER)
  308. cmd_action_3 = CmdRunAction(command='ls')
  309. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  310. cmd_observation_3 = CmdOutputObservation(
  311. command_id=3, command='ls', content='file1.txt\nfile2.txt'
  312. )
  313. cmd_observation_3._cause = cmd_action_3._id
  314. event_stream.add_event(cmd_observation_3, EventSource.USER)
  315. read_action_3 = FileReadAction(path='file1.txt')
  316. event_stream.add_event(read_action_3, EventSource.AGENT)
  317. read_observation_3 = FileReadObservation(
  318. content='File content', path='file1.txt'
  319. )
  320. read_observation_3._cause = read_action_3._id
  321. event_stream.add_event(read_observation_3, EventSource.USER)
  322. with patch('logging.Logger.warning') as mock_warning:
  323. assert stuck_detector.is_stuck() is True
  324. mock_warning.assert_called_once_with('Action, Observation pattern detected')
  325. def test_is_stuck_not_stuck(
  326. self, stuck_detector: StuckDetector, event_stream: EventStream
  327. ):
  328. message_action = MessageAction(content='Done', wait_for_response=False)
  329. message_action._source = EventSource.USER
  330. hello_action = MessageAction(content='Hello', wait_for_response=False)
  331. event_stream.add_event(hello_action, EventSource.USER)
  332. hello_observation = NullObservation(content='')
  333. hello_observation._cause = hello_action._id
  334. event_stream.add_event(hello_observation, EventSource.USER)
  335. cmd_action_1 = CmdRunAction(command='ls')
  336. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  337. cmd_observation_1 = CmdOutputObservation(
  338. command_id=cmd_action_1.id, command='ls', content='file1.txt\nfile2.txt'
  339. )
  340. cmd_observation_1._cause = cmd_action_1._id
  341. event_stream.add_event(cmd_observation_1, EventSource.USER)
  342. read_action_1 = FileReadAction(path='file1.txt')
  343. event_stream.add_event(read_action_1, EventSource.AGENT)
  344. read_observation_1 = FileReadObservation(
  345. content='File content', path='file1.txt'
  346. )
  347. read_observation_1._cause = read_action_1._id
  348. event_stream.add_event(read_observation_1, EventSource.USER)
  349. cmd_action_2 = CmdRunAction(command='pwd')
  350. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  351. cmd_observation_2 = CmdOutputObservation(
  352. command_id=2, command='pwd', content='/home/user'
  353. )
  354. cmd_observation_2._cause = cmd_action_2._id
  355. event_stream.add_event(cmd_observation_2, EventSource.USER)
  356. read_action_2 = FileReadAction(path='file2.txt')
  357. event_stream.add_event(read_action_2, EventSource.AGENT)
  358. read_observation_2 = FileReadObservation(
  359. content='Another file content', path='file2.txt'
  360. )
  361. read_observation_2._cause = read_action_2._id
  362. event_stream.add_event(read_observation_2, EventSource.USER)
  363. message_null_observation = NullObservation(content='')
  364. event_stream.add_event(message_action, EventSource.USER)
  365. event_stream.add_event(message_null_observation, EventSource.USER)
  366. cmd_action_3 = CmdRunAction(command='pwd')
  367. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  368. cmd_observation_3 = CmdOutputObservation(
  369. command_id=cmd_action_3.id, command='pwd', content='/home/user'
  370. )
  371. cmd_observation_3._cause = cmd_action_3._id
  372. event_stream.add_event(cmd_observation_3, EventSource.USER)
  373. read_action_3 = FileReadAction(path='file2.txt')
  374. event_stream.add_event(read_action_3, EventSource.AGENT)
  375. read_observation_3 = FileReadObservation(
  376. content='Another file content', path='file2.txt'
  377. )
  378. read_observation_3._cause = read_action_3._id
  379. event_stream.add_event(read_observation_3, EventSource.USER)
  380. assert stuck_detector.is_stuck() is False
  381. def test_is_stuck_monologue(self, stuck_detector, event_stream):
  382. # Add events to the event stream
  383. message_action_1 = MessageAction(content='Hi there!')
  384. event_stream.add_event(message_action_1, EventSource.USER)
  385. message_action_1._source = EventSource.USER
  386. message_action_2 = MessageAction(content='Hi there!')
  387. event_stream.add_event(message_action_2, EventSource.AGENT)
  388. message_action_2._source = EventSource.AGENT
  389. message_action_3 = MessageAction(content='How are you?')
  390. event_stream.add_event(message_action_3, EventSource.USER)
  391. message_action_3._source = EventSource.USER
  392. cmd_kill_action = CmdRunAction(
  393. command='echo 42', thought="I'm not stuck, he's stuck"
  394. )
  395. event_stream.add_event(cmd_kill_action, EventSource.AGENT)
  396. message_action_4 = MessageAction(content="I'm doing well, thanks for asking.")
  397. event_stream.add_event(message_action_4, EventSource.AGENT)
  398. message_action_4._source = EventSource.AGENT
  399. message_action_5 = MessageAction(content="I'm doing well, thanks for asking.")
  400. event_stream.add_event(message_action_5, EventSource.AGENT)
  401. message_action_5._source = EventSource.AGENT
  402. message_action_6 = MessageAction(content="I'm doing well, thanks for asking.")
  403. event_stream.add_event(message_action_6, EventSource.AGENT)
  404. message_action_6._source = EventSource.AGENT
  405. # stuck_detector.state.history.set_event_stream(event_stream)
  406. assert stuck_detector.is_stuck()
  407. # Add an observation event between the repeated message actions
  408. cmd_output_observation = CmdOutputObservation(
  409. content='OK, I was stuck, but no more.',
  410. command_id=42,
  411. command='storybook',
  412. exit_code=0,
  413. )
  414. cmd_output_observation._cause = cmd_kill_action._id
  415. event_stream.add_event(cmd_output_observation, EventSource.USER)
  416. message_action_7 = MessageAction(content="I'm doing well, thanks for asking.")
  417. event_stream.add_event(message_action_7, EventSource.AGENT)
  418. message_action_7._source = EventSource.AGENT
  419. message_action_8 = MessageAction(content="I'm doing well, thanks for asking.")
  420. event_stream.add_event(message_action_8, EventSource.AGENT)
  421. message_action_8._source = EventSource.AGENT
  422. assert not stuck_detector.is_stuck()
  423. class TestAgentController:
  424. @pytest.fixture
  425. def controller(self):
  426. controller = Mock(spec=AgentController)
  427. controller._is_stuck = AgentController._is_stuck.__get__(
  428. controller, AgentController
  429. )
  430. controller.delegate = None
  431. controller.state = Mock()
  432. controller.state.history = ShortTermHistory()
  433. return controller
  434. def test_is_stuck_delegate_stuck(self, controller: AgentController):
  435. controller.delegate = Mock()
  436. controller.delegate._is_stuck.return_value = True
  437. assert controller._is_stuck() is True