stuck.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. from openhands.controller.state.state import State
  2. from openhands.core.logger import openhands_logger as logger
  3. from openhands.events.action.action import Action
  4. from openhands.events.action.empty import NullAction
  5. from openhands.events.action.message import MessageAction
  6. from openhands.events.event import Event, EventSource
  7. from openhands.events.observation.commands import (
  8. CmdOutputObservation,
  9. IPythonRunCellObservation,
  10. )
  11. from openhands.events.observation.empty import NullObservation
  12. from openhands.events.observation.error import ErrorObservation
  13. from openhands.events.observation.observation import Observation
  14. class StuckDetector:
  15. SYNTAX_ERROR_MESSAGES = [
  16. 'SyntaxError: unterminated string literal (detected at line',
  17. 'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
  18. 'SyntaxError: incomplete input',
  19. ]
  20. def __init__(self, state: State):
  21. self.state = state
  22. def is_stuck(self):
  23. # filter out MessageAction with source='user' from history
  24. filtered_history = [
  25. event
  26. for event in self.state.history.get_events()
  27. if not (
  28. (isinstance(event, MessageAction) and event.source == EventSource.USER)
  29. or
  30. # there might be some NullAction or NullObservation in the history at least for now
  31. isinstance(event, NullAction)
  32. or isinstance(event, NullObservation)
  33. )
  34. ]
  35. # it takes 3 actions minimum to detect a loop, otherwise nothing to do here
  36. if len(filtered_history) < 3:
  37. return False
  38. # the first few scenarios detect 3 or 4 repeated steps
  39. # prepare the last 4 actions and observations, to check them out
  40. last_actions: list[Event] = []
  41. last_observations: list[Event] = []
  42. # retrieve the last four actions and observations starting from the end of history, wherever they are
  43. for event in reversed(filtered_history):
  44. if isinstance(event, Action) and len(last_actions) < 4:
  45. last_actions.append(event)
  46. elif isinstance(event, Observation) and len(last_observations) < 4:
  47. last_observations.append(event)
  48. if len(last_actions) == 4 and len(last_observations) == 4:
  49. break
  50. # scenario 1: same action, same observation
  51. if self._is_stuck_repeating_action_observation(last_actions, last_observations):
  52. return True
  53. # scenario 2: same action, errors
  54. if self._is_stuck_repeating_action_error(last_actions, last_observations):
  55. return True
  56. # scenario 3: monologue
  57. if self._is_stuck_monologue(filtered_history):
  58. return True
  59. # scenario 4: action, observation pattern on the last six steps
  60. if len(filtered_history) < 6:
  61. return False
  62. if self._is_stuck_action_observation_pattern(filtered_history):
  63. return True
  64. return False
  65. def _is_stuck_repeating_action_observation(self, last_actions, last_observations):
  66. # scenario 1: same action, same observation
  67. # it takes 4 actions and 4 observations to detect a loop
  68. # assert len(last_actions) == 4 and len(last_observations) == 4
  69. # reset almost_stuck reminder
  70. self.state.almost_stuck = 0
  71. # almost stuck? if two actions, obs are the same, we're almost stuck
  72. if len(last_actions) >= 2 and len(last_observations) >= 2:
  73. actions_equal = all(
  74. self._eq_no_pid(last_actions[0], action) for action in last_actions[:2]
  75. )
  76. observations_equal = all(
  77. self._eq_no_pid(last_observations[0], observation)
  78. for observation in last_observations[:2]
  79. )
  80. # the last two actions and obs are the same?
  81. if actions_equal and observations_equal:
  82. self.state.almost_stuck = 2
  83. # the last three actions and observations are the same?
  84. if len(last_actions) >= 3 and len(last_observations) >= 3:
  85. if (
  86. actions_equal
  87. and observations_equal
  88. and self._eq_no_pid(last_actions[0], last_actions[2])
  89. and self._eq_no_pid(last_observations[0], last_observations[2])
  90. ):
  91. self.state.almost_stuck = 1
  92. if len(last_actions) == 4 and len(last_observations) == 4:
  93. if (
  94. actions_equal
  95. and observations_equal
  96. and self._eq_no_pid(last_actions[0], last_actions[3])
  97. and self._eq_no_pid(last_observations[0], last_observations[3])
  98. ):
  99. logger.warning('Action, Observation loop detected')
  100. self.state.almost_stuck = 0
  101. return True
  102. return False
  103. def _is_stuck_repeating_action_error(self, last_actions, last_observations):
  104. # scenario 2: same action, errors
  105. # it takes 3 actions and 3 observations to detect a loop
  106. # check if the last three actions are the same and result in errors
  107. if len(last_actions) < 4 or len(last_observations) < 4:
  108. return False
  109. # are the last three actions the "same"?
  110. if all(self._eq_no_pid(last_actions[0], action) for action in last_actions[:3]):
  111. # and the last three observations are all errors?
  112. if all(isinstance(obs, ErrorObservation) for obs in last_observations[:3]):
  113. logger.warning('Action, ErrorObservation loop detected')
  114. return True
  115. # or, are the last three observations all IPythonRunCellObservation with SyntaxError?
  116. elif all(
  117. isinstance(obs, IPythonRunCellObservation)
  118. for obs in last_observations[:3]
  119. ):
  120. warning = 'Action, IPythonRunCellObservation loop detected'
  121. for error_message in self.SYNTAX_ERROR_MESSAGES:
  122. if error_message.startswith(
  123. 'SyntaxError: unterminated string literal (detected at line'
  124. ):
  125. if self._check_for_consistent_line_error(
  126. last_observations[:3], error_message
  127. ):
  128. logger.warning(warning)
  129. return True
  130. elif error_message in [
  131. 'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
  132. 'SyntaxError: incomplete input',
  133. ]:
  134. if self._check_for_consistent_invalid_syntax(
  135. last_observations[:3], error_message
  136. ):
  137. logger.warning(warning)
  138. return True
  139. return False
  140. def _check_for_consistent_invalid_syntax(self, observations, error_message):
  141. first_lines = []
  142. valid_observations = []
  143. for obs in observations:
  144. content = obs.content
  145. lines = content.strip().split('\n')
  146. if len(lines) < 4:
  147. return False
  148. first_lines.append(lines[0]) # Store the first line of each observation
  149. # Check last three lines
  150. if lines[-2].startswith('[Jupyter current working directory:') and lines[
  151. -1
  152. ].startswith('[Jupyter Python interpreter:'):
  153. if error_message in lines[-3]:
  154. valid_observations.append(obs)
  155. break
  156. # Check if:
  157. # 1. All first lines are identical
  158. # 2. We have exactly 3 valid observations
  159. # 3. The error message line is identical in all valid observations
  160. return (
  161. len(set(first_lines)) == 1
  162. and len(valid_observations) == 3
  163. and len(
  164. set(
  165. obs.content.strip().split('\n')[:-2][-1]
  166. for obs in valid_observations
  167. )
  168. )
  169. == 1
  170. )
  171. def _check_for_consistent_line_error(self, observations, error_message):
  172. error_lines = []
  173. for obs in observations:
  174. content = obs.content
  175. lines = content.strip().split('\n')
  176. if len(lines) < 3:
  177. return False
  178. last_lines = lines[-3:]
  179. # Check if the last two lines are our own
  180. if not (
  181. last_lines[-2].startswith('[Jupyter current working directory:')
  182. and last_lines[-1].startswith('[Jupyter Python interpreter:')
  183. ):
  184. return False
  185. # Check for the error message in the 3rd-to-last line
  186. if error_message in last_lines[-3]:
  187. error_lines.append(last_lines[-3])
  188. # Check if we found the error message in all 3 observations
  189. # and the 3rd-to-last line is identical across all occurrences
  190. return len(error_lines) == 3 and len(set(error_lines)) == 1
  191. def _is_stuck_monologue(self, filtered_history):
  192. # scenario 3: monologue
  193. # check for repeated MessageActions with source=AGENT
  194. # see if the agent is engaged in a good old monologue, telling itself the same thing over and over
  195. agent_message_actions = [
  196. (i, event)
  197. for i, event in enumerate(filtered_history)
  198. if isinstance(event, MessageAction) and event.source == EventSource.AGENT
  199. ]
  200. # last three message actions will do for this check
  201. if len(agent_message_actions) >= 3:
  202. last_agent_message_actions = agent_message_actions[-3:]
  203. if all(
  204. (last_agent_message_actions[0][1] == action[1])
  205. for action in last_agent_message_actions
  206. ):
  207. # check if there are any observations between the repeated MessageActions
  208. # then it's not yet a loop, maybe it can recover
  209. start_index = last_agent_message_actions[0][0]
  210. end_index = last_agent_message_actions[-1][0]
  211. has_observation_between = False
  212. for event in filtered_history[start_index + 1 : end_index]:
  213. if isinstance(event, Observation):
  214. has_observation_between = True
  215. break
  216. if not has_observation_between:
  217. logger.warning('Repeated MessageAction with source=AGENT detected')
  218. return True
  219. return False
  220. def _is_stuck_action_observation_pattern(self, filtered_history):
  221. # scenario 4: action, observation pattern on the last six steps
  222. # check if the agent repeats the same (Action, Observation)
  223. # every other step in the last six steps
  224. last_six_actions: list[Event] = []
  225. last_six_observations: list[Event] = []
  226. # the end of history is most interesting
  227. for event in reversed(filtered_history):
  228. if isinstance(event, Action) and len(last_six_actions) < 6:
  229. last_six_actions.append(event)
  230. elif isinstance(event, Observation) and len(last_six_observations) < 6:
  231. last_six_observations.append(event)
  232. if len(last_six_actions) == 6 and len(last_six_observations) == 6:
  233. break
  234. # this pattern is every other step, like:
  235. # (action_1, obs_1), (action_2, obs_2), (action_1, obs_1), (action_2, obs_2),...
  236. if len(last_six_actions) == 6 and len(last_six_observations) == 6:
  237. actions_equal = (
  238. # action_0 == action_2 == action_4
  239. self._eq_no_pid(last_six_actions[0], last_six_actions[2])
  240. and self._eq_no_pid(last_six_actions[0], last_six_actions[4])
  241. # action_1 == action_3 == action_5
  242. and self._eq_no_pid(last_six_actions[1], last_six_actions[3])
  243. and self._eq_no_pid(last_six_actions[1], last_six_actions[5])
  244. )
  245. observations_equal = (
  246. # obs_0 == obs_2 == obs_4
  247. self._eq_no_pid(last_six_observations[0], last_six_observations[2])
  248. and self._eq_no_pid(last_six_observations[0], last_six_observations[4])
  249. # obs_1 == obs_3 == obs_5
  250. and self._eq_no_pid(last_six_observations[1], last_six_observations[3])
  251. and self._eq_no_pid(last_six_observations[1], last_six_observations[5])
  252. )
  253. if actions_equal and observations_equal:
  254. logger.warning('Action, Observation pattern detected')
  255. return True
  256. return False
  257. def _eq_no_pid(self, obj1, obj2):
  258. if isinstance(obj1, CmdOutputObservation) and isinstance(
  259. obj2, CmdOutputObservation
  260. ):
  261. # for loop detection, ignore command_id, which is the pid
  262. return obj1.command == obj2.command and obj1.exit_code == obj2.exit_code
  263. else:
  264. # this is the default comparison
  265. return obj1 == obj2