test_is_stuck.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. import logging
  2. from unittest.mock import Mock, patch
  3. import pytest
  4. from opendevin.controller.agent_controller import AgentController
  5. from opendevin.controller.state.state import State
  6. from opendevin.controller.stuck import StuckDetector
  7. from opendevin.events.action import CmdRunAction, FileReadAction, MessageAction
  8. from opendevin.events.action.commands import IPythonRunCellAction
  9. from opendevin.events.observation import (
  10. CmdOutputObservation,
  11. FileReadObservation,
  12. )
  13. from opendevin.events.observation.commands import IPythonRunCellObservation
  14. from opendevin.events.observation.empty import NullObservation
  15. from opendevin.events.observation.error import ErrorObservation
  16. from opendevin.events.stream import EventSource, EventStream
  17. from opendevin.memory.history import ShortTermHistory
  18. def collect_events(stream):
  19. return [event for event in stream.get_events()]
  20. logging.basicConfig(level=logging.DEBUG)
  21. @pytest.fixture
  22. def event_stream():
  23. event_stream = EventStream('asdf')
  24. yield event_stream
  25. # clear after each test
  26. event_stream.clear()
  27. class TestStuckDetector:
  28. @pytest.fixture
  29. def stuck_detector(self, event_stream):
  30. state = State(inputs={}, max_iterations=50)
  31. # state.history = ShortTermHistory()
  32. state.history.set_event_stream(event_stream)
  33. return StuckDetector(state)
  34. def test_history_too_short(
  35. self, stuck_detector: StuckDetector, event_stream: EventStream
  36. ):
  37. message_action = MessageAction(content='Hello', wait_for_response=False)
  38. message_action._source = EventSource.USER
  39. observation = NullObservation(content='')
  40. observation._cause = message_action.id
  41. event_stream.add_event(message_action, EventSource.USER)
  42. event_stream.add_event(observation, EventSource.USER)
  43. cmd_action = CmdRunAction(command='ls')
  44. event_stream.add_event(cmd_action, EventSource.AGENT)
  45. cmd_observation = CmdOutputObservation(
  46. command_id=1, command='ls', content='file1.txt\nfile2.txt'
  47. )
  48. cmd_observation._cause = cmd_action._id
  49. event_stream.add_event(cmd_observation, EventSource.USER)
  50. # stuck_detector.state.history.set_event_stream(event_stream)
  51. assert stuck_detector.is_stuck() is False
  52. def test_is_stuck_repeating_action_observation(
  53. self, stuck_detector: StuckDetector, event_stream: EventStream
  54. ):
  55. message_action = MessageAction(content='Done', wait_for_response=False)
  56. message_action._source = EventSource.USER
  57. hello_action = MessageAction(content='Hello', wait_for_response=False)
  58. hello_observation = NullObservation('')
  59. # 2 events
  60. event_stream.add_event(hello_action, EventSource.USER)
  61. event_stream.add_event(hello_observation, EventSource.USER)
  62. cmd_action_1 = CmdRunAction(command='ls')
  63. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  64. cmd_observation_1 = CmdOutputObservation(
  65. content='', command='ls', command_id=cmd_action_1._id
  66. )
  67. cmd_observation_1._cause = cmd_action_1._id
  68. event_stream.add_event(cmd_observation_1, EventSource.USER)
  69. # 4 events
  70. cmd_action_2 = CmdRunAction(command='ls')
  71. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  72. cmd_observation_2 = CmdOutputObservation(
  73. content='', command='ls', command_id=cmd_action_2._id
  74. )
  75. cmd_observation_2._cause = cmd_action_2._id
  76. event_stream.add_event(cmd_observation_2, EventSource.USER)
  77. # 6 events
  78. # random user message just because we can
  79. message_null_observation = NullObservation(content='')
  80. event_stream.add_event(message_action, EventSource.USER)
  81. event_stream.add_event(message_null_observation, EventSource.USER)
  82. # 8 events
  83. assert stuck_detector.is_stuck() is False
  84. assert stuck_detector.state.almost_stuck == 2
  85. cmd_action_3 = CmdRunAction(command='ls')
  86. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  87. cmd_observation_3 = CmdOutputObservation(
  88. content='', command='ls', command_id=cmd_action_3._id
  89. )
  90. cmd_observation_3._cause = cmd_action_3._id
  91. event_stream.add_event(cmd_observation_3, EventSource.USER)
  92. # 10 events
  93. assert len(collect_events(event_stream)) == 10
  94. assert len(list(stuck_detector.state.history.get_events())) == 8
  95. assert len(stuck_detector.state.history.get_pairs()) == 5
  96. assert stuck_detector.is_stuck() is False
  97. assert stuck_detector.state.almost_stuck == 1
  98. cmd_action_4 = CmdRunAction(command='ls')
  99. event_stream.add_event(cmd_action_4, EventSource.AGENT)
  100. cmd_observation_4 = CmdOutputObservation(
  101. content='', command='ls', command_id=cmd_action_4._id
  102. )
  103. cmd_observation_4._cause = cmd_action_4._id
  104. event_stream.add_event(cmd_observation_4, EventSource.USER)
  105. # 12 events
  106. assert len(collect_events(event_stream)) == 12
  107. assert len(list(stuck_detector.state.history.get_events())) == 10
  108. assert len(stuck_detector.state.history.get_pairs()) == 6
  109. with patch('logging.Logger.warning') as mock_warning:
  110. assert stuck_detector.is_stuck() is True
  111. assert stuck_detector.state.almost_stuck == 0
  112. mock_warning.assert_called_once_with('Action, Observation loop detected')
  113. def test_is_stuck_repeating_action_error(
  114. self, stuck_detector: StuckDetector, event_stream: EventStream
  115. ):
  116. # (action, error_observation), not necessarily the same error
  117. message_action = MessageAction(content='Done', wait_for_response=False)
  118. message_action._source = EventSource.USER
  119. hello_action = MessageAction(content='Hello', wait_for_response=False)
  120. hello_observation = NullObservation(content='')
  121. event_stream.add_event(hello_action, EventSource.USER)
  122. hello_observation._cause = hello_action._id
  123. event_stream.add_event(hello_observation, EventSource.USER)
  124. # 2 events
  125. cmd_action_1 = CmdRunAction(command='invalid_command')
  126. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  127. error_observation_1 = ErrorObservation(content='Command not found')
  128. error_observation_1._cause = cmd_action_1._id
  129. event_stream.add_event(error_observation_1, EventSource.USER)
  130. # 4 events
  131. cmd_action_2 = CmdRunAction(command='invalid_command')
  132. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  133. error_observation_2 = ErrorObservation(
  134. content='Command still not found or another error'
  135. )
  136. error_observation_2._cause = cmd_action_2._id
  137. event_stream.add_event(error_observation_2, EventSource.USER)
  138. # 6 events
  139. message_null_observation = NullObservation(content='')
  140. event_stream.add_event(message_action, EventSource.USER)
  141. event_stream.add_event(message_null_observation, EventSource.USER)
  142. # 8 events
  143. cmd_action_3 = CmdRunAction(command='invalid_command')
  144. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  145. error_observation_3 = ErrorObservation(content='Different error')
  146. error_observation_3._cause = cmd_action_3._id
  147. event_stream.add_event(error_observation_3, EventSource.USER)
  148. # 10 events
  149. cmd_action_4 = CmdRunAction(command='invalid_command')
  150. event_stream.add_event(cmd_action_4, EventSource.AGENT)
  151. error_observation_4 = ErrorObservation(content='Command not found')
  152. error_observation_4._cause = cmd_action_4._id
  153. event_stream.add_event(error_observation_4, EventSource.USER)
  154. # 12 events
  155. with patch('logging.Logger.warning') as mock_warning:
  156. assert stuck_detector.is_stuck() is True
  157. mock_warning.assert_called_once_with(
  158. 'Action, ErrorObservation loop detected'
  159. )
  160. def test_is_stuck_ipython_syntax_error(
  161. self, stuck_detector: StuckDetector, event_stream: EventStream
  162. ):
  163. ipython_action_1 = IPythonRunCellAction(code='print("hello')
  164. event_stream.add_event(ipython_action_1, EventSource.AGENT)
  165. ipython_observation_1 = IPythonRunCellObservation(
  166. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)',
  167. code='print("hello',
  168. )
  169. ipython_observation_1._cause = ipython_action_1._id
  170. event_stream.add_event(ipython_observation_1, EventSource.USER)
  171. ipython_action_2 = IPythonRunCellAction(code='print("hello')
  172. event_stream.add_event(ipython_action_2, EventSource.AGENT)
  173. ipython_observation_2 = IPythonRunCellObservation(
  174. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)',
  175. code='print("hello',
  176. )
  177. ipython_observation_2._cause = ipython_action_2._id
  178. event_stream.add_event(ipython_observation_2, EventSource.USER)
  179. ipython_action_3 = IPythonRunCellAction(code='print("hello')
  180. event_stream.add_event(ipython_action_3, EventSource.AGENT)
  181. ipython_observation_3 = IPythonRunCellObservation(
  182. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 3)',
  183. code='print("hello',
  184. )
  185. ipython_observation_3._cause = ipython_action_3._id
  186. event_stream.add_event(ipython_observation_3, EventSource.USER)
  187. ipython_action_4 = IPythonRunCellAction(code='print("hello')
  188. event_stream.add_event(ipython_action_4, EventSource.AGENT)
  189. ipython_observation_4 = IPythonRunCellObservation(
  190. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 2)',
  191. code='print("hello',
  192. )
  193. ipython_observation_4._cause = ipython_action_4._id
  194. event_stream.add_event(ipython_observation_4, EventSource.USER)
  195. # stuck_detector.state.history.set_event_stream(event_stream)
  196. last_observations = [
  197. ipython_observation_1,
  198. ipython_observation_2,
  199. ipython_observation_3,
  200. ipython_observation_4,
  201. ]
  202. for observation in last_observations:
  203. has_string = (
  204. observation.content[-100:].find(
  205. 'SyntaxError: unterminated string literal (detected at line'
  206. )
  207. != -1
  208. )
  209. assert has_string
  210. string_is_last = (
  211. len(
  212. observation.content.split(
  213. 'SyntaxError: unterminated string literal (detected at line'
  214. )[-1]
  215. )
  216. < 10
  217. )
  218. assert string_is_last
  219. with patch('logging.Logger.warning') as mock_warning:
  220. assert stuck_detector.is_stuck() is True
  221. mock_warning.assert_called_once_with(
  222. 'Action, IPythonRunCellObservation loop detected'
  223. )
  224. def test_is_stuck_ipython_syntax_error_not_at_end(
  225. self, stuck_detector: StuckDetector, event_stream: EventStream
  226. ):
  227. ipython_action_1 = IPythonRunCellAction(code='print("hello')
  228. event_stream.add_event(ipython_action_1, EventSource.AGENT)
  229. ipython_observation_1 = IPythonRunCellObservation(
  230. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nThis is some additional output',
  231. code='print("hello',
  232. )
  233. ipython_observation_1._cause = ipython_action_1._id
  234. event_stream.add_event(ipython_observation_1, EventSource.USER)
  235. ipython_action_2 = IPythonRunCellAction(code='print("hello')
  236. event_stream.add_event(ipython_action_2, EventSource.AGENT)
  237. ipython_observation_2 = IPythonRunCellObservation(
  238. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nToo much output here on and on',
  239. code='print("hello',
  240. )
  241. ipython_observation_2._cause = ipython_action_2._id
  242. event_stream.add_event(ipython_observation_2, EventSource.USER)
  243. ipython_action_3 = IPythonRunCellAction(code='print("hello')
  244. event_stream.add_event(ipython_action_3, EventSource.AGENT)
  245. ipython_observation_3 = IPythonRunCellObservation(
  246. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 3)\nEnough',
  247. code='print("hello',
  248. )
  249. ipython_observation_3._cause = ipython_action_3._id
  250. event_stream.add_event(ipython_observation_3, EventSource.USER)
  251. ipython_action_4 = IPythonRunCellAction(code='print("hello')
  252. event_stream.add_event(ipython_action_4, EventSource.AGENT)
  253. ipython_observation_4 = IPythonRunCellObservation(
  254. content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 2)\nLast line of output',
  255. code='print("hello',
  256. )
  257. ipython_observation_4._cause = ipython_action_4._id
  258. event_stream.add_event(ipython_observation_4, EventSource.USER)
  259. with patch('logging.Logger.warning') as mock_warning:
  260. assert stuck_detector.is_stuck() is False
  261. mock_warning.assert_not_called()
  262. def test_is_stuck_repeating_action_observation_pattern(
  263. self, stuck_detector: StuckDetector, event_stream: EventStream
  264. ):
  265. message_action = MessageAction(content='Come on', wait_for_response=False)
  266. message_action._source = EventSource.USER
  267. event_stream.add_event(message_action, EventSource.USER)
  268. message_observation = NullObservation(content='')
  269. event_stream.add_event(message_observation, EventSource.USER)
  270. cmd_action_1 = CmdRunAction(command='ls')
  271. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  272. cmd_observation_1 = CmdOutputObservation(
  273. command_id=1, command='ls', content='file1.txt\nfile2.txt'
  274. )
  275. cmd_observation_1._cause = cmd_action_1._id
  276. event_stream.add_event(cmd_observation_1, EventSource.USER)
  277. read_action_1 = FileReadAction(path='file1.txt')
  278. event_stream.add_event(read_action_1, EventSource.AGENT)
  279. read_observation_1 = FileReadObservation(
  280. content='File content', path='file1.txt'
  281. )
  282. read_observation_1._cause = read_action_1._id
  283. event_stream.add_event(read_observation_1, EventSource.USER)
  284. cmd_action_2 = CmdRunAction(command='ls')
  285. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  286. cmd_observation_2 = CmdOutputObservation(
  287. command_id=2, command='ls', content='file1.txt\nfile2.txt'
  288. )
  289. cmd_observation_2._cause = cmd_action_2._id
  290. event_stream.add_event(cmd_observation_2, EventSource.USER)
  291. read_action_2 = FileReadAction(path='file1.txt')
  292. event_stream.add_event(read_action_2, EventSource.AGENT)
  293. read_observation_2 = FileReadObservation(
  294. content='File content', path='file1.txt'
  295. )
  296. read_observation_2._cause = read_action_2._id
  297. event_stream.add_event(read_observation_2, EventSource.USER)
  298. # one more message to break the pattern
  299. message_null_observation = NullObservation(content='')
  300. event_stream.add_event(message_action, EventSource.USER)
  301. event_stream.add_event(message_null_observation, EventSource.USER)
  302. cmd_action_3 = CmdRunAction(command='ls')
  303. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  304. cmd_observation_3 = CmdOutputObservation(
  305. command_id=3, command='ls', content='file1.txt\nfile2.txt'
  306. )
  307. cmd_observation_3._cause = cmd_action_3._id
  308. event_stream.add_event(cmd_observation_3, EventSource.USER)
  309. read_action_3 = FileReadAction(path='file1.txt')
  310. event_stream.add_event(read_action_3, EventSource.AGENT)
  311. read_observation_3 = FileReadObservation(
  312. content='File content', path='file1.txt'
  313. )
  314. read_observation_3._cause = read_action_3._id
  315. event_stream.add_event(read_observation_3, EventSource.USER)
  316. with patch('logging.Logger.warning') as mock_warning:
  317. assert stuck_detector.is_stuck() is True
  318. mock_warning.assert_called_once_with('Action, Observation pattern detected')
  319. def test_is_stuck_not_stuck(
  320. self, stuck_detector: StuckDetector, event_stream: EventStream
  321. ):
  322. message_action = MessageAction(content='Done', wait_for_response=False)
  323. message_action._source = EventSource.USER
  324. hello_action = MessageAction(content='Hello', wait_for_response=False)
  325. event_stream.add_event(hello_action, EventSource.USER)
  326. hello_observation = NullObservation(content='')
  327. hello_observation._cause = hello_action._id
  328. event_stream.add_event(hello_observation, EventSource.USER)
  329. cmd_action_1 = CmdRunAction(command='ls')
  330. event_stream.add_event(cmd_action_1, EventSource.AGENT)
  331. cmd_observation_1 = CmdOutputObservation(
  332. command_id=cmd_action_1.id, command='ls', content='file1.txt\nfile2.txt'
  333. )
  334. cmd_observation_1._cause = cmd_action_1._id
  335. event_stream.add_event(cmd_observation_1, EventSource.USER)
  336. read_action_1 = FileReadAction(path='file1.txt')
  337. event_stream.add_event(read_action_1, EventSource.AGENT)
  338. read_observation_1 = FileReadObservation(
  339. content='File content', path='file1.txt'
  340. )
  341. read_observation_1._cause = read_action_1._id
  342. event_stream.add_event(read_observation_1, EventSource.USER)
  343. cmd_action_2 = CmdRunAction(command='pwd')
  344. event_stream.add_event(cmd_action_2, EventSource.AGENT)
  345. cmd_observation_2 = CmdOutputObservation(
  346. command_id=2, command='pwd', content='/home/user'
  347. )
  348. cmd_observation_2._cause = cmd_action_2._id
  349. event_stream.add_event(cmd_observation_2, EventSource.USER)
  350. read_action_2 = FileReadAction(path='file2.txt')
  351. event_stream.add_event(read_action_2, EventSource.AGENT)
  352. read_observation_2 = FileReadObservation(
  353. content='Another file content', path='file2.txt'
  354. )
  355. read_observation_2._cause = read_action_2._id
  356. event_stream.add_event(read_observation_2, EventSource.USER)
  357. message_null_observation = NullObservation(content='')
  358. event_stream.add_event(message_action, EventSource.USER)
  359. event_stream.add_event(message_null_observation, EventSource.USER)
  360. cmd_action_3 = CmdRunAction(command='pwd')
  361. event_stream.add_event(cmd_action_3, EventSource.AGENT)
  362. cmd_observation_3 = CmdOutputObservation(
  363. command_id=cmd_action_3.id, command='pwd', content='/home/user'
  364. )
  365. cmd_observation_3._cause = cmd_action_3._id
  366. event_stream.add_event(cmd_observation_3, EventSource.USER)
  367. read_action_3 = FileReadAction(path='file2.txt')
  368. event_stream.add_event(read_action_3, EventSource.AGENT)
  369. read_observation_3 = FileReadObservation(
  370. content='Another file content', path='file2.txt'
  371. )
  372. read_observation_3._cause = read_action_3._id
  373. event_stream.add_event(read_observation_3, EventSource.USER)
  374. assert stuck_detector.is_stuck() is False
  375. def test_is_stuck_monologue(self, stuck_detector, event_stream):
  376. # Add events to the event stream
  377. message_action_1 = MessageAction(content='Hi there!')
  378. event_stream.add_event(message_action_1, EventSource.USER)
  379. message_action_1._source = EventSource.USER
  380. message_action_2 = MessageAction(content='Hi there!')
  381. event_stream.add_event(message_action_2, EventSource.AGENT)
  382. message_action_2._source = EventSource.AGENT
  383. message_action_3 = MessageAction(content='How are you?')
  384. event_stream.add_event(message_action_3, EventSource.USER)
  385. message_action_3._source = EventSource.USER
  386. cmd_kill_action = CmdRunAction(
  387. command='echo 42', thought="I'm not stuck, he's stuck"
  388. )
  389. event_stream.add_event(cmd_kill_action, EventSource.AGENT)
  390. message_action_4 = MessageAction(content="I'm doing well, thanks for asking.")
  391. event_stream.add_event(message_action_4, EventSource.AGENT)
  392. message_action_4._source = EventSource.AGENT
  393. message_action_5 = MessageAction(content="I'm doing well, thanks for asking.")
  394. event_stream.add_event(message_action_5, EventSource.AGENT)
  395. message_action_5._source = EventSource.AGENT
  396. message_action_6 = MessageAction(content="I'm doing well, thanks for asking.")
  397. event_stream.add_event(message_action_6, EventSource.AGENT)
  398. message_action_6._source = EventSource.AGENT
  399. # stuck_detector.state.history.set_event_stream(event_stream)
  400. assert stuck_detector.is_stuck()
  401. # Add an observation event between the repeated message actions
  402. cmd_output_observation = CmdOutputObservation(
  403. content='OK, I was stuck, but no more.',
  404. command_id=42,
  405. command='storybook',
  406. exit_code=0,
  407. )
  408. cmd_output_observation._cause = cmd_kill_action._id
  409. event_stream.add_event(cmd_output_observation, EventSource.USER)
  410. message_action_7 = MessageAction(content="I'm doing well, thanks for asking.")
  411. event_stream.add_event(message_action_7, EventSource.AGENT)
  412. message_action_7._source = EventSource.AGENT
  413. message_action_8 = MessageAction(content="I'm doing well, thanks for asking.")
  414. event_stream.add_event(message_action_8, EventSource.AGENT)
  415. message_action_8._source = EventSource.AGENT
  416. assert not stuck_detector.is_stuck()
  417. class TestAgentController:
  418. @pytest.fixture
  419. def controller(self):
  420. controller = Mock(spec=AgentController)
  421. controller._is_stuck = AgentController._is_stuck.__get__(
  422. controller, AgentController
  423. )
  424. controller.delegate = None
  425. controller.state = Mock()
  426. controller.state.history = ShortTermHistory()
  427. return controller
  428. def test_is_stuck_delegate_stuck(self, controller: AgentController):
  429. controller.delegate = Mock()
  430. controller.delegate._is_stuck.return_value = True
  431. assert controller._is_stuck() is True