| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188 |
- from unittest.mock import MagicMock
- import pytest
- from openhands.controller.agent_controller import AgentController
- from openhands.events import EventSource
- from openhands.events.action import CmdRunAction, MessageAction
- from openhands.events.observation import CmdOutputObservation
- @pytest.fixture
- def mock_event_stream():
- stream = MagicMock()
- # Mock get_events to return an empty list by default
- stream.get_events.return_value = []
- return stream
- @pytest.fixture
- def mock_agent():
- agent = MagicMock()
- agent.llm = MagicMock()
- agent.llm.config = MagicMock()
- return agent
- class TestTruncation:
- def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent):
- controller = AgentController(
- agent=mock_agent,
- event_stream=mock_event_stream,
- max_iterations=10,
- sid='test_truncation',
- confirmation_mode=False,
- headless_mode=True,
- )
- # Create a sequence of events with IDs
- first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
- first_msg._source = EventSource.USER
- first_msg._id = 1
- cmd1 = CmdRunAction(command='ls')
- cmd1._id = 2
- obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=2)
- obs1._id = 3
- obs1._cause = 2
- cmd2 = CmdRunAction(command='pwd')
- cmd2._id = 4
- obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=4)
- obs2._id = 5
- obs2._cause = 4
- events = [first_msg, cmd1, obs1, cmd2, obs2]
- # Apply truncation
- truncated = controller._apply_conversation_window(events)
- # Should keep first user message and roughly half of other events
- assert (
- len(truncated) >= 3
- ) # First message + at least one action-observation pair
- assert truncated[0] == first_msg # First message always preserved
- assert controller.state.start_id == first_msg._id
- assert controller.state.truncation_id is not None
- # Verify pairs aren't split
- for i, event in enumerate(truncated[1:]):
- if isinstance(event, CmdOutputObservation):
- assert any(e._id == event._cause for e in truncated[: i + 1])
- def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
- controller = AgentController(
- agent=mock_agent,
- event_stream=mock_event_stream,
- max_iterations=10,
- sid='test_truncation',
- confirmation_mode=False,
- headless_mode=True,
- )
- # Setup initial history with IDs
- first_msg = MessageAction(content='Start task', wait_for_response=False)
- first_msg._source = EventSource.USER
- first_msg._id = 1
- # Add agent question
- agent_msg = MessageAction(
- content='What task would you like me to perform?', wait_for_response=True
- )
- agent_msg._source = EventSource.AGENT
- agent_msg._id = 2
- # Add user response
- user_response = MessageAction(
- content='Please list all files and show me current directory',
- wait_for_response=False,
- )
- user_response._source = EventSource.USER
- user_response._id = 3
- cmd1 = CmdRunAction(command='ls')
- cmd1._id = 4
- obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
- obs1._id = 5
- obs1._cause = 4
- # Update mock event stream to include new messages
- mock_event_stream.get_events.return_value = [
- first_msg,
- agent_msg,
- user_response,
- cmd1,
- obs1,
- ]
- controller.state.history = [first_msg, agent_msg, user_response, cmd1, obs1]
- original_history_len = len(controller.state.history)
- # Simulate ContextWindowExceededError and truncation
- controller.state.history = controller._apply_conversation_window(
- controller.state.history
- )
- # Verify truncation occurred
- assert len(controller.state.history) < original_history_len
- assert controller.state.start_id == first_msg._id
- assert controller.state.truncation_id is not None
- assert controller.state.truncation_id > controller.state.start_id
- def test_history_restoration_after_truncation(self, mock_event_stream, mock_agent):
- controller = AgentController(
- agent=mock_agent,
- event_stream=mock_event_stream,
- max_iterations=10,
- sid='test_truncation',
- confirmation_mode=False,
- headless_mode=True,
- )
- # Create events with IDs
- first_msg = MessageAction(content='Start task', wait_for_response=False)
- first_msg._source = EventSource.USER
- first_msg._id = 1
- events = [first_msg]
- for i in range(5):
- cmd = CmdRunAction(command=f'cmd{i}')
- cmd._id = i + 2
- obs = CmdOutputObservation(
- command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
- )
- obs._cause = cmd._id
- events.extend([cmd, obs])
- # Set up initial history
- controller.state.history = events.copy()
- # Force truncation
- controller.state.history = controller._apply_conversation_window(
- controller.state.history
- )
- # Save state
- saved_start_id = controller.state.start_id
- saved_truncation_id = controller.state.truncation_id
- saved_history_len = len(controller.state.history)
- # Set up mock event stream for new controller
- mock_event_stream.get_events.return_value = controller.state.history
- # Create new controller with saved state
- new_controller = AgentController(
- agent=mock_agent,
- event_stream=mock_event_stream,
- max_iterations=10,
- sid='test_truncation',
- confirmation_mode=False,
- headless_mode=True,
- )
- new_controller.state.start_id = saved_start_id
- new_controller.state.truncation_id = saved_truncation_id
- new_controller.state.history = mock_event_stream.get_events()
- # Verify restoration
- assert len(new_controller.state.history) == saved_history_len
- assert new_controller.state.history[0] == first_msg
- assert new_controller.state.start_id == saved_start_id
|