浏览代码

(enh) StuckDetector: fix+enhance syntax error loop detection (#3628)

* fix StuckDetector and add more errors for detection

* more stringent error detection and more unit tests
tobitege 1 年之前
父节点
当前提交
a2d94c9cb1
共有 2 个文件被更改,包括 192 次插入87 次删除
  1. 100 25
      openhands/controller/stuck.py
  2. 92 62
      tests/unit/test_is_stuck.py

+ 100 - 25
openhands/controller/stuck.py

@@ -1,5 +1,3 @@
-from typing import cast
-
 from openhands.controller.state.state import State
 from openhands.core.logger import openhands_logger as logger
 from openhands.events.action.action import Action
@@ -16,6 +14,12 @@ from openhands.events.observation.observation import Observation
 
 
 class StuckDetector:
+    SYNTAX_ERROR_MESSAGES = [
+        'SyntaxError: unterminated string literal (detected at line',
+        'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
+        'SyntaxError: incomplete input',
+    ]
+
     def __init__(self, state: State):
         self.state = state
 
@@ -119,36 +123,107 @@ class StuckDetector:
 
     def _is_stuck_repeating_action_error(self, last_actions, last_observations):
         # scenario 2: same action, errors
-        # it takes 4 actions and 4 observations to detect a loop
-        # check if the last four actions are the same and result in errors
+        # it takes 3 actions and 3 observations to detect a loop
+        # check if the last three actions are the same and result in errors
 
-        # are the last four actions the same?
-        if len(last_actions) == 4 and all(
-            self._eq_no_pid(last_actions[0], action) for action in last_actions
-        ):
-            # and the last four observations all errors?
-            if all(isinstance(obs, ErrorObservation) for obs in last_observations):
+        if len(last_actions) < 4 or len(last_observations) < 4:
+            return False
+
+        # are the last three actions the "same"?
+        if all(self._eq_no_pid(last_actions[0], action) for action in last_actions[:3]):
+            # and the last three observations are all errors?
+            if all(isinstance(obs, ErrorObservation) for obs in last_observations[:3]):
                 logger.warning('Action, ErrorObservation loop detected')
                 return True
-            # or, are the last four observations all IPythonRunCellObservation with SyntaxError?
+            # or, are the last three observations all IPythonRunCellObservation with SyntaxError?
             elif all(
-                isinstance(obs, IPythonRunCellObservation) for obs in last_observations
-            ) and all(
-                cast(IPythonRunCellObservation, obs)
-                .content[-100:]
-                .find('SyntaxError: unterminated string literal (detected at line')
-                != -1
-                and len(
-                    cast(IPythonRunCellObservation, obs).content.split(
+                isinstance(obs, IPythonRunCellObservation)
+                for obs in last_observations[:3]
+            ):
+                warning = 'Action, IPythonRunCellObservation loop detected'
+                for error_message in self.SYNTAX_ERROR_MESSAGES:
+                    if error_message.startswith(
                         'SyntaxError: unterminated string literal (detected at line'
-                    )[-1]
+                    ):
+                        if self._check_for_consistent_line_error(
+                            last_observations[:3], error_message
+                        ):
+                            logger.warning(warning)
+                            return True
+                    elif error_message in [
+                        'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
+                        'SyntaxError: incomplete input',
+                    ]:
+                        if self._check_for_consistent_invalid_syntax(
+                            last_observations[:3], error_message
+                        ):
+                            logger.warning(warning)
+                            return True
+        return False
+
+    def _check_for_consistent_invalid_syntax(self, observations, error_message):
+        first_lines = []
+        valid_observations = []
+
+        for obs in observations:
+            content = obs.content
+            lines = content.strip().split('\n')
+
+            if len(lines) < 4:
+                return False
+
+            first_lines.append(lines[0])  # Store the first line of each observation
+
+            # Check last three lines
+            if lines[-2].startswith('[Jupyter current working directory:') and lines[
+                -1
+            ].startswith('[Jupyter Python interpreter:'):
+                if error_message in lines[-3]:
+                    valid_observations.append(obs)
+                    break
+
+        # Check if:
+        # 1. All first lines are identical
+        # 2. We have exactly 3 valid observations
+        # 3. The error message line is identical in all valid observations
+        return (
+            len(set(first_lines)) == 1
+            and len(valid_observations) == 3
+            and len(
+                set(
+                    obs.content.strip().split('\n')[:-2][-1]
+                    for obs in valid_observations
                 )
-                < 10
-                for obs in last_observations
+            )
+            == 1
+        )
+
+    def _check_for_consistent_line_error(self, observations, error_message):
+        error_lines = []
+
+        for obs in observations:
+            content = obs.content
+            lines = content.strip().split('\n')
+
+            if len(lines) < 3:
+                return False
+
+            last_lines = lines[-3:]
+
+            # Check if the last two lines are our own
+            if not (
+                last_lines[-2].startswith('[Jupyter current working directory:')
+                and last_lines[-1].startswith('[Jupyter Python interpreter:')
             ):
-                logger.warning('Action, IPythonRunCellObservation loop detected')
-                return True
-        return False
+                return False
+
+            # Check for the error message in the 3rd-to-last line
+            if error_message in last_lines[-3]:
+                error_lines.append(last_lines[-3])
+
+        # Check if we found the error message in all 3 observations
+        # and the 3rd-to-last line is identical across all occurrences
+        return len(error_lines) == 3 and len(set(error_lines)) == 1
 
     def _is_stuck_monologue(self, filtered_history):
         # scenario 3: monologue

+ 92 - 62
tests/unit/test_is_stuck.py

@@ -1,4 +1,5 @@
 import logging
+import random
 from unittest.mock import Mock, patch
 
 import pytest
@@ -27,6 +28,9 @@ def collect_events(stream):
 
 logging.basicConfig(level=logging.DEBUG)
 
+jupyter_line_1 = '\n[Jupyter current working directory:'
+jupyter_line_2 = '\n[Jupyter Python interpreter:'
+
 
 @pytest.fixture
 def temp_dir(tmp_path_factory: TempPathFactory) -> str:
@@ -46,11 +50,44 @@ class TestStuckDetector:
     @pytest.fixture
     def stuck_detector(self, event_stream):
         state = State(inputs={}, max_iterations=50)
-        # state.history = ShortTermHistory()
         state.history.set_event_stream(event_stream)
 
         return StuckDetector(state)
 
+    def _impl_syntax_error_events(
+        self, event_stream: EventStream, error_message: str, random_line: bool
+    ):
+        for _ in range(4):
+            ipython_action = IPythonRunCellAction(code='print("hello')
+            event_stream.add_event(ipython_action, EventSource.AGENT)
+            extra_number = random.randint(88, 222) if random_line else '42'
+            extra_line = '\n' * random.randint(2, 3) if random_line else ''
+            ipython_observation = IPythonRunCellObservation(
+                content=f'  Cell In[1], line {extra_number}\n'
+                'to_replace="""def largest(min_factor, max_factor):\n            ^\n'
+                f'{error_message}{extra_line}' + jupyter_line_1 + jupyter_line_2,
+                code='print("hello',
+            )
+            print(ipython_observation.content)
+            ipython_observation._cause = ipython_action._id
+            event_stream.add_event(ipython_observation, EventSource.USER)
+
+    def _impl_unterminated_string_error_events(
+        self, event_stream: EventStream, random_line: bool
+    ):
+        for _ in range(4):
+            ipython_action = IPythonRunCellAction(code='print("hello')
+            event_stream.add_event(ipython_action, EventSource.AGENT)
+            line_number = str(random.randint(1, 10)) if random_line else '1'
+            ipython_observation = IPythonRunCellObservation(
+                content=f'print("hello\n       ^\nSyntaxError: unterminated string literal (detected at line {line_number})'
+                + jupyter_line_1
+                + jupyter_line_2,
+                code='print("hello',
+            )
+            ipython_observation._cause = ipython_action._id
+            event_stream.add_event(ipython_observation, EventSource.USER)
+
     def test_history_too_short(
         self, stuck_detector: StuckDetector, event_stream: EventStream
     ):
@@ -202,81 +239,75 @@ class TestStuckDetector:
                 'Action, ErrorObservation loop detected'
             )
 
-    def test_is_stuck_ipython_syntax_error(
+    def test_is_stuck_invalid_syntax_error(
         self, stuck_detector: StuckDetector, event_stream: EventStream
     ):
-        ipython_action_1 = IPythonRunCellAction(code='print("hello')
-        event_stream.add_event(ipython_action_1, EventSource.AGENT)
-        ipython_observation_1 = IPythonRunCellObservation(
-            content='print("hello\n       ^\nSyntaxError: unterminated string literal (detected at line 1)',
-            code='print("hello',
+        self._impl_syntax_error_events(
+            event_stream,
+            error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
+            random_line=False,
         )
-        ipython_observation_1._cause = ipython_action_1._id
-        event_stream.add_event(ipython_observation_1, EventSource.USER)
 
-        ipython_action_2 = IPythonRunCellAction(code='print("hello')
-        event_stream.add_event(ipython_action_2, EventSource.AGENT)
-        ipython_observation_2 = IPythonRunCellObservation(
-            content='print("hello\n       ^\nSyntaxError: unterminated string literal (detected at line 1)',
-            code='print("hello',
+        with patch('logging.Logger.warning'):
+            assert stuck_detector.is_stuck() is True
+
+    def test_is_not_stuck_invalid_syntax_error(
+        self, stuck_detector: StuckDetector, event_stream: EventStream
+    ):
+        self._impl_syntax_error_events(
+            event_stream,
+            error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
+            random_line=True,
         )
-        ipython_observation_2._cause = ipython_action_2._id
-        event_stream.add_event(ipython_observation_2, EventSource.USER)
 
-        ipython_action_3 = IPythonRunCellAction(code='print("hello')
-        event_stream.add_event(ipython_action_3, EventSource.AGENT)
-        ipython_observation_3 = IPythonRunCellObservation(
-            content='print("hello\n       ^\nSyntaxError: unterminated string literal (detected at line 3)',
-            code='print("hello',
+        with patch('logging.Logger.warning'):
+            assert stuck_detector.is_stuck() is False
+
+    def test_is_stuck_incomplete_input_error(
+        self, stuck_detector: StuckDetector, event_stream: EventStream
+    ):
+        self._impl_syntax_error_events(
+            event_stream,
+            error_message='SyntaxError: incomplete input',
+            random_line=False,
         )
-        ipython_observation_3._cause = ipython_action_3._id
-        event_stream.add_event(ipython_observation_3, EventSource.USER)
 
-        ipython_action_4 = IPythonRunCellAction(code='print("hello')
-        event_stream.add_event(ipython_action_4, EventSource.AGENT)
-        ipython_observation_4 = IPythonRunCellObservation(
-            content='print("hello\n       ^\nSyntaxError: unterminated string literal (detected at line 2)',
-            code='print("hello',
+        with patch('logging.Logger.warning'):
+            assert stuck_detector.is_stuck() is True
+
+    def test_is_not_stuck_incomplete_input_error(
+        self, stuck_detector: StuckDetector, event_stream: EventStream
+    ):
+        self._impl_syntax_error_events(
+            event_stream,
+            error_message='SyntaxError: incomplete input',
+            random_line=True,
         )
-        ipython_observation_4._cause = ipython_action_4._id
-        event_stream.add_event(ipython_observation_4, EventSource.USER)
 
-        # stuck_detector.state.history.set_event_stream(event_stream)
+        with patch('logging.Logger.warning'):
+            assert stuck_detector.is_stuck() is False
 
-        last_observations = [
-            ipython_observation_1,
-            ipython_observation_2,
-            ipython_observation_3,
-            ipython_observation_4,
-        ]
-        for observation in last_observations:
-            has_string = (
-                observation.content[-100:].find(
-                    'SyntaxError: unterminated string literal (detected at line'
-                )
-                != -1
-            )
-            assert has_string
-
-            string_is_last = (
-                len(
-                    observation.content.split(
-                        'SyntaxError: unterminated string literal (detected at line'
-                    )[-1]
-                )
-                < 10
-            )
-            assert string_is_last
+    def test_is_not_stuck_ipython_unterminated_string_error(
+        self, stuck_detector: StuckDetector, event_stream: EventStream
+    ):
+        self._impl_unterminated_string_error_events(event_stream, random_line=True)
 
-        with patch('logging.Logger.warning') as mock_warning:
+        with patch('logging.Logger.warning'):
+            assert stuck_detector.is_stuck() is False
+
+    def test_is_stuck_ipython_unterminated_string_error(
+        self, stuck_detector: StuckDetector, event_stream: EventStream
+    ):
+        self._impl_unterminated_string_error_events(event_stream, random_line=False)
+
+        with patch('logging.Logger.warning'):
             assert stuck_detector.is_stuck() is True
-            mock_warning.assert_called_once_with(
-                'Action, IPythonRunCellObservation loop detected'
-            )
 
     def test_is_stuck_ipython_syntax_error_not_at_end(
         self, stuck_detector: StuckDetector, event_stream: EventStream
     ):
+        # this test is to make sure we don't get false positives
+        # since the "at line x" is changing in between!
         ipython_action_1 = IPythonRunCellAction(code='print("hello')
         event_stream.add_event(ipython_action_1, EventSource.AGENT)
         ipython_observation_1 = IPythonRunCellObservation(
@@ -480,8 +511,6 @@ class TestStuckDetector:
         event_stream.add_event(message_action_6, EventSource.AGENT)
         message_action_6._source = EventSource.AGENT
 
-        # stuck_detector.state.history.set_event_stream(event_stream)
-
         assert stuck_detector.is_stuck()
 
         # Add an observation event between the repeated message actions
@@ -502,7 +531,8 @@ class TestStuckDetector:
         event_stream.add_event(message_action_8, EventSource.AGENT)
         message_action_8._source = EventSource.AGENT
 
-        assert not stuck_detector.is_stuck()
+        with patch('logging.Logger.warning'):
+            assert not stuck_detector.is_stuck()
 
 
 class TestAgentController: