|
|
@@ -6,7 +6,7 @@ import pytest
|
|
|
|
|
|
from openhands.controller.agent import Agent
|
|
|
from openhands.controller.agent_controller import AgentController
|
|
|
-from openhands.controller.state.state import TrafficControlState
|
|
|
+from openhands.controller.state.state import State, TrafficControlState
|
|
|
from openhands.core.config import AppConfig
|
|
|
from openhands.core.main import run_controller
|
|
|
from openhands.core.schema import AgentState
|
|
|
@@ -41,7 +41,9 @@ def mock_agent():
|
|
|
|
|
|
@pytest.fixture
|
|
|
def mock_event_stream():
|
|
|
- return MagicMock(spec=EventStream)
|
|
|
+ mock = MagicMock(spec=EventStream)
|
|
|
+ mock.get_latest_event_id.return_value = 0
|
|
|
+ return mock
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
@@ -278,7 +280,9 @@ async def test_delegate_step_different_states(
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
-async def test_step_max_iterations(mock_agent, mock_event_stream):
|
|
|
+async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
|
|
+ # Test with headless_mode=False - should extend max_iterations
|
|
|
+ initial_state = State(max_iterations=10)
|
|
|
controller = AgentController(
|
|
|
agent=mock_agent,
|
|
|
event_stream=mock_event_stream,
|
|
|
@@ -286,18 +290,34 @@ async def test_step_max_iterations(mock_agent, mock_event_stream):
|
|
|
sid='test',
|
|
|
confirmation_mode=False,
|
|
|
headless_mode=False,
|
|
|
+ initial_state=initial_state,
|
|
|
)
|
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
|
controller.state.iteration = 10
|
|
|
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
|
+
|
|
|
+ # Trigger throttling by calling _step() when we hit max_iterations
|
|
|
await controller._step()
|
|
|
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
|
assert controller.state.agent_state == AgentState.ERROR
|
|
|
- await controller.close()
|
|
|
|
|
|
+ # Simulate a new user message
|
|
|
+ message_action = MessageAction(content='Test message')
|
|
|
+ message_action._source = EventSource.USER
|
|
|
+ await controller.on_event(message_action)
|
|
|
+
|
|
|
+ # Max iterations should be extended to current iteration + initial max_iterations
|
|
|
+ assert (
|
|
|
+ controller.state.max_iterations == 20
|
|
|
+ ) # Current iteration (10 initial because _step() should not have been executed) + initial max_iterations (10)
|
|
|
+ assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
|
+ assert controller.state.agent_state == AgentState.RUNNING
|
|
|
|
|
|
-@pytest.mark.asyncio
|
|
|
-async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
|
|
|
+ # Close the controller to clean up
|
|
|
+ await controller.close()
|
|
|
+
|
|
|
+ # Test with headless_mode=True - should NOT extend max_iterations
|
|
|
+ initial_state = State(max_iterations=10)
|
|
|
controller = AgentController(
|
|
|
agent=mock_agent,
|
|
|
event_stream=mock_event_stream,
|
|
|
@@ -305,13 +325,24 @@ async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
|
|
|
sid='test',
|
|
|
confirmation_mode=False,
|
|
|
headless_mode=True,
|
|
|
+ initial_state=initial_state,
|
|
|
)
|
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
|
controller.state.iteration = 10
|
|
|
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
|
+
|
|
|
+ # Simulate a new user message
|
|
|
+ message_action = MessageAction(content='Test message')
|
|
|
+ message_action._source = EventSource.USER
|
|
|
+ await controller.on_event(message_action)
|
|
|
+
|
|
|
+ # Max iterations should NOT be extended in headless mode
|
|
|
+ assert controller.state.max_iterations == 10 # Original value unchanged
|
|
|
+
|
|
|
+ # Trigger throttling by calling _step() when we hit max_iterations
|
|
|
await controller._step()
|
|
|
+
|
|
|
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
|
- # In headless mode, throttling results in an error
|
|
|
assert controller.state.agent_state == AgentState.ERROR
|
|
|
await controller.close()
|
|
|
|