stuck.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. from openhands.controller.state.state import State
  2. from openhands.core.logger import openhands_logger as logger
  3. from openhands.events.action.action import Action
  4. from openhands.events.action.commands import IPythonRunCellAction
  5. from openhands.events.action.empty import NullAction
  6. from openhands.events.action.message import MessageAction
  7. from openhands.events.event import Event, EventSource
  8. from openhands.events.observation.commands import (
  9. CmdOutputObservation,
  10. IPythonRunCellObservation,
  11. )
  12. from openhands.events.observation.empty import NullObservation
  13. from openhands.events.observation.error import ErrorObservation
  14. from openhands.events.observation.observation import Observation
  15. class StuckDetector:
  16. SYNTAX_ERROR_MESSAGES = [
  17. 'SyntaxError: unterminated string literal (detected at line',
  18. 'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
  19. 'SyntaxError: incomplete input',
  20. ]
  21. def __init__(self, state: State):
  22. self.state = state
  23. def is_stuck(self, headless_mode: bool = True):
  24. """Checks if the agent is stuck in a loop.
  25. Args:
  26. headless_mode: Matches AgentController's headless_mode.
  27. If True: Consider all history (automated/testing)
  28. If False: Consider only history after last user message (interactive)
  29. Returns:
  30. bool: True if the agent is stuck in a loop, False otherwise.
  31. """
  32. if not headless_mode:
  33. # In interactive mode, only look at history after the last user message
  34. last_user_msg_idx = -1
  35. for i, event in enumerate(reversed(self.state.history)):
  36. if (
  37. isinstance(event, MessageAction)
  38. and event.source == EventSource.USER
  39. ):
  40. last_user_msg_idx = len(self.state.history) - i - 1
  41. break
  42. history_to_check = self.state.history[last_user_msg_idx + 1 :]
  43. else:
  44. # In headless mode, look at all history
  45. history_to_check = self.state.history
  46. # Filter out user messages and null events
  47. filtered_history = [
  48. event
  49. for event in history_to_check
  50. if not (
  51. # Filter works elegantly in both modes:
  52. # - In headless: actively filters out user messages from full history
  53. # - In non-headless: no-op since we already sliced after last user message
  54. (isinstance(event, MessageAction) and event.source == EventSource.USER)
  55. # there might be some NullAction or NullObservation in the history at least for now
  56. or isinstance(event, (NullAction, NullObservation))
  57. )
  58. ]
  59. # it takes 3 actions minimum to detect a loop, otherwise nothing to do here
  60. if len(filtered_history) < 3:
  61. return False
  62. # the first few scenarios detect 3 or 4 repeated steps
  63. # prepare the last 4 actions and observations, to check them out
  64. last_actions: list[Event] = []
  65. last_observations: list[Event] = []
  66. # retrieve the last four actions and observations starting from the end of history, wherever they are
  67. for event in reversed(filtered_history):
  68. if isinstance(event, Action) and len(last_actions) < 4:
  69. last_actions.append(event)
  70. elif isinstance(event, Observation) and len(last_observations) < 4:
  71. last_observations.append(event)
  72. if len(last_actions) == 4 and len(last_observations) == 4:
  73. break
  74. # scenario 1: same action, same observation
  75. if self._is_stuck_repeating_action_observation(last_actions, last_observations):
  76. return True
  77. # scenario 2: same action, errors
  78. if self._is_stuck_repeating_action_error(last_actions, last_observations):
  79. return True
  80. # scenario 3: monologue
  81. if self._is_stuck_monologue(filtered_history):
  82. return True
  83. # scenario 4: action, observation pattern on the last six steps
  84. if len(filtered_history) < 6:
  85. return False
  86. if self._is_stuck_action_observation_pattern(filtered_history):
  87. return True
  88. return False
  89. def _is_stuck_repeating_action_observation(self, last_actions, last_observations):
  90. # scenario 1: same action, same observation
  91. # it takes 4 actions and 4 observations to detect a loop
  92. # assert len(last_actions) == 4 and len(last_observations) == 4
  93. # Check for a loop of 4 identical action-observation pairs
  94. if len(last_actions) == 4 and len(last_observations) == 4:
  95. actions_equal = all(
  96. self._eq_no_pid(last_actions[0], action) for action in last_actions
  97. )
  98. observations_equal = all(
  99. self._eq_no_pid(last_observations[0], observation)
  100. for observation in last_observations
  101. )
  102. if actions_equal and observations_equal:
  103. logger.warning('Action, Observation loop detected')
  104. return True
  105. return False
  106. def _is_stuck_repeating_action_error(self, last_actions, last_observations):
  107. # scenario 2: same action, errors
  108. # it takes 3 actions and 3 observations to detect a loop
  109. # check if the last three actions are the same and result in errors
  110. if len(last_actions) < 4 or len(last_observations) < 4:
  111. return False
  112. # are the last three actions the "same"?
  113. if all(self._eq_no_pid(last_actions[0], action) for action in last_actions[:3]):
  114. # and the last three observations are all errors?
  115. if all(isinstance(obs, ErrorObservation) for obs in last_observations[:3]):
  116. logger.warning('Action, ErrorObservation loop detected')
  117. return True
  118. # or, are the last three observations all IPythonRunCellObservation with SyntaxError?
  119. elif all(
  120. isinstance(obs, IPythonRunCellObservation)
  121. for obs in last_observations[:3]
  122. ):
  123. warning = 'Action, IPythonRunCellObservation loop detected'
  124. for error_message in self.SYNTAX_ERROR_MESSAGES:
  125. if error_message.startswith(
  126. 'SyntaxError: unterminated string literal (detected at line'
  127. ):
  128. if self._check_for_consistent_line_error(
  129. last_observations[:3], error_message
  130. ):
  131. logger.warning(warning)
  132. return True
  133. elif error_message in (
  134. 'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
  135. 'SyntaxError: incomplete input',
  136. ) and self._check_for_consistent_invalid_syntax(
  137. last_observations[:3], error_message
  138. ):
  139. logger.warning(warning)
  140. return True
  141. return False
  142. def _check_for_consistent_invalid_syntax(self, observations, error_message):
  143. first_lines = []
  144. valid_observations = []
  145. for obs in observations:
  146. content = obs.content
  147. lines = content.strip().split('\n')
  148. if len(lines) < 6: # 6 because a real syntax error has at least 6 lines
  149. return False
  150. line1 = lines[0].strip()
  151. if not line1.startswith('Cell In[1], line'):
  152. return False
  153. first_lines.append(line1) # Store the first line of each observation
  154. # Check last three lines
  155. if (
  156. lines[-1].startswith('[Jupyter Python interpreter:')
  157. and lines[-2].startswith('[Jupyter current working directory:')
  158. and error_message in lines[-3]
  159. ):
  160. valid_observations.append(obs)
  161. # Check if:
  162. # 1. All first lines are identical
  163. # 2. We have exactly 3 valid observations
  164. # 3. The error message line is identical in all valid observations
  165. return (
  166. len(set(first_lines)) == 1
  167. and len(valid_observations) == 3
  168. and len(
  169. set(
  170. obs.content.strip().split('\n')[:-2][-1]
  171. for obs in valid_observations
  172. )
  173. )
  174. == 1
  175. )
  176. def _check_for_consistent_line_error(self, observations, error_message):
  177. error_lines = []
  178. for obs in observations:
  179. content = obs.content
  180. lines = content.strip().split('\n')
  181. if len(lines) < 3:
  182. return False
  183. last_lines = lines[-3:]
  184. # Check if the last two lines are our own
  185. if not (
  186. last_lines[-2].startswith('[Jupyter current working directory:')
  187. and last_lines[-1].startswith('[Jupyter Python interpreter:')
  188. ):
  189. return False
  190. # Check for the error message in the 3rd-to-last line
  191. if error_message in last_lines[-3]:
  192. error_lines.append(last_lines[-3])
  193. # Check if we found the error message in all 3 observations
  194. # and the 3rd-to-last line is identical across all occurrences
  195. return len(error_lines) == 3 and len(set(error_lines)) == 1
  196. def _is_stuck_monologue(self, filtered_history):
  197. # scenario 3: monologue
  198. # check for repeated MessageActions with source=AGENT
  199. # see if the agent is engaged in a good old monologue, telling itself the same thing over and over
  200. agent_message_actions = [
  201. (i, event)
  202. for i, event in enumerate(filtered_history)
  203. if isinstance(event, MessageAction) and event.source == EventSource.AGENT
  204. ]
  205. # last three message actions will do for this check
  206. if len(agent_message_actions) >= 3:
  207. last_agent_message_actions = agent_message_actions[-3:]
  208. if all(
  209. (last_agent_message_actions[0][1] == action[1])
  210. for action in last_agent_message_actions
  211. ):
  212. # check if there are any observations between the repeated MessageActions
  213. # then it's not yet a loop, maybe it can recover
  214. start_index = last_agent_message_actions[0][0]
  215. end_index = last_agent_message_actions[-1][0]
  216. has_observation_between = False
  217. for event in filtered_history[start_index + 1 : end_index]:
  218. if isinstance(event, Observation):
  219. has_observation_between = True
  220. break
  221. if not has_observation_between:
  222. logger.warning('Repeated MessageAction with source=AGENT detected')
  223. return True
  224. return False
  225. def _is_stuck_action_observation_pattern(self, filtered_history):
  226. # scenario 4: action, observation pattern on the last six steps
  227. # check if the agent repeats the same (Action, Observation)
  228. # every other step in the last six steps
  229. last_six_actions: list[Event] = []
  230. last_six_observations: list[Event] = []
  231. # the end of history is most interesting
  232. for event in reversed(filtered_history):
  233. if isinstance(event, Action) and len(last_six_actions) < 6:
  234. last_six_actions.append(event)
  235. elif isinstance(event, Observation) and len(last_six_observations) < 6:
  236. last_six_observations.append(event)
  237. if len(last_six_actions) == 6 and len(last_six_observations) == 6:
  238. break
  239. # this pattern is every other step, like:
  240. # (action_1, obs_1), (action_2, obs_2), (action_1, obs_1), (action_2, obs_2),...
  241. if len(last_six_actions) == 6 and len(last_six_observations) == 6:
  242. actions_equal = (
  243. # action_0 == action_2 == action_4
  244. self._eq_no_pid(last_six_actions[0], last_six_actions[2])
  245. and self._eq_no_pid(last_six_actions[0], last_six_actions[4])
  246. # action_1 == action_3 == action_5
  247. and self._eq_no_pid(last_six_actions[1], last_six_actions[3])
  248. and self._eq_no_pid(last_six_actions[1], last_six_actions[5])
  249. )
  250. observations_equal = (
  251. # obs_0 == obs_2 == obs_4
  252. self._eq_no_pid(last_six_observations[0], last_six_observations[2])
  253. and self._eq_no_pid(last_six_observations[0], last_six_observations[4])
  254. # obs_1 == obs_3 == obs_5
  255. and self._eq_no_pid(last_six_observations[1], last_six_observations[3])
  256. and self._eq_no_pid(last_six_observations[1], last_six_observations[5])
  257. )
  258. if actions_equal and observations_equal:
  259. logger.warning('Action, Observation pattern detected')
  260. return True
  261. return False
  262. def _eq_no_pid(self, obj1, obj2):
  263. if isinstance(obj1, IPythonRunCellAction) and isinstance(
  264. obj2, IPythonRunCellAction
  265. ):
  266. # for loop detection on edit actions, ignore the thought, compare some code
  267. # the code should have at least 3 lines, to avoid simple one-liners
  268. if (
  269. 'edit_file_by_replace(' in obj1.code
  270. and 'edit_file_by_replace(' in obj2.code
  271. ):
  272. return (
  273. len(obj1.code.split('\n')) > 2
  274. and obj1.code.split('\n')[:3] == obj2.code.split('\n')[:3]
  275. )
  276. else:
  277. # default comparison
  278. return obj1 == obj2
  279. elif isinstance(obj1, CmdOutputObservation) and isinstance(
  280. obj2, CmdOutputObservation
  281. ):
  282. # for loop detection, ignore command_id, which is the pid
  283. return obj1.command == obj2.command and obj1.exit_code == obj2.exit_code
  284. else:
  285. # this is the default comparison
  286. return obj1 == obj2