stuck.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. from typing import cast
  2. from openhands.controller.state.state import State
  3. from openhands.core.logger import openhands_logger as logger
  4. from openhands.events.action.action import Action
  5. from openhands.events.action.empty import NullAction
  6. from openhands.events.action.message import MessageAction
  7. from openhands.events.event import Event, EventSource
  8. from openhands.events.observation.commands import (
  9. CmdOutputObservation,
  10. IPythonRunCellObservation,
  11. )
  12. from openhands.events.observation.empty import NullObservation
  13. from openhands.events.observation.error import ErrorObservation
  14. from openhands.events.observation.observation import Observation
  15. class StuckDetector:
  16. def __init__(self, state: State):
  17. self.state = state
  18. def is_stuck(self):
  19. # filter out MessageAction with source='user' from history
  20. filtered_history = [
  21. event
  22. for event in self.state.history.get_events()
  23. if not (
  24. (isinstance(event, MessageAction) and event.source == EventSource.USER)
  25. or
  26. # there might be some NullAction or NullObservation in the history at least for now
  27. isinstance(event, NullAction)
  28. or isinstance(event, NullObservation)
  29. )
  30. ]
  31. # it takes 3 actions minimum to detect a loop, otherwise nothing to do here
  32. if len(filtered_history) < 3:
  33. return False
  34. # the first few scenarios detect 3 or 4 repeated steps
  35. # prepare the last 4 actions and observations, to check them out
  36. last_actions: list[Event] = []
  37. last_observations: list[Event] = []
  38. # retrieve the last four actions and observations starting from the end of history, wherever they are
  39. for event in reversed(filtered_history):
  40. if isinstance(event, Action) and len(last_actions) < 4:
  41. last_actions.append(event)
  42. elif isinstance(event, Observation) and len(last_observations) < 4:
  43. last_observations.append(event)
  44. if len(last_actions) == 4 and len(last_observations) == 4:
  45. break
  46. # scenario 1: same action, same observation
  47. if self._is_stuck_repeating_action_observation(last_actions, last_observations):
  48. return True
  49. # scenario 2: same action, errors
  50. if self._is_stuck_repeating_action_error(last_actions, last_observations):
  51. return True
  52. # scenario 3: monologue
  53. if self._is_stuck_monologue(filtered_history):
  54. return True
  55. # scenario 4: action, observation pattern on the last six steps
  56. if len(filtered_history) < 6:
  57. return False
  58. if self._is_stuck_action_observation_pattern(filtered_history):
  59. return True
  60. return False
  61. def _is_stuck_repeating_action_observation(self, last_actions, last_observations):
  62. # scenario 1: same action, same observation
  63. # it takes 4 actions and 4 observations to detect a loop
  64. # assert len(last_actions) == 4 and len(last_observations) == 4
  65. # reset almost_stuck reminder
  66. self.state.almost_stuck = 0
  67. # almost stuck? if two actions, obs are the same, we're almost stuck
  68. if len(last_actions) >= 2 and len(last_observations) >= 2:
  69. actions_equal = all(
  70. self._eq_no_pid(last_actions[0], action) for action in last_actions[:2]
  71. )
  72. observations_equal = all(
  73. self._eq_no_pid(last_observations[0], observation)
  74. for observation in last_observations[:2]
  75. )
  76. # the last two actions and obs are the same?
  77. if actions_equal and observations_equal:
  78. self.state.almost_stuck = 2
  79. # the last three actions and observations are the same?
  80. if len(last_actions) >= 3 and len(last_observations) >= 3:
  81. if (
  82. actions_equal
  83. and observations_equal
  84. and self._eq_no_pid(last_actions[0], last_actions[2])
  85. and self._eq_no_pid(last_observations[0], last_observations[2])
  86. ):
  87. self.state.almost_stuck = 1
  88. if len(last_actions) == 4 and len(last_observations) == 4:
  89. if (
  90. actions_equal
  91. and observations_equal
  92. and self._eq_no_pid(last_actions[0], last_actions[3])
  93. and self._eq_no_pid(last_observations[0], last_observations[3])
  94. ):
  95. logger.warning('Action, Observation loop detected')
  96. self.state.almost_stuck = 0
  97. return True
  98. return False
  99. def _is_stuck_repeating_action_error(self, last_actions, last_observations):
  100. # scenario 2: same action, errors
  101. # it takes 4 actions and 4 observations to detect a loop
  102. # check if the last four actions are the same and result in errors
  103. # are the last four actions the same?
  104. if len(last_actions) == 4 and all(
  105. self._eq_no_pid(last_actions[0], action) for action in last_actions
  106. ):
  107. # and the last four observations all errors?
  108. if all(isinstance(obs, ErrorObservation) for obs in last_observations):
  109. logger.warning('Action, ErrorObservation loop detected')
  110. return True
  111. # or, are the last four observations all IPythonRunCellObservation with SyntaxError?
  112. elif all(
  113. isinstance(obs, IPythonRunCellObservation) for obs in last_observations
  114. ) and all(
  115. cast(IPythonRunCellObservation, obs)
  116. .content[-100:]
  117. .find('SyntaxError: unterminated string literal (detected at line')
  118. != -1
  119. and len(
  120. cast(IPythonRunCellObservation, obs).content.split(
  121. 'SyntaxError: unterminated string literal (detected at line'
  122. )[-1]
  123. )
  124. < 10
  125. for obs in last_observations
  126. ):
  127. logger.warning('Action, IPythonRunCellObservation loop detected')
  128. return True
  129. return False
  130. def _is_stuck_monologue(self, filtered_history):
  131. # scenario 3: monologue
  132. # check for repeated MessageActions with source=AGENT
  133. # see if the agent is engaged in a good old monologue, telling itself the same thing over and over
  134. agent_message_actions = [
  135. (i, event)
  136. for i, event in enumerate(filtered_history)
  137. if isinstance(event, MessageAction) and event.source == EventSource.AGENT
  138. ]
  139. # last three message actions will do for this check
  140. if len(agent_message_actions) >= 3:
  141. last_agent_message_actions = agent_message_actions[-3:]
  142. if all(
  143. (last_agent_message_actions[0][1] == action[1])
  144. for action in last_agent_message_actions
  145. ):
  146. # check if there are any observations between the repeated MessageActions
  147. # then it's not yet a loop, maybe it can recover
  148. start_index = last_agent_message_actions[0][0]
  149. end_index = last_agent_message_actions[-1][0]
  150. has_observation_between = False
  151. for event in filtered_history[start_index + 1 : end_index]:
  152. if isinstance(event, Observation):
  153. has_observation_between = True
  154. break
  155. if not has_observation_between:
  156. logger.warning('Repeated MessageAction with source=AGENT detected')
  157. return True
  158. return False
  159. def _is_stuck_action_observation_pattern(self, filtered_history):
  160. # scenario 4: action, observation pattern on the last six steps
  161. # check if the agent repeats the same (Action, Observation)
  162. # every other step in the last six steps
  163. last_six_actions: list[Event] = []
  164. last_six_observations: list[Event] = []
  165. # the end of history is most interesting
  166. for event in reversed(filtered_history):
  167. if isinstance(event, Action) and len(last_six_actions) < 6:
  168. last_six_actions.append(event)
  169. elif isinstance(event, Observation) and len(last_six_observations) < 6:
  170. last_six_observations.append(event)
  171. if len(last_six_actions) == 6 and len(last_six_observations) == 6:
  172. break
  173. # this pattern is every other step, like:
  174. # (action_1, obs_1), (action_2, obs_2), (action_1, obs_1), (action_2, obs_2),...
  175. if len(last_six_actions) == 6 and len(last_six_observations) == 6:
  176. actions_equal = (
  177. # action_0 == action_2 == action_4
  178. self._eq_no_pid(last_six_actions[0], last_six_actions[2])
  179. and self._eq_no_pid(last_six_actions[0], last_six_actions[4])
  180. # action_1 == action_3 == action_5
  181. and self._eq_no_pid(last_six_actions[1], last_six_actions[3])
  182. and self._eq_no_pid(last_six_actions[1], last_six_actions[5])
  183. )
  184. observations_equal = (
  185. # obs_0 == obs_2 == obs_4
  186. self._eq_no_pid(last_six_observations[0], last_six_observations[2])
  187. and self._eq_no_pid(last_six_observations[0], last_six_observations[4])
  188. # obs_1 == obs_3 == obs_5
  189. and self._eq_no_pid(last_six_observations[1], last_six_observations[3])
  190. and self._eq_no_pid(last_six_observations[1], last_six_observations[5])
  191. )
  192. if actions_equal and observations_equal:
  193. logger.warning('Action, Observation pattern detected')
  194. return True
  195. return False
  196. def _eq_no_pid(self, obj1, obj2):
  197. if isinstance(obj1, CmdOutputObservation) and isinstance(
  198. obj2, CmdOutputObservation
  199. ):
  200. # for loop detection, ignore command_id, which is the pid
  201. return obj1.command == obj2.command and obj1.exit_code == obj2.exit_code
  202. else:
  203. # this is the default comparison
  204. return obj1 == obj2