Browse Source

Fix/update controller is_stuck() (#1891)

* Refactor monologue to use the messages in state history

remove now unused method

* is_stuck update

* fix is_stuck

* unit tests

* fix tests

* Revert "Refactor monologue to use the messages in state history"

This reverts commit 76b4b765ef31f2cb116184827bb69ff1b7100e80.

* Override eq for CmdOutputObservation to ignore the pid, compare the actual command only

* Revert "Override eq for CmdOutputObservation to ignore the pid, compare the actual command only"

This reverts commit 6418d856b565c72f900ef447e7595869894dc2a3.
Engel Nyst 1 year ago
parent
commit
1e51bb9276
2 changed files with 255 additions and 20 deletions
  1. 31 20
      opendevin/controller/agent_controller.py
  2. 224 0
      tests/unit/test_is_stuck.py

+ 31 - 20
opendevin/controller/agent_controller.py

@@ -244,32 +244,43 @@ class AgentController:
         # check if delegate stuck
         if self.delegate and self.delegate._is_stuck():
             return True
-        if len(self.state.history) < 3:
+
+        # filter out MessageAction with source='user' from history
+        filtered_history = [
+            _tuple
+            for _tuple in self.state.history
+            if not (
+                isinstance(_tuple[0], MessageAction)
+                and _tuple[0].source == EventSource.USER
+            )
+        ]
+
+        if len(filtered_history) < 4:
             return False
 
-        # if the last three (Action, Observation) tuples are too repetitive
-        # the agent got stuck in a loop
-        if all(
-            [
-                self.state.history[-i][0] == self.state.history[-3][0]
-                for i in range(1, 3)
-            ]
-        ):
-            # it repeats same action, give it a chance, but not if:
+        # Check if the last four (Action, Observation) tuples are too repetitive
+        last_four_tuples = filtered_history[-4:]
+        if all(last_four_tuples[-1] == _tuple for _tuple in last_four_tuples):
+            logger.warning('Action, Observation loop detected')
+            return True
+
+        if all(last_four_tuples[-1][0] == _tuple[0] for _tuple in last_four_tuples):
+            # It repeats the same action, give it a chance, but not if:
             if all(
-                isinstance(self.state.history[-i][1], NullObservation)
-                for i in range(1, 4)
+                isinstance(_tuple[1], ErrorObservation) for _tuple in last_four_tuples
             ):
-                # same (Action, NullObservation): like 'think' the same thought over and over
-                logger.warning('Action, NullObservation loop detected')
+                logger.warning('Action, ErrorObservation loop detected')
                 return True
-            elif all(
-                isinstance(self.state.history[-i][1], ErrorObservation)
-                for i in range(1, 4)
+
+        # check if the agent repeats the same (Action, Observation)
+        # every other step in the last six tuples
+        if len(filtered_history) >= 6:
+            last_six_tuples = filtered_history[-6:]
+            if (
+                last_six_tuples[-1] == last_six_tuples[-3] == last_six_tuples[-5]
+                and last_six_tuples[-2] == last_six_tuples[-4] == last_six_tuples[-6]
             ):
-                # (NullAction, ErrorObservation): errors coming from an exception
-                # (Action, ErrorObservation): the same action getting an error, even if not necessarily the same error
-                logger.warning('Action, ErrorObservation loop detected')
+                logger.warning('Action, Observation pattern detected')
                 return True
 
         return False

+ 224 - 0
tests/unit/test_is_stuck.py

@@ -0,0 +1,224 @@
+from unittest.mock import Mock, patch
+
+import pytest
+
+from opendevin.controller.agent_controller import AgentController
+from opendevin.events.action import CmdRunAction, FileReadAction, MessageAction
+from opendevin.events.observation import (
+    CmdOutputObservation,
+    FileReadObservation,
+    Observation,
+)
+from opendevin.events.observation.empty import NullObservation
+from opendevin.events.observation.error import ErrorObservation
+from opendevin.events.stream import EventSource
+
+
+class TestAgentController:
+    @pytest.fixture
+    def controller(self):
+        controller = Mock(spec=AgentController)
+        controller._is_stuck = AgentController._is_stuck.__get__(
+            controller, AgentController
+        )
+        controller.delegate = None
+        controller.state = Mock()
+        controller.state.history = []
+        return controller
+
+    def test_history_too_short(self, controller):
+        controller.state.history = [
+            (
+                MessageAction(content='Hello', wait_for_response=False),
+                Observation(content='Response 1'),
+            ),
+            (
+                CmdRunAction(command='ls'),
+                CmdOutputObservation(
+                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                ),
+            ),
+        ]
+        assert controller._is_stuck() is False
+
+    def test_is_stuck_repeating_action_null_observation(self, controller):
+        # message actions with source USER are not considered in the stuck check
+        message_action = MessageAction(content='Done', wait_for_response=False)
+        message_action._source = EventSource.USER
+        controller.state.history = [
+            (
+                MessageAction(content='Hello', wait_for_response=False),
+                Observation(content='Response 1'),
+            ),
+            (CmdRunAction(command='ls'), NullObservation(content='')),
+            (CmdRunAction(command='ls'), NullObservation(content='')),
+            # user message shouldn't break detection
+            (message_action, NullObservation(content='')),
+            (CmdRunAction(command='ls'), NullObservation(content='')),
+            (CmdRunAction(command='ls'), NullObservation(content='')),
+        ]
+        with patch('logging.Logger.warning') as mock_warning:
+            assert controller._is_stuck() is True
+            mock_warning.assert_called_once_with('Action, Observation loop detected')
+
+    def test_is_stuck_repeating_action_error_observation(self, controller):
+        message_action = MessageAction(content='Done', wait_for_response=False)
+        message_action._source = EventSource.USER
+        controller.state.history = [
+            (
+                MessageAction(content='Hello', wait_for_response=False),
+                Observation(content='Response 1'),
+            ),
+            (
+                CmdRunAction(command='invalid_command'),
+                ErrorObservation(content='Command not found'),
+            ),
+            (
+                CmdRunAction(command='invalid_command'),
+                ErrorObservation(content='Command not found'),
+            ),
+            # user message shouldn't break detection
+            (message_action, NullObservation(content='')),
+            (
+                CmdRunAction(command='invalid_command'),
+                ErrorObservation(content='Different error'),
+            ),
+            (
+                CmdRunAction(command='invalid_command'),
+                ErrorObservation(content='Command not found'),
+            ),
+        ]
+        with patch('logging.Logger.warning') as mock_warning:
+            assert controller._is_stuck() is True
+            mock_warning.assert_called_once_with(
+                'Action, ErrorObservation loop detected'
+            )
+
+    def test_is_stuck_repeating_action_observation_pattern(self, controller):
+        # six tuples of action, observation
+        message_action = MessageAction(content='Come on', wait_for_response=False)
+        message_action._source = EventSource.USER
+        controller.state.history = [
+            (
+                message_action,
+                Observation(content=''),
+            ),
+            (
+                CmdRunAction(command='ls'),
+                CmdOutputObservation(
+                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                ),
+            ),
+            (
+                FileReadAction(path='file1.txt'),
+                FileReadObservation(content='File content', path='file1.txt'),
+            ),
+            (
+                CmdRunAction(command='ls'),
+                CmdOutputObservation(
+                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                ),
+            ),
+            (
+                FileReadAction(path='file1.txt'),
+                FileReadObservation(content='File content', path='file1.txt'),
+            ),
+            # insert a message just because they can, it shouldn't break detection
+            (message_action, NullObservation(content='')),
+            (
+                CmdRunAction(command='ls'),
+                CmdOutputObservation(
+                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                ),
+            ),
+            (
+                FileReadAction(path='file1.txt'),
+                FileReadObservation(content='File content', path='file1.txt'),
+            ),
+        ]
+        with patch('logging.Logger.warning') as mock_warning:
+            assert controller._is_stuck() is True
+            mock_warning.assert_called_once_with('Action, Observation pattern detected')
+
+    def test_is_stuck_not_stuck(self, controller):
+        message_action = MessageAction(content='Done', wait_for_response=False)
+        message_action._source = EventSource.USER
+        controller.state.history = [
+            (
+                MessageAction(content='Hello', wait_for_response=False),
+                Observation(content='Response 1'),
+            ),
+            (
+                CmdRunAction(command='ls'),
+                CmdOutputObservation(
+                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                ),
+            ),
+            (
+                FileReadAction(path='file1.txt'),
+                FileReadObservation(content='File content', path='file1.txt'),
+            ),
+            (
+                CmdRunAction(command='pwd'),
+                CmdOutputObservation(command_id=1, command='pwd', content='/home/user'),
+            ),
+            (
+                FileReadAction(path='file2.txt'),
+                Observation(content='Another file content'),
+            ),
+            # insert a message from the user
+            (message_action, NullObservation(content='')),
+            (
+                CmdRunAction(command='pwd'),
+                CmdOutputObservation(command_id=1, command='pwd', content='/home/user'),
+            ),
+            (
+                FileReadAction(path='file2.txt'),
+                Observation(content='Another file content'),
+            ),
+        ]
+        assert controller._is_stuck() is False
+
+    def test_is_stuck_four_identical_tuples(self, controller):
+        message_action = MessageAction(content='Done', wait_for_response=False)
+        message_action._source = EventSource.USER
+        controller.state.history = [
+            (
+                MessageAction(content='Hello', wait_for_response=False),
+                Observation(content='Response 1'),
+            ),
+            (
+                CmdRunAction(command='ls'),
+                CmdOutputObservation(
+                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                ),
+            ),
+            (
+                CmdRunAction(command='ls'),
+                CmdOutputObservation(
+                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                ),
+            ),
+            # message from the user
+            (message_action, NullObservation(content='')),
+            (
+                CmdRunAction(command='ls'),
+                CmdOutputObservation(
+                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                ),
+            ),
+            (
+                CmdRunAction(command='ls'),
+                CmdOutputObservation(
+                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                ),
+            ),
+        ]
+        with patch('logging.Logger.warning') as mock_warning:
+            assert controller._is_stuck() is True
+            mock_warning.assert_called_once_with('Action, Observation loop detected')
+
+    def test_is_stuck_delegate_stuck(self, controller):
+        controller.delegate = Mock()
+        controller.delegate._is_stuck.return_value = True
+        assert controller._is_stuck() is True