|
@@ -5,6 +5,7 @@ import traceback
|
|
|
from typing import Callable, ClassVar, Type
|
|
from typing import Callable, ClassVar, Type
|
|
|
|
|
|
|
|
import litellm
|
|
import litellm
|
|
|
|
|
+from litellm.exceptions import ContextWindowExceededError
|
|
|
|
|
|
|
|
from openhands.controller.agent import Agent
|
|
from openhands.controller.agent import Agent
|
|
|
from openhands.controller.state.state import State, TrafficControlState
|
|
from openhands.controller.state.state import State, TrafficControlState
|
|
@@ -485,6 +486,15 @@ class AgentController:
|
|
|
EventSource.AGENT,
|
|
EventSource.AGENT,
|
|
|
)
|
|
)
|
|
|
return
|
|
return
|
|
|
|
|
+ except ContextWindowExceededError:
|
|
|
|
|
+ # When context window is exceeded, keep roughly half of agent interactions
|
|
|
|
|
+ self.state.history = self._apply_conversation_window(self.state.history)
|
|
|
|
|
+
|
|
|
|
|
+ # Save the ID of the first event in our truncated history for future reloading
|
|
|
|
|
+ if self.state.history:
|
|
|
|
|
+ self.state.start_id = self.state.history[0].id
|
|
|
|
|
+ # Don't add error event - let the agent retry with reduced context
|
|
|
|
|
+ return
|
|
|
|
|
|
|
|
if action.runnable:
|
|
if action.runnable:
|
|
|
if self.state.confirmation_mode and (
|
|
if self.state.confirmation_mode and (
|
|
@@ -659,6 +669,12 @@ class AgentController:
|
|
|
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
|
|
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
|
|
|
- Excludes all events between the action and observation
|
|
- Excludes all events between the action and observation
|
|
|
- Includes the delegate action and observation themselves
|
|
- Includes the delegate action and observation themselves
|
|
|
|
|
+
|
|
|
|
|
+ The history is loaded in two parts if truncation_id is set:
|
|
|
|
|
+ 1. First user message from start_id onwards
|
|
|
|
|
+ 2. Rest of history from truncation_id to the end
|
|
|
|
|
+
|
|
|
|
|
+ Otherwise loads normally from start_id.
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
# define range of events to fetch
|
|
# define range of events to fetch
|
|
@@ -680,8 +696,33 @@ class AgentController:
|
|
|
self.state.history = []
|
|
self.state.history = []
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
- # Get all events, filtering out backend events and hidden events
|
|
|
|
|
- events = list(
|
|
|
|
|
|
|
+ events: list[Event] = []
|
|
|
|
|
+
|
|
|
|
|
+ # If we have a truncation point, get first user message and then rest of history
|
|
|
|
|
+ if hasattr(self.state, 'truncation_id') and self.state.truncation_id > 0:
|
|
|
|
|
+ # Find first user message from stream
|
|
|
|
|
+ first_user_msg = next(
|
|
|
|
|
+ (
|
|
|
|
|
+ e
|
|
|
|
|
+ for e in self.event_stream.get_events(
|
|
|
|
|
+ start_id=start_id,
|
|
|
|
|
+ end_id=end_id,
|
|
|
|
|
+ reverse=False,
|
|
|
|
|
+ filter_out_type=self.filter_out,
|
|
|
|
|
+ filter_hidden=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ if isinstance(e, MessageAction) and e.source == EventSource.USER
|
|
|
|
|
+ ),
|
|
|
|
|
+ None,
|
|
|
|
|
+ )
|
|
|
|
|
+ if first_user_msg:
|
|
|
|
|
+ events.append(first_user_msg)
|
|
|
|
|
+
|
|
|
|
|
+ # the rest of the events are from the truncation point
|
|
|
|
|
+ start_id = self.state.truncation_id
|
|
|
|
|
+
|
|
|
|
|
+ # Get rest of history
|
|
|
|
|
+ events_to_add = list(
|
|
|
self.event_stream.get_events(
|
|
self.event_stream.get_events(
|
|
|
start_id=start_id,
|
|
start_id=start_id,
|
|
|
end_id=end_id,
|
|
end_id=end_id,
|
|
@@ -690,6 +731,7 @@ class AgentController:
|
|
|
filter_hidden=True,
|
|
filter_hidden=True,
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
|
|
+ events.extend(events_to_add)
|
|
|
|
|
|
|
|
# Find all delegate action/observation pairs
|
|
# Find all delegate action/observation pairs
|
|
|
delegate_ranges: list[tuple[int, int]] = []
|
|
delegate_ranges: list[tuple[int, int]] = []
|
|
@@ -744,6 +786,92 @@ class AgentController:
|
|
|
# make sure history is in sync
|
|
# make sure history is in sync
|
|
|
self.state.start_id = start_id
|
|
self.state.start_id = start_id
|
|
|
|
|
|
|
|
|
|
+ def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
|
|
|
|
|
+ """Cuts history roughly in half when context window is exceeded, preserving action-observation pairs
|
|
|
|
|
+ and ensuring the first user message is always included.
|
|
|
|
|
+
|
|
|
|
|
+ The algorithm:
|
|
|
|
|
+ 1. Cut history in half
|
|
|
|
|
+ 2. Check first event in new history:
|
|
|
|
|
+ - If Observation: find and include its Action
|
|
|
|
|
+ - If MessageAction: ensure its related Action-Observation pair isn't split
|
|
|
|
|
+ 3. Always include the first user message
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ events: List of events to filter
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ Filtered list of events keeping newest half while preserving pairs
|
|
|
|
|
+ """
|
|
|
|
|
+ if not events:
|
|
|
|
|
+ return events
|
|
|
|
|
+
|
|
|
|
|
+ # Find first user message - we'll need to ensure it's included
|
|
|
|
|
+ first_user_msg = next(
|
|
|
|
|
+ (
|
|
|
|
|
+ e
|
|
|
|
|
+ for e in events
|
|
|
|
|
+ if isinstance(e, MessageAction) and e.source == EventSource.USER
|
|
|
|
|
+ ),
|
|
|
|
|
+ None,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # cut in half
|
|
|
|
|
+ mid_point = max(1, len(events) // 2)
|
|
|
|
|
+ kept_events = events[mid_point:]
|
|
|
|
|
+
|
|
|
|
|
+ # Handle first event in truncated history
|
|
|
|
|
+ if kept_events:
|
|
|
|
|
+ i = 0
|
|
|
|
|
+ while i < len(kept_events):
|
|
|
|
|
+ first_event = kept_events[i]
|
|
|
|
|
+ if isinstance(first_event, Observation) and first_event.cause:
|
|
|
|
|
+ # Find its action and include it
|
|
|
|
|
+ matching_action = next(
|
|
|
|
|
+ (
|
|
|
|
|
+ e
|
|
|
|
|
+ for e in reversed(events[:mid_point])
|
|
|
|
|
+ if isinstance(e, Action) and e.id == first_event.cause
|
|
|
|
|
+ ),
|
|
|
|
|
+ None,
|
|
|
|
|
+ )
|
|
|
|
|
+ if matching_action:
|
|
|
|
|
+ kept_events = [matching_action] + kept_events
|
|
|
|
|
+ else:
|
|
|
|
|
+ self.log(
|
|
|
|
|
+ 'warning',
|
|
|
|
|
+ f'Found Observation without matching Action at id={first_event.id}',
|
|
|
|
|
+ )
|
|
|
|
|
+ # drop this observation
|
|
|
|
|
+ kept_events = kept_events[1:]
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ elif isinstance(first_event, MessageAction) or (
|
|
|
|
|
+ isinstance(first_event, Action)
|
|
|
|
|
+ and first_event.source == EventSource.USER
|
|
|
|
|
+ ):
|
|
|
|
|
+ # if it's a message action or a user action, keep it and continue to find the next event
|
|
|
|
|
+ i += 1
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ # if it's an action with source == EventSource.AGENT, we're good
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ # Save where to continue from in next reload
|
|
|
|
|
+ if kept_events:
|
|
|
|
|
+ self.state.truncation_id = kept_events[0].id
|
|
|
|
|
+
|
|
|
|
|
+ # Ensure first user message is included
|
|
|
|
|
+ if first_user_msg and first_user_msg not in kept_events:
|
|
|
|
|
+ kept_events = [first_user_msg] + kept_events
|
|
|
|
|
+
|
|
|
|
|
+ # start_id points to first user message
|
|
|
|
|
+ if first_user_msg:
|
|
|
|
|
+ self.state.start_id = first_user_msg.id
|
|
|
|
|
+
|
|
|
|
|
+ return kept_events
|
|
|
|
|
+
|
|
|
def _is_stuck(self):
|
|
def _is_stuck(self):
|
|
|
"""Checks if the agent or its delegate is stuck in a loop.
|
|
"""Checks if the agent or its delegate is stuck in a loop.
|
|
|
|
|
|