Forráskód Böngészése

Context Window Exceeded fix (#4977)

Engel Nyst 1 éve
szülő
commit
8dee334236

+ 130 - 2
openhands/controller/agent_controller.py

@@ -5,6 +5,7 @@ import traceback
 from typing import Callable, ClassVar, Type
 
 import litellm
+from litellm.exceptions import ContextWindowExceededError
 
 from openhands.controller.agent import Agent
 from openhands.controller.state.state import State, TrafficControlState
@@ -485,6 +486,15 @@ class AgentController:
                 EventSource.AGENT,
             )
             return
+        except ContextWindowExceededError:
+            # When context window is exceeded, keep roughly half of agent interactions
+            self.state.history = self._apply_conversation_window(self.state.history)
+
+            # Save the ID of the first event in our truncated history for future reloading
+            if self.state.history:
+                self.state.start_id = self.state.history[0].id
+            # Don't add error event - let the agent retry with reduced context
+            return
 
         if action.runnable:
             if self.state.confirmation_mode and (
@@ -659,6 +669,12 @@ class AgentController:
         - For delegate events (between AgentDelegateAction and AgentDelegateObservation):
             - Excludes all events between the action and observation
             - Includes the delegate action and observation themselves
+
+        The history is loaded in two parts if truncation_id is set:
+        1. First user message from start_id onwards
+        2. Rest of history from truncation_id to the end
+
+        Otherwise loads normally from start_id.
         """
 
         # define range of events to fetch
@@ -680,8 +696,33 @@ class AgentController:
             self.state.history = []
             return
 
-        # Get all events, filtering out backend events and hidden events
-        events = list(
+        events: list[Event] = []
+
+        # If we have a truncation point, get first user message and then rest of history
+        if hasattr(self.state, 'truncation_id') and self.state.truncation_id > 0:
+            # Find first user message from stream
+            first_user_msg = next(
+                (
+                    e
+                    for e in self.event_stream.get_events(
+                        start_id=start_id,
+                        end_id=end_id,
+                        reverse=False,
+                        filter_out_type=self.filter_out,
+                        filter_hidden=True,
+                    )
+                    if isinstance(e, MessageAction) and e.source == EventSource.USER
+                ),
+                None,
+            )
+            if first_user_msg:
+                events.append(first_user_msg)
+
+            # the rest of the events are from the truncation point
+            start_id = self.state.truncation_id
+
+        # Get rest of history
+        events_to_add = list(
             self.event_stream.get_events(
                 start_id=start_id,
                 end_id=end_id,
@@ -690,6 +731,7 @@ class AgentController:
                 filter_hidden=True,
             )
         )
+        events.extend(events_to_add)
 
         # Find all delegate action/observation pairs
         delegate_ranges: list[tuple[int, int]] = []
@@ -744,6 +786,92 @@ class AgentController:
         # make sure history is in sync
         self.state.start_id = start_id
 
+    def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
+        """Cuts history roughly in half when context window is exceeded, preserving action-observation pairs
+        and ensuring the first user message is always included.
+
+        The algorithm:
+        1. Cut history in half
+        2. Check first event in new history:
+           - If Observation: find and include its Action
+           - If MessageAction: ensure its related Action-Observation pair isn't split
+        3. Always include the first user message
+
+        Args:
+            events: List of events to filter
+
+        Returns:
+            Filtered list of events keeping newest half while preserving pairs
+        """
+        if not events:
+            return events
+
+        # Find first user message - we'll need to ensure it's included
+        first_user_msg = next(
+            (
+                e
+                for e in events
+                if isinstance(e, MessageAction) and e.source == EventSource.USER
+            ),
+            None,
+        )
+
+        # cut in half
+        mid_point = max(1, len(events) // 2)
+        kept_events = events[mid_point:]
+
+        # Handle first event in truncated history
+        if kept_events:
+            i = 0
+            while i < len(kept_events):
+                first_event = kept_events[i]
+                if isinstance(first_event, Observation) and first_event.cause:
+                    # Find its action and include it
+                    matching_action = next(
+                        (
+                            e
+                            for e in reversed(events[:mid_point])
+                            if isinstance(e, Action) and e.id == first_event.cause
+                        ),
+                        None,
+                    )
+                    if matching_action:
+                        kept_events = [matching_action] + kept_events
+                    else:
+                        self.log(
+                            'warning',
+                            f'Found Observation without matching Action at id={first_event.id}',
+                        )
+                        # drop this observation
+                        kept_events = kept_events[1:]
+                    break
+
+                elif isinstance(first_event, MessageAction) or (
+                    isinstance(first_event, Action)
+                    and first_event.source == EventSource.USER
+                ):
+                    # if it's a message action or a user action, keep it and continue to find the next event
+                    i += 1
+                    continue
+
+                else:
+                    # if it's an action with source == EventSource.AGENT, we're good
+                    break
+
+        # Save where to continue from in next reload
+        if kept_events:
+            self.state.truncation_id = kept_events[0].id
+
+        # Ensure first user message is included
+        if first_user_msg and first_user_msg not in kept_events:
+            kept_events = [first_user_msg] + kept_events
+
+        # start_id points to first user message
+        if first_user_msg:
+            self.state.start_id = first_user_msg.id
+
+        return kept_events
+
     def _is_stuck(self):
         """Checks if the agent or its delegate is stuck in a loop.
 

+ 2 - 0
openhands/controller/state/state.py

@@ -92,6 +92,8 @@ class State:
     # start_id and end_id track the range of events in history
     start_id: int = -1
     end_id: int = -1
+    # truncation_id tracks where to load history after context window truncation
+    truncation_id: int = -1
     almost_stuck: int = 0
     delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
     # NOTE: This will never be used by the controller, but it can be used by different

+ 188 - 0
tests/unit/test_truncation.py

@@ -0,0 +1,188 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from openhands.controller.agent_controller import AgentController
+from openhands.events import EventSource
+from openhands.events.action import CmdRunAction, MessageAction
+from openhands.events.observation import CmdOutputObservation
+
+
+@pytest.fixture
+def mock_event_stream():
+    stream = MagicMock()
+    # Mock get_events to return an empty list by default
+    stream.get_events.return_value = []
+    return stream
+
+
+@pytest.fixture
+def mock_agent():
+    agent = MagicMock()
+    agent.llm = MagicMock()
+    agent.llm.config = MagicMock()
+    return agent
+
+
+class TestTruncation:
+    def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent):
+        controller = AgentController(
+            agent=mock_agent,
+            event_stream=mock_event_stream,
+            max_iterations=10,
+            sid='test_truncation',
+            confirmation_mode=False,
+            headless_mode=True,
+        )
+
+        # Create a sequence of events with IDs
+        first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
+        first_msg._source = EventSource.USER
+        first_msg._id = 1
+
+        cmd1 = CmdRunAction(command='ls')
+        cmd1._id = 2
+        obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=2)
+        obs1._id = 3
+        obs1._cause = 2
+
+        cmd2 = CmdRunAction(command='pwd')
+        cmd2._id = 4
+        obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=4)
+        obs2._id = 5
+        obs2._cause = 4
+
+        events = [first_msg, cmd1, obs1, cmd2, obs2]
+
+        # Apply truncation
+        truncated = controller._apply_conversation_window(events)
+
+        # Should keep first user message and roughly half of other events
+        assert (
+            len(truncated) >= 3
+        )  # First message + at least one action-observation pair
+        assert truncated[0] == first_msg  # First message always preserved
+        assert controller.state.start_id == first_msg._id
+        assert controller.state.truncation_id is not None
+
+        # Verify pairs aren't split
+        for i, event in enumerate(truncated[1:]):
+            if isinstance(event, CmdOutputObservation):
+                assert any(e._id == event._cause for e in truncated[: i + 1])
+
+    def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
+        controller = AgentController(
+            agent=mock_agent,
+            event_stream=mock_event_stream,
+            max_iterations=10,
+            sid='test_truncation',
+            confirmation_mode=False,
+            headless_mode=True,
+        )
+
+        # Setup initial history with IDs
+        first_msg = MessageAction(content='Start task', wait_for_response=False)
+        first_msg._source = EventSource.USER
+        first_msg._id = 1
+
+        # Add agent question
+        agent_msg = MessageAction(
+            content='What task would you like me to perform?', wait_for_response=True
+        )
+        agent_msg._source = EventSource.AGENT
+        agent_msg._id = 2
+
+        # Add user response
+        user_response = MessageAction(
+            content='Please list all files and show me current directory',
+            wait_for_response=False,
+        )
+        user_response._source = EventSource.USER
+        user_response._id = 3
+
+        cmd1 = CmdRunAction(command='ls')
+        cmd1._id = 4
+        obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
+        obs1._id = 5
+        obs1._cause = 4
+
+        # Update mock event stream to include new messages
+        mock_event_stream.get_events.return_value = [
+            first_msg,
+            agent_msg,
+            user_response,
+            cmd1,
+            obs1,
+        ]
+        controller.state.history = [first_msg, agent_msg, user_response, cmd1, obs1]
+        original_history_len = len(controller.state.history)
+
+        # Simulate ContextWindowExceededError and truncation
+        controller.state.history = controller._apply_conversation_window(
+            controller.state.history
+        )
+
+        # Verify truncation occurred
+        assert len(controller.state.history) < original_history_len
+        assert controller.state.start_id == first_msg._id
+        assert controller.state.truncation_id is not None
+        assert controller.state.truncation_id > controller.state.start_id
+
+    def test_history_restoration_after_truncation(self, mock_event_stream, mock_agent):
+        controller = AgentController(
+            agent=mock_agent,
+            event_stream=mock_event_stream,
+            max_iterations=10,
+            sid='test_truncation',
+            confirmation_mode=False,
+            headless_mode=True,
+        )
+
+        # Create events with IDs
+        first_msg = MessageAction(content='Start task', wait_for_response=False)
+        first_msg._source = EventSource.USER
+        first_msg._id = 1
+
+        events = [first_msg]
+        for i in range(5):
+            cmd = CmdRunAction(command=f'cmd{i}')
+            cmd._id = i + 2
+            obs = CmdOutputObservation(
+                command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
+            )
+            obs._cause = cmd._id
+            events.extend([cmd, obs])
+
+        # Set up initial history
+        controller.state.history = events.copy()
+
+        # Force truncation
+        controller.state.history = controller._apply_conversation_window(
+            controller.state.history
+        )
+
+        # Save state
+        saved_start_id = controller.state.start_id
+        saved_truncation_id = controller.state.truncation_id
+        saved_history_len = len(controller.state.history)
+
+        # Set up mock event stream for new controller
+        mock_event_stream.get_events.return_value = controller.state.history
+
+        # Create new controller with saved state
+        new_controller = AgentController(
+            agent=mock_agent,
+            event_stream=mock_event_stream,
+            max_iterations=10,
+            sid='test_truncation',
+            confirmation_mode=False,
+            headless_mode=True,
+        )
+        new_controller.state.start_id = saved_start_id
+        new_controller.state.truncation_id = saved_truncation_id
+        new_controller.state.history = mock_event_stream.get_events()
+
+        # Verify restoration
+        assert len(new_controller.state.history) == saved_history_len
+        assert new_controller.state.history[0] == first_msg
+        assert new_controller.state.start_id == saved_start_id