Forráskód Böngészése

Ignore pid for loop detection (Was: override eq...) (#2045)

* rewrite, implement pid ignore in the controller

* make the helper method private
Engel Nyst 1 éve
szülő
commit
783fea62a0
2 módosított fájl, 124 hozzáadás és 13 törlés
  1. 43 4
      opendevin/controller/agent_controller.py
  2. 81 9
      tests/unit/test_is_stuck.py

+ 43 - 4
opendevin/controller/agent_controller.py

@@ -24,6 +24,7 @@ from opendevin.events.action import (
     ModifyTaskAction,
     ModifyTaskAction,
     NullAction,
     NullAction,
 )
 )
+from opendevin.events.action.commands import CmdKillAction
 from opendevin.events.event import Event
 from opendevin.events.event import Event
 from opendevin.events.observation import (
 from opendevin.events.observation import (
     AgentDelegateObservation,
     AgentDelegateObservation,
@@ -271,13 +272,29 @@ class AgentController:
         if len(filtered_history) < 4:
         if len(filtered_history) < 4:
             return False
             return False
 
 
+        # FIXME rewrite this to be more readable
+
         # Check if the last four (Action, Observation) tuples are too repetitive
         # Check if the last four (Action, Observation) tuples are too repetitive
         last_four_tuples = filtered_history[-4:]
         last_four_tuples = filtered_history[-4:]
-        if all(last_four_tuples[-1] == _tuple for _tuple in last_four_tuples):
+
+        if all(
+            # (Action, Observation) tuples
+            # compare the last action to the last four actions
+            self._eq_no_pid(last_four_tuples[-1][0], _tuple[0])
+            for _tuple in last_four_tuples
+        ) and all(
+            # compare the last observation to the last four observations
+            self._eq_no_pid(last_four_tuples[-1][1], _tuple[1])
+            for _tuple in last_four_tuples
+        ):
             logger.warning('Action, Observation loop detected')
             logger.warning('Action, Observation loop detected')
             return True
             return True
 
 
-        if all(last_four_tuples[-1][0] == _tuple[0] for _tuple in last_four_tuples):
+        # (action, error) tuples
+        if all(
+            self._eq_no_pid(last_four_tuples[-1][0], _tuple[0])
+            for _tuple in last_four_tuples
+        ):
             # It repeats the same action, give it a chance, but not if:
             # It repeats the same action, give it a chance, but not if:
             if all(
             if all(
                 isinstance(_tuple[1], ErrorObservation) for _tuple in last_four_tuples
                 isinstance(_tuple[1], ErrorObservation) for _tuple in last_four_tuples
@@ -287,13 +304,35 @@ class AgentController:
 
 
         # check if the agent repeats the same (Action, Observation)
         # check if the agent repeats the same (Action, Observation)
         # every other step in the last six tuples
         # every other step in the last six tuples
+
         if len(filtered_history) >= 6:
         if len(filtered_history) >= 6:
             last_six_tuples = filtered_history[-6:]
             last_six_tuples = filtered_history[-6:]
             if (
             if (
-                last_six_tuples[-1] == last_six_tuples[-3] == last_six_tuples[-5]
-                and last_six_tuples[-2] == last_six_tuples[-4] == last_six_tuples[-6]
+                # this pattern is every other step, like:
+                # (action_1, obs_1), (action_2, obs_2), (action_1, obs_1), (action_2, obs_2),...
+                self._eq_no_pid(last_six_tuples[-1][0], last_six_tuples[-3][0])
+                and self._eq_no_pid(last_six_tuples[-1][0], last_six_tuples[-5][0])
+                and self._eq_no_pid(last_six_tuples[-2][0], last_six_tuples[-4][0])
+                and self._eq_no_pid(last_six_tuples[-2][0], last_six_tuples[-6][0])
+                and self._eq_no_pid(last_six_tuples[-1][1], last_six_tuples[-3][1])
+                and self._eq_no_pid(last_six_tuples[-1][1], last_six_tuples[-5][1])
+                and self._eq_no_pid(last_six_tuples[-2][1], last_six_tuples[-4][1])
+                and self._eq_no_pid(last_six_tuples[-2][1], last_six_tuples[-6][1])
             ):
             ):
                 logger.warning('Action, Observation pattern detected')
                 logger.warning('Action, Observation pattern detected')
                 return True
                 return True
 
 
         return False
         return False
+
+    def _eq_no_pid(self, obj1, obj2):
+        if isinstance(obj1, CmdOutputObservation) and isinstance(
+            obj2, CmdOutputObservation
+        ):
+            # for loop detection, ignore command_id, which is the pid
+            return obj1.command == obj2.command and obj1.exit_code == obj2.exit_code
+        elif isinstance(obj1, CmdKillAction) and isinstance(obj2, CmdKillAction):
+            # for loop detection, ignore command_id, which is the pid
+            return obj1.thought == obj2.thought
+        else:
+            # this is the default comparison
+            return obj1 == obj2

+ 81 - 9
tests/unit/test_is_stuck.py

@@ -4,6 +4,7 @@ import pytest
 
 
 from opendevin.controller.agent_controller import AgentController
 from opendevin.controller.agent_controller import AgentController
 from opendevin.events.action import CmdRunAction, FileReadAction, MessageAction
 from opendevin.events.action import CmdRunAction, FileReadAction, MessageAction
+from opendevin.events.action.commands import CmdKillAction
 from opendevin.events.observation import (
 from opendevin.events.observation import (
     CmdOutputObservation,
     CmdOutputObservation,
     FileReadObservation,
     FileReadObservation,
@@ -21,6 +22,9 @@ class TestAgentController:
         controller._is_stuck = AgentController._is_stuck.__get__(
         controller._is_stuck = AgentController._is_stuck.__get__(
             controller, AgentController
             controller, AgentController
         )
         )
+        controller._eq_no_pid = AgentController._eq_no_pid.__get__(
+            controller, AgentController
+        )
         controller.delegate = None
         controller.delegate = None
         controller.state = Mock()
         controller.state = Mock()
         controller.state.history = []
         controller.state.history = []
@@ -75,7 +79,7 @@ class TestAgentController:
             ),
             ),
             (
             (
                 CmdRunAction(command='invalid_command'),
                 CmdRunAction(command='invalid_command'),
-                ErrorObservation(content='Command not found'),
+                ErrorObservation(content='Command still not found or another error'),
             ),
             ),
             # user message shouldn't break detection
             # user message shouldn't break detection
             (message_action, NullObservation(content='')),
             (message_action, NullObservation(content='')),
@@ -115,8 +119,9 @@ class TestAgentController:
             ),
             ),
             (
             (
                 CmdRunAction(command='ls'),
                 CmdRunAction(command='ls'),
+                # command_id is ignored for the eq check, it's a pid
                 CmdOutputObservation(
                 CmdOutputObservation(
-                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                    command_id=2, command='ls', content='file1.txt\nfile2.txt'
                 ),
                 ),
             ),
             ),
             (
             (
@@ -128,7 +133,7 @@ class TestAgentController:
             (
             (
                 CmdRunAction(command='ls'),
                 CmdRunAction(command='ls'),
                 CmdOutputObservation(
                 CmdOutputObservation(
-                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                    command_id=3, command='ls', content='file1.txt\nfile2.txt'
                 ),
                 ),
             ),
             ),
             (
             (
@@ -160,7 +165,8 @@ class TestAgentController:
             ),
             ),
             (
             (
                 CmdRunAction(command='pwd'),
                 CmdRunAction(command='pwd'),
-                CmdOutputObservation(command_id=1, command='pwd', content='/home/user'),
+                # command_id is ignored for the eq check, it's the pid
+                CmdOutputObservation(command_id=2, command='pwd', content='/home/user'),
             ),
             ),
             (
             (
                 FileReadAction(path='file2.txt'),
                 FileReadAction(path='file2.txt'),
@@ -170,7 +176,7 @@ class TestAgentController:
             (message_action, NullObservation(content='')),
             (message_action, NullObservation(content='')),
             (
             (
                 CmdRunAction(command='pwd'),
                 CmdRunAction(command='pwd'),
-                CmdOutputObservation(command_id=1, command='pwd', content='/home/user'),
+                CmdOutputObservation(command_id=3, command='pwd', content='/home/user'),
             ),
             ),
             (
             (
                 FileReadAction(path='file2.txt'),
                 FileReadAction(path='file2.txt'),
@@ -195,22 +201,88 @@ class TestAgentController:
             ),
             ),
             (
             (
                 CmdRunAction(command='ls'),
                 CmdRunAction(command='ls'),
+                # command_id is ignored for the eq check, it's just the pid
                 CmdOutputObservation(
                 CmdOutputObservation(
-                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                    command_id=2, command='ls', content='file1.txt\nfile2.txt'
                 ),
                 ),
             ),
             ),
-            # message from the user
+            # message from the user shouldn't interfere with the detection
             (message_action, NullObservation(content='')),
             (message_action, NullObservation(content='')),
             (
             (
                 CmdRunAction(command='ls'),
                 CmdRunAction(command='ls'),
                 CmdOutputObservation(
                 CmdOutputObservation(
-                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                    command_id=3, command='ls', content='file1.txt\nfile2.txt'
                 ),
                 ),
             ),
             ),
             (
             (
                 CmdRunAction(command='ls'),
                 CmdRunAction(command='ls'),
                 CmdOutputObservation(
                 CmdOutputObservation(
-                    command_id=1, command='ls', content='file1.txt\nfile2.txt'
+                    command_id=4, command='ls', content='file1.txt\nfile2.txt'
+                ),
+            ),
+        ]
+        with patch('logging.Logger.warning') as mock_warning:
+            assert controller._is_stuck() is True
+            mock_warning.assert_called_once_with('Action, Observation loop detected')
+
+    def test_is_stuck_four_tuples_cmd_kill_and_output(self, controller):
+        message_action = MessageAction(content='Done', wait_for_response=False)
+        message_action._source = EventSource.USER
+        controller.state.history = [
+            (
+                MessageAction(content='Hello', wait_for_response=False),
+                Observation(content='Response 1'),
+            ),
+            (
+                CmdKillAction(
+                    command_id=1,
+                    thought='It looks like storybook is stuck, lets kill it',
+                ),
+                CmdOutputObservation(
+                    content='Background command storybook has been killed.',
+                    command_id=1,
+                    command='storybook',
+                    exit_code=0,
+                ),
+            ),
+            (
+                # command_id is ignored for the eq check, it's the pid
+                CmdKillAction(
+                    command_id=2,
+                    thought='It looks like storybook is stuck, lets kill it',
+                ),
+                # command_id here too
+                CmdOutputObservation(
+                    content='Background command storybook has been killed.',
+                    command_id=2,
+                    command='storybook',
+                    exit_code=0,
+                ),
+            ),
+            # message from the user, shouldn't be counted
+            (message_action, NullObservation(content='')),
+            (
+                CmdKillAction(
+                    command_id=3,
+                    thought='It looks like storybook is stuck, lets kill it',
+                ),
+                CmdOutputObservation(
+                    content='Background command storybook has been killed.',
+                    command_id=3,
+                    command='storybook',
+                    exit_code=0,
+                ),
+            ),
+            (
+                CmdKillAction(
+                    command_id=4,
+                    thought='It looks like storybook is stuck, lets kill it',
+                ),
+                CmdOutputObservation(
+                    content='Background command storybook has been killed.',
+                    command_id=4,
+                    command='storybook',
+                    exit_code=0,
                 ),
                 ),
             ),
             ),
         ]
         ]