Procházet zdrojové kódy

refactor: move get_pairs from memory to shared utils (#4411)

Xingyao Wang před 1 rokem
rodič
revize
da23189e4c

+ 56 - 0
openhands/events/utils.py

@@ -0,0 +1,56 @@
+from openhands.core.logger import openhands_logger as logger
+from openhands.events.action.action import Action
+from openhands.events.action.empty import NullAction
+from openhands.events.event import Event
+from openhands.events.observation.commands import CmdOutputObservation
+from openhands.events.observation.empty import NullObservation
+from openhands.events.observation.observation import Observation
+
+
+def get_pairs_from_events(events: list[Event]) -> list[tuple[Action, Observation]]:
+    """Return the history as a list of tuples (action, observation)."""
+    tuples: list[tuple[Action, Observation]] = []
+    action_map: dict[int, Action] = {}
+    observation_map: dict[int, Observation] = {}
+
+    # runnable actions are set as cause of observations
+    # (MessageAction, NullObservation) for source=USER
+    # (MessageAction, NullObservation) for source=AGENT
+    # (other_action?, NullObservation)
+    # (NullAction, CmdOutputObservation) background CmdOutputObservations
+
+    for event in events:
+        if event.id is None or event.id == -1:
+            logger.debug(f'Event {event} has no ID')
+
+        if isinstance(event, Action):
+            action_map[event.id] = event
+
+        if isinstance(event, Observation):
+            if event.cause is None or event.cause == -1:
+                logger.debug(f'Observation {event} has no cause')
+
+            if event.cause is None:
+                # runnable actions are set as cause of observations
+                # NullObservations have no cause
+                continue
+
+            observation_map[event.cause] = event
+
+    for action_id, action in action_map.items():
+        observation = observation_map.get(action_id)
+        if observation:
+            # observation with a cause
+            tuples.append((action, observation))
+        else:
+            tuples.append((action, NullObservation('')))
+
+    for cause_id, observation in observation_map.items():
+        if cause_id not in action_map:
+            if isinstance(observation, NullObservation):
+                continue
+            if not isinstance(observation, CmdOutputObservation):
+                logger.debug(f'Observation {observation} has no cause')
+            tuples.append((NullAction(), observation))
+
+    return tuples.copy()

+ 4 - 50
openhands/memory/history.py

@@ -10,12 +10,12 @@ from openhands.events.action.empty import NullAction
 from openhands.events.action.message import MessageAction
 from openhands.events.event import Event, EventSource
 from openhands.events.observation.agent import AgentStateChangedObservation
-from openhands.events.observation.commands import CmdOutputObservation
 from openhands.events.observation.delegate import AgentDelegateObservation
 from openhands.events.observation.empty import NullObservation
 from openhands.events.observation.observation import Observation
 from openhands.events.serialization.event import event_to_dict
 from openhands.events.stream import EventStream
+from openhands.events.utils import get_pairs_from_events
 
 
 class ShortTermHistory(list[Event]):
@@ -216,55 +216,9 @@ class ShortTermHistory(list[Event]):
     def compatibility_for_eval_history_pairs(self) -> list[tuple[dict, dict]]:
         history_pairs = []
 
-        for action, observation in self.get_pairs():
+        for action, observation in get_pairs_from_events(
+            self.get_events_as_list(include_delegates=True)
+        ):
             history_pairs.append((event_to_dict(action), event_to_dict(observation)))
 
         return history_pairs
-
-    def get_pairs(self) -> list[tuple[Action, Observation]]:
-        """Return the history as a list of tuples (action, observation)."""
-        tuples: list[tuple[Action, Observation]] = []
-        action_map: dict[int, Action] = {}
-        observation_map: dict[int, Observation] = {}
-
-        # runnable actions are set as cause of observations
-        # (MessageAction, NullObservation) for source=USER
-        # (MessageAction, NullObservation) for source=AGENT
-        # (other_action?, NullObservation)
-        # (NullAction, CmdOutputObservation) background CmdOutputObservations
-
-        for event in self.get_events_as_list(include_delegates=True):
-            if event.id is None or event.id == -1:
-                logger.debug(f'Event {event} has no ID')
-
-            if isinstance(event, Action):
-                action_map[event.id] = event
-
-            if isinstance(event, Observation):
-                if event.cause is None or event.cause == -1:
-                    logger.debug(f'Observation {event} has no cause')
-
-                if event.cause is None:
-                    # runnable actions are set as cause of observations
-                    # NullObservations have no cause
-                    continue
-
-                observation_map[event.cause] = event
-
-        for action_id, action in action_map.items():
-            observation = observation_map.get(action_id)
-            if observation:
-                # observation with a cause
-                tuples.append((action, observation))
-            else:
-                tuples.append((action, NullObservation('')))
-
-        for cause_id, observation in observation_map.items():
-            if cause_id not in action_map:
-                if isinstance(observation, NullObservation):
-                    continue
-                if not isinstance(observation, CmdOutputObservation):
-                    logger.debug(f'Observation {observation} has no cause')
-                tuples.append((NullAction(), observation))
-
-        return tuples.copy()

+ 21 - 2
tests/unit/test_is_stuck.py

@@ -17,6 +17,7 @@ from openhands.events.observation.commands import IPythonRunCellObservation
 from openhands.events.observation.empty import NullObservation
 from openhands.events.observation.error import ErrorObservation
 from openhands.events.stream import EventSource, EventStream
+from openhands.events.utils import get_pairs_from_events
 from openhands.memory.history import ShortTermHistory
 from openhands.storage import get_file_store
 
@@ -170,7 +171,16 @@ class TestStuckDetector:
 
         assert len(collect_events(event_stream)) == 10
         assert len(list(stuck_detector.state.history.get_events())) == 8
-        assert len(stuck_detector.state.history.get_pairs()) == 5
+        assert (
+            len(
+                get_pairs_from_events(
+                    stuck_detector.state.history.get_events_as_list(
+                        include_delegates=True
+                    )
+                )
+            )
+            == 5
+        )
 
         assert stuck_detector.is_stuck() is False
         assert stuck_detector.state.almost_stuck == 1
@@ -186,7 +196,16 @@ class TestStuckDetector:
 
         assert len(collect_events(event_stream)) == 12
         assert len(list(stuck_detector.state.history.get_events())) == 10
-        assert len(stuck_detector.state.history.get_pairs()) == 6
+        assert (
+            len(
+                get_pairs_from_events(
+                    stuck_detector.state.history.get_events_as_list(
+                        include_delegates=True
+                    )
+                )
+            )
+            == 6
+        )
 
         with patch('logging.Logger.warning') as mock_warning:
             assert stuck_detector.is_stuck() is True