Ver código fonte

(fix) StuckDetector: syntax error loops were not detected (#3663)

Co-authored-by: mamoodi <mamoodiha@gmail.com>
tobitege 1 ano atrás
pai
commit
6111f530c2
2 arquivos alterados com 84 adições e 32 exclusões
  1. 36 16
      openhands/controller/stuck.py
  2. 48 16
      tests/unit/test_is_stuck.py

+ 36 - 16
openhands/controller/stuck.py

@@ -1,6 +1,7 @@
 from openhands.controller.state.state import State
 from openhands.core.logger import openhands_logger as logger
 from openhands.events.action.action import Action
+from openhands.events.action.commands import IPythonRunCellAction
 from openhands.events.action.empty import NullAction
 from openhands.events.action.message import MessageAction
 from openhands.events.event import Event, EventSource
@@ -150,15 +151,14 @@ class StuckDetector:
                         ):
                             logger.warning(warning)
                             return True
-                    elif error_message in [
+                    elif error_message in (
                         'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
                         'SyntaxError: incomplete input',
-                    ]:
-                        if self._check_for_consistent_invalid_syntax(
-                            last_observations[:3], error_message
-                        ):
-                            logger.warning(warning)
-                            return True
+                    ) and self._check_for_consistent_invalid_syntax(
+                        last_observations[:3], error_message
+                    ):
+                        logger.warning(warning)
+                        return True
         return False
 
     def _check_for_consistent_invalid_syntax(self, observations, error_message):
@@ -169,18 +169,22 @@ class StuckDetector:
             content = obs.content
             lines = content.strip().split('\n')
 
-            if len(lines) < 4:
+            if len(lines) < 6:  # 6 because a real syntax error has at least 6 lines
                 return False
 
-            first_lines.append(lines[0])  # Store the first line of each observation
+            line1 = lines[0].strip()
+            if not line1.startswith('Cell In[1], line'):
+                return False
+
+            first_lines.append(line1)  # Store the first line of each observation
 
             # Check last three lines
-            if lines[-2].startswith('[Jupyter current working directory:') and lines[
-                -1
-            ].startswith('[Jupyter Python interpreter:'):
-                if error_message in lines[-3]:
-                    valid_observations.append(obs)
-                    break
+            if (
+                lines[-1].startswith('[Jupyter Python interpreter:')
+                and lines[-2].startswith('[Jupyter current working directory:')
+                and error_message in lines[-3]
+            ):
+                valid_observations.append(obs)
 
         # Check if:
         # 1. All first lines are identical
@@ -302,7 +306,23 @@ class StuckDetector:
         return False
 
     def _eq_no_pid(self, obj1, obj2):
