stuck.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. from openhands.controller.state.state import State
  2. from openhands.core.logger import openhands_logger as logger
  3. from openhands.events.action.action import Action
  4. from openhands.events.action.commands import IPythonRunCellAction
  5. from openhands.events.action.empty import NullAction
  6. from openhands.events.action.message import MessageAction
  7. from openhands.events.event import Event, EventSource
  8. from openhands.events.observation.commands import (
  9. CmdOutputObservation,
  10. IPythonRunCellObservation,
  11. )
  12. from openhands.events.observation.empty import NullObservation
  13. from openhands.events.observation.error import ErrorObservation
  14. from openhands.events.observation.observation import Observation
  15. class StuckDetector:
  16. SYNTAX_ERROR_MESSAGES = [
  17. 'SyntaxError: unterminated string literal (detected at line',
  18. 'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
  19. 'SyntaxError: incomplete input',
  20. ]
  21. def __init__(self, state: State):
  22. self.state = state
  23. def is_stuck(self):
  24. # filter out MessageAction with source='user' from history
  25. filtered_history = [
  26. event
  27. for event in self.state.history.get_events()
  28. if not (
  29. (isinstance(event, MessageAction) and event.source == EventSource.USER)
  30. or
  31. # there might be some NullAction or NullObservation in the history at least for now
  32. isinstance(event, NullAction)
  33. or isinstance(event, NullObservation)
  34. )
  35. ]
  36. # it takes 3 actions minimum to detect a loop, otherwise nothing to do here
  37. if len(filtered_history) < 3:
  38. return False
  39. # the first few scenarios detect 3 or 4 repeated steps
  40. # prepare the last 4 actions and observations, to check them out
  41. last_actions: list[Event] = []
  42. last_observations: list[Event] = []
  43. # retrieve the last four actions and observations starting from the end of history, wherever they are
  44. for event in reversed(filtered_history):
  45. if isinstance(event, Action) and len(last_actions) < 4:
  46. last_actions.append(event)
  47. elif isinstance(event, Observation) and len(last_observations) < 4:
  48. last_observations.append(event)
  49. if len(last_actions) == 4 and len(last_observations) == 4:
  50. break
  51. # scenario 1: same action, same observation
  52. if self._is_stuck_repeating_action_observation(last_actions, last_observations):
  53. return True
  54. # scenario 2: same action, errors
  55. if self._is_stuck_repeating_action_error(last_actions, last_observations):
  56. return True
  57. # scenario 3: monologue
  58. if self._is_stuck_monologue(filtered_history):
  59. return True
  60. # scenario 4: action, observation pattern on the last six steps
  61. if len(filtered_history) < 6:
  62. return False
  63. if self._is_stuck_action_observation_pattern(filtered_history):
  64. return True
  65. return False
  66. def _is_stuck_repeating_action_observation(self, last_actions, last_observations):
  67. # scenario 1: same action, same observation
  68. # it takes 4 actions and 4 observations to detect a loop
  69. # assert len(last_actions) == 4 and len(last_observations) == 4
  70. # reset almost_stuck reminder
  71. self.state.almost_stuck = 0
  72. # almost stuck? if two actions, obs are the same, we're almost stuck
  73. if len(last_actions) >= 2 and len(last_observations) >= 2:
  74. actions_equal = all(
  75. self._eq_no_pid(last_actions[0], action) for action in last_actions[:2]
  76. )
  77. observations_equal = all(
  78. self._eq_no_pid(last_observations[0], observation)
  79. for observation in last_observations[:2]
  80. )
  81. # the last two actions and obs are the same?
  82. if actions_equal and observations_equal:
  83. self.state.almost_stuck = 2
  84. # the last three actions and observations are the same?
  85. if len(last_actions) >= 3 and len(last_observations) >= 3:
  86. if (
  87. actions_equal
  88. and observations_equal
  89. and self._eq_no_pid(last_actions[0], last_actions[2])
  90. and self._eq_no_pid(last_observations[0], last_observations[2])
  91. ):
  92. self.state.almost_stuck = 1
  93. if len(last_actions) == 4 and len(last_observations) == 4:
  94. if (
  95. actions_equal
  96. and observations_equal
  97. and self._eq_no_pid(last_actions[0], last_actions[3])
  98. and self._eq_no_pid(last_observations[0], last_observations[3])
  99. ):
  100. logger.warning('Action, Observation loop detected')
  101. self.state.almost_stuck = 0
  102. return True
  103. return False
  104. def _is_stuck_repeating_action_error(self, last_actions, last_observations):
  105. # scenario 2: same action, errors
  106. # it takes 3 actions and 3 observations to detect a loop
  107. # check if the last three actions are the same and result in errors
  108. if len(last_actions) < 4 or len(last_observations) < 4:
  109. return False
  110. # are the last three actions the "same"?
  111. if all(self._eq_no_pid(last_actions[0], action) for action in last_actions[:3]):
  112. # and the last three observations are all errors?
  113. if all(isinstance(obs, ErrorObservation) for obs in last_observations[:3]):
  114. logger.warning('Action, ErrorObservation loop detected')
  115. return True
  116. # or, are the last three observations all IPythonRunCellObservation with SyntaxError?
  117. elif all(
  118. isinstance(obs, IPythonRunCellObservation)
  119. for obs in last_observations[:3]
  120. ):
  121. warning = 'Action, IPythonRunCellObservation loop detected'
  122. for error_message in self.SYNTAX_ERROR_MESSAGES:
  123. if error_message.startswith(
  124. 'SyntaxError: unterminated string literal (detected at line'
  125. ):
  126. if self._check_for_consistent_line_error(
  127. last_observations[:3], error_message
  128. ):
  129. logger.warning(warning)
  130. return True
  131. elif error_message in (
  132. 'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
  133. 'SyntaxError: incomplete input',
  134. ) and self._check_for_consistent_invalid_syntax(
  135. last_observations[:3], error_message
  136. ):
  137. logger.warning(warning)
  138. return True
  139. return False
  140. def _check_for_consistent_invalid_syntax(self, observations, error_message):
  141. first_lines = []
  142. valid_observations = []
  143. for obs in observations:
  144. content = obs.content
  145. lines = content.strip().split('\n')
  146. if len(lines) < 6: # 6 because a real syntax error has at least 6 lines
  147. return False
  148. line1 = lines[0].strip()
  149. if not line1.startswith('Cell In[1], line'):
  150. return False
  151. first_lines.append(line1) # Store the first line of each observation
  152. # Check last three lines
  153. if (
  154. lines[-1].startswith('[Jupyter Python interpreter:')
  155. and lines[-2].startswith('[Jupyter current working directory:')
  156. and error_message in lines[-3]
  157. ):
  158. valid_observations.append(obs)
  159. # Check if:
  160. # 1. All first lines are identical
  161. # 2. We have exactly 3 valid observations
  162. # 3. The error message line is identical in all valid observations
  163. return (
  164. len(set(first_lines)) == 1
  165. and len(valid_observations) == 3
  166. and len(
  167. set(
  168. obs.content.strip().split('\n')[:-2][-1]
  169. for obs in valid_observations
  170. )
  171. )
  172. == 1
  173. )
  174. def _check_for_consistent_line_error(self, observations, error_message):
  175. error_lines = []
  176. for obs in observations:
  177. content = obs.content
  178. lines = content.strip().split('\n')
  179. if len(lines) < 3:
  180. return False
  181. last_lines = lines[-3:]
  182. # Check if the last two lines are our own
  183. if not (
  184. last_lines[-2].startswith('[Jupyter current working directory:')
  185. and last_lines[-1].startswith('[Jupyter Python interpreter:')
  186. ):
  187. return False
  188. # Check for the error message in the 3rd-to-last line
  189. if error_message in last_lines[-3]:
  190. error_lines.append(last_lines[-3])
  191. # Check if we found the error message in all 3 observations
  192. # and the 3rd-to-last line is identical across all occurrences
  193. return len(error_lines) == 3 and len(set(error_lines)) == 1
  194. def _is_stuck_monologue(self, filtered_history):
  195. # scenario 3: monologue
  196. # check for repeated MessageActions with source=AGENT
  197. # see if the agent is engaged in a good old monologue, telling itself the same thing over and over
  198. agent_message_actions = [
  199. (i, event)
  200. for i, event in enumerate(filtered_history)
  201. if isinstance(event, MessageAction) and event.source == EventSource.AGENT
  202. ]
  203. # last three message actions will do for this check
  204. if len(agent_message_actions) >= 3:
  205. last_agent_message_actions = agent_message_actions[-3:]
  206. if all(
  207. (last_agent_message_actions[0][1] == action[1])
  208. for action in last_agent_message_actions
  209. ):
  210. # check if there are any observations between the repeated MessageActions
  211. # then it's not yet a loop, maybe it can recover
  212. start_index = last_agent_message_actions[0][0]
  213. end_index = last_agent_message_actions[-1][0]
  214. has_observation_between = False
  215. for event in filtered_history[start_index + 1 : end_index]:
  216. if isinstance(event, Observation):
  217. has_observation_between = True
  218. break
  219. if not has_observation_between:
  220. logger.warning('Repeated MessageAction with source=AGENT detected')
  221. return True
  222. return False
  223. def _is_stuck_action_observation_pattern(self, filtered_history):
  224. # scenario 4: action, observation pattern on the last six steps
  225. # check if the agent repeats the same (Action, Observation)
  226. # every other step in the last six steps
  227. last_six_actions: list[Event] = []
  228. last_six_observations: list[Event] = []
  229. # the end of history is most interesting
  230. for event in reversed(filtered_history):
  231. if isinstance(event, Action) and len(last_six_actions) < 6:
  232. last_six_actions.append(event)
  233. elif isinstance(event, Observation) and len(last_six_observations) < 6:
  234. last_six_observations.append(event)
  235. if len(last_six_actions) == 6 and len(last_six_observations) == 6:
  236. break
  237. # this pattern is every other step, like:
  238. # (action_1, obs_1), (action_2, obs_2), (action_1, obs_1), (action_2, obs_2),...
  239. if len(last_six_actions) == 6 and len(last_six_observations) == 6:
  240. actions_equal = (
  241. # action_0 == action_2 == action_4
  242. self._eq_no_pid(last_six_actions[0], last_six_actions[2])
  243. and self._eq_no_pid(last_six_actions[0], last_six_actions[4])
  244. # action_1 == action_3 == action_5
  245. and self._eq_no_pid(last_six_actions[1], last_six_actions[3])
  246. and self._eq_no_pid(last_six_actions[1], last_six_actions[5])
  247. )
  248. observations_equal = (
  249. # obs_0 == obs_2 == obs_4
  250. self._eq_no_pid(last_six_observations[0], last_six_observations[2])
  251. and self._eq_no_pid(last_six_observations[0], last_six_observations[4])
  252. # obs_1 == obs_3 == obs_5
  253. and self._eq_no_pid(last_six_observations[1], last_six_observations[3])
  254. and self._eq_no_pid(last_six_observations[1], last_six_observations[5])
  255. )
  256. if actions_equal and observations_equal:
  257. logger.warning('Action, Observation pattern detected')
  258. return True
  259. return False
  260. def _eq_no_pid(self, obj1, obj2):
  261. if isinstance(obj1, IPythonRunCellAction) and isinstance(
  262. obj2, IPythonRunCellAction
  263. ):
  264. # for loop detection on edit actions, ignore the thought, compare some code
  265. # the code should have at least 3 lines, to avoid simple one-liners
  266. if (
  267. 'edit_file_by_replace(' in obj1.code
  268. and 'edit_file_by_replace(' in obj2.code
  269. ):
  270. return (
  271. len(obj1.code.split('\n')) > 2
  272. and obj1.code.split('\n')[:3] == obj2.code.split('\n')[:3]
  273. )
  274. else:
  275. # default comparison
  276. return obj1 == obj2
  277. elif isinstance(obj1, CmdOutputObservation) and isinstance(
  278. obj2, CmdOutputObservation
  279. ):
  280. # for loop detection, ignore command_id, which is the pid
  281. return obj1.command == obj2.command and obj1.exit_code == obj2.exit_code
  282. else:
  283. # this is the default comparison
  284. return obj1 == obj2