test_truncation.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from unittest.mock import MagicMock
  2. import pytest
  3. from openhands.controller.agent_controller import AgentController
  4. from openhands.events import EventSource
  5. from openhands.events.action import CmdRunAction, MessageAction
  6. from openhands.events.observation import CmdOutputObservation
  7. @pytest.fixture
  8. def mock_event_stream():
  9. stream = MagicMock()
  10. # Mock get_events to return an empty list by default
  11. stream.get_events.return_value = []
  12. return stream
  13. @pytest.fixture
  14. def mock_agent():
  15. agent = MagicMock()
  16. agent.llm = MagicMock()
  17. agent.llm.config = MagicMock()
  18. return agent
  19. class TestTruncation:
  20. def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent):
  21. controller = AgentController(
  22. agent=mock_agent,
  23. event_stream=mock_event_stream,
  24. max_iterations=10,
  25. sid='test_truncation',
  26. confirmation_mode=False,
  27. headless_mode=True,
  28. )
  29. # Create a sequence of events with IDs
  30. first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
  31. first_msg._source = EventSource.USER
  32. first_msg._id = 1
  33. cmd1 = CmdRunAction(command='ls')
  34. cmd1._id = 2
  35. obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=2)
  36. obs1._id = 3
  37. obs1._cause = 2
  38. cmd2 = CmdRunAction(command='pwd')
  39. cmd2._id = 4
  40. obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=4)
  41. obs2._id = 5
  42. obs2._cause = 4
  43. events = [first_msg, cmd1, obs1, cmd2, obs2]
  44. # Apply truncation
  45. truncated = controller._apply_conversation_window(events)
  46. # Should keep first user message and roughly half of other events
  47. assert (
  48. len(truncated) >= 3
  49. ) # First message + at least one action-observation pair
  50. assert truncated[0] == first_msg # First message always preserved
  51. assert controller.state.start_id == first_msg._id
  52. assert controller.state.truncation_id is not None
  53. # Verify pairs aren't split
  54. for i, event in enumerate(truncated[1:]):
  55. if isinstance(event, CmdOutputObservation):
  56. assert any(e._id == event._cause for e in truncated[: i + 1])
  57. def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
  58. controller = AgentController(
  59. agent=mock_agent,
  60. event_stream=mock_event_stream,
  61. max_iterations=10,
  62. sid='test_truncation',
  63. confirmation_mode=False,
  64. headless_mode=True,
  65. )
  66. # Setup initial history with IDs
  67. first_msg = MessageAction(content='Start task', wait_for_response=False)
  68. first_msg._source = EventSource.USER
  69. first_msg._id = 1
  70. # Add agent question
  71. agent_msg = MessageAction(
  72. content='What task would you like me to perform?', wait_for_response=True
  73. )
  74. agent_msg._source = EventSource.AGENT
  75. agent_msg._id = 2
  76. # Add user response
  77. user_response = MessageAction(
  78. content='Please list all files and show me current directory',
  79. wait_for_response=False,
  80. )
  81. user_response._source = EventSource.USER
  82. user_response._id = 3
  83. cmd1 = CmdRunAction(command='ls')
  84. cmd1._id = 4
  85. obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
  86. obs1._id = 5
  87. obs1._cause = 4
  88. # Update mock event stream to include new messages
  89. mock_event_stream.get_events.return_value = [
  90. first_msg,
  91. agent_msg,
  92. user_response,
  93. cmd1,
  94. obs1,
  95. ]
  96. controller.state.history = [first_msg, agent_msg, user_response, cmd1, obs1]
  97. original_history_len = len(controller.state.history)
  98. # Simulate ContextWindowExceededError and truncation
  99. controller.state.history = controller._apply_conversation_window(
  100. controller.state.history
  101. )
  102. # Verify truncation occurred
  103. assert len(controller.state.history) < original_history_len
  104. assert controller.state.start_id == first_msg._id
  105. assert controller.state.truncation_id is not None
  106. assert controller.state.truncation_id > controller.state.start_id
  107. def test_history_restoration_after_truncation(self, mock_event_stream, mock_agent):
  108. controller = AgentController(
  109. agent=mock_agent,
  110. event_stream=mock_event_stream,
  111. max_iterations=10,
  112. sid='test_truncation',
  113. confirmation_mode=False,
  114. headless_mode=True,
  115. )
  116. # Create events with IDs
  117. first_msg = MessageAction(content='Start task', wait_for_response=False)
  118. first_msg._source = EventSource.USER
  119. first_msg._id = 1
  120. events = [first_msg]
  121. for i in range(5):
  122. cmd = CmdRunAction(command=f'cmd{i}')
  123. cmd._id = i + 2
  124. obs = CmdOutputObservation(
  125. command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
  126. )
  127. obs._cause = cmd._id
  128. events.extend([cmd, obs])
  129. # Set up initial history
  130. controller.state.history = events.copy()
  131. # Force truncation
  132. controller.state.history = controller._apply_conversation_window(
  133. controller.state.history
  134. )
  135. # Save state
  136. saved_start_id = controller.state.start_id
  137. saved_truncation_id = controller.state.truncation_id
  138. saved_history_len = len(controller.state.history)
  139. # Set up mock event stream for new controller
  140. mock_event_stream.get_events.return_value = controller.state.history
  141. # Create new controller with saved state
  142. new_controller = AgentController(
  143. agent=mock_agent,
  144. event_stream=mock_event_stream,
  145. max_iterations=10,
  146. sid='test_truncation',
  147. confirmation_mode=False,
  148. headless_mode=True,
  149. )
  150. new_controller.state.start_id = saved_start_id
  151. new_controller.state.truncation_id = saved_truncation_id
  152. new_controller.state.history = mock_event_stream.get_events()
  153. # Verify restoration
  154. assert len(new_controller.state.history) == saved_history_len
  155. assert new_controller.state.history[0] == first_msg
  156. assert new_controller.state.start_id == saved_start_id