-        if isinstance(obj1, CmdOutputObservation) and isinstance(
+        if isinstance(obj1, IPythonRunCellAction) and isinstance(
+            obj2, IPythonRunCellAction
+        ):
+            # for loop detection on edit actions, ignore the thought, compare some code
+            # the code should have at least 3 lines, to avoid simple one-liners
+            if (
+                'edit_file_by_replace(' in obj1.code
+                and 'edit_file_by_replace(' in obj2.code
+            ):
+                return (
+                    len(obj1.code.split('\n')) > 2
+                    and obj1.code.split('\n')[:3] == obj2.code.split('\n')[:3]
+                )
+            else:
+                # default comparison
+                return obj1 == obj2
+        elif isinstance(obj1, CmdOutputObservation) and isinstance(
             obj2, CmdOutputObservation
         ):
             # for loop detection, ignore command_id, which is the pid

+ 48 - 16
tests/unit/test_is_stuck.py

@@ -1,5 +1,4 @@
 import logging
-import random
 from unittest.mock import Mock, patch
 
 import pytest
@@ -30,6 +29,13 @@ logging.basicConfig(level=logging.DEBUG)
 
 jupyter_line_1 = '\n[Jupyter current working directory:'
 jupyter_line_2 = '\n[Jupyter Python interpreter:'
+code_snippet = """
+edit_file_by_replace(
+    'book_store.py',
+    to_replace=\"""def total(basket):
+    if not basket:
+        return 0
+"""
 
 
 @pytest.fixture
@@ -55,35 +61,38 @@ class TestStuckDetector:
         return StuckDetector(state)
 
     def _impl_syntax_error_events(
-        self, event_stream: EventStream, error_message: str, random_line: bool
+        self,
+        event_stream: EventStream,
+        error_message: str,
+        random_line: bool,
+        incidents: int = 4,
     ):
-        for _ in range(4):
-            ipython_action = IPythonRunCellAction(code='print("hello')
+        for i in range(incidents):
+            ipython_action = IPythonRunCellAction(code=code_snippet)
             event_stream.add_event(ipython_action, EventSource.AGENT)
-            extra_number = random.randint(88, 222) if random_line else '42'
-            extra_line = '\n' * random.randint(2, 3) if random_line else ''
+            extra_number = (i + 1) * 10 if random_line else '42'
+            extra_line = '\n' * (i + 1) if random_line else ''
             ipython_observation = IPythonRunCellObservation(
                 content=f'  Cell In[1], line {extra_number}\n'
                 'to_replace="""def largest(min_factor, max_factor):\n            ^\n'
                 f'{error_message}{extra_line}' + jupyter_line_1 + jupyter_line_2,
-                code='print("hello',
+                code=code_snippet,
             )
-            print(ipython_observation.content)
             ipython_observation._cause = ipython_action._id
             event_stream.add_event(ipython_observation, EventSource.USER)
 
     def _impl_unterminated_string_error_events(
-        self, event_stream: EventStream, random_line: bool
+        self, event_stream: EventStream, random_line: bool, incidents: int = 4
     ):
-        for i in range(4):
-            ipython_action = IPythonRunCellAction(code='print("hello')
+        for i in range(incidents):
+            ipython_action = IPythonRunCellAction(code=code_snippet)
             event_stream.add_event(ipython_action, EventSource.AGENT)
-            line_number = i * 10 if random_line else '1'
+            line_number = (i + 1) * 10 if random_line else '1'
             ipython_observation = IPythonRunCellObservation(
                 content=f'print("  Cell In[1], line {line_number}\nhello\n       ^\nSyntaxError: unterminated string literal (detected at line {line_number})'
                 + jupyter_line_1
                 + jupyter_line_2,
-                code='print("hello',
+                code=code_snippet,
             )
             ipython_observation._cause = ipython_action._id
             event_stream.add_event(ipython_observation, EventSource.USER)
@@ -251,7 +260,7 @@ class TestStuckDetector:
         with patch('logging.Logger.warning'):
             assert stuck_detector.is_stuck() is True
 
-    def test_is_not_stuck_invalid_syntax_error(
+    def test_is_not_stuck_invalid_syntax_error_random_lines(
         self, stuck_detector: StuckDetector, event_stream: EventStream
     ):
         self._impl_syntax_error_events(
@@ -263,6 +272,19 @@ class TestStuckDetector:
         with patch('logging.Logger.warning'):
             assert stuck_detector.is_stuck() is False
 
+    def test_is_not_stuck_invalid_syntax_error_only_three_incidents(
+        self, stuck_detector: StuckDetector, event_stream: EventStream
+    ):
+        self._impl_syntax_error_events(
+            event_stream,
+            error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
+            random_line=True,
+            incidents=3,
+        )
+
+        with patch('logging.Logger.warning'):
+            assert stuck_detector.is_stuck() is False
+
     def test_is_stuck_incomplete_input_error(
         self, stuck_detector: StuckDetector, event_stream: EventStream
     ):
@@ -287,7 +309,7 @@ class TestStuckDetector:
         with patch('logging.Logger.warning'):
             assert stuck_detector.is_stuck() is False
 
-    def test_is_not_stuck_ipython_unterminated_string_error(
+    def test_is_not_stuck_ipython_unterminated_string_error_random_lines(
         self, stuck_detector: StuckDetector, event_stream: EventStream
     ):
         self._impl_unterminated_string_error_events(event_stream, random_line=True)
@@ -295,6 +317,16 @@ class TestStuckDetector:
         with patch('logging.Logger.warning'):
             assert stuck_detector.is_stuck() is False
 
+    def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
+        self, stuck_detector: StuckDetector, event_stream: EventStream
+    ):
+        self._impl_unterminated_string_error_events(
+            event_stream, random_line=False, incidents=3
+        )
+
+        with patch('logging.Logger.warning'):
+            assert stuck_detector.is_stuck() is False
+
     def test_is_stuck_ipython_unterminated_string_error(
         self, stuck_detector: StuckDetector, event_stream: EventStream
     ):
@@ -303,7 +335,7 @@ class TestStuckDetector:
         with patch('logging.Logger.warning'):
             assert stuck_detector.is_stuck() is True
 
-    def test_is_stuck_ipython_syntax_error_not_at_end(
+    def test_is_not_stuck_ipython_syntax_error_not_at_end(
         self, stuck_detector: StuckDetector, event_stream: EventStream
     ):
         # this test is to make sure we don't get false positives