test_iteration_limit.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import asyncio
  2. import pytest
  3. from openhands.controller.agent_controller import AgentController
  4. from openhands.core.schema import AgentState
  5. from openhands.events import EventStream
  6. from openhands.events.action import MessageAction
  7. from openhands.events.event import EventSource
  8. class DummyAgent:
  9. def __init__(self):
  10. self.name = 'dummy'
  11. self.llm = type(
  12. 'DummyLLM',
  13. (),
  14. {'metrics': type('DummyMetrics', (), {'merge': lambda x: None})()},
  15. )()
  16. def reset(self):
  17. pass
  18. @pytest.mark.asyncio
  19. async def test_iteration_limit_extends_on_user_message():
  20. # Initialize test components
  21. from openhands.storage.memory import InMemoryFileStore
  22. file_store = InMemoryFileStore()
  23. event_stream = EventStream(sid='test', file_store=file_store)
  24. agent = DummyAgent()
  25. initial_max_iterations = 100
  26. controller = AgentController(
  27. agent=agent,
  28. event_stream=event_stream,
  29. max_iterations=initial_max_iterations,
  30. sid='test',
  31. headless_mode=False,
  32. )
  33. # Set initial state
  34. await controller.set_agent_state_to(AgentState.RUNNING)
  35. controller.state.iteration = 90 # Close to the limit
  36. assert controller.state.max_iterations == initial_max_iterations
  37. # Simulate user message
  38. user_message = MessageAction('test message', EventSource.USER)
  39. event_stream.add_event(user_message, EventSource.USER)
  40. await asyncio.sleep(0.1) # Give time for event to be processed
  41. # Verify max_iterations was extended
  42. assert controller.state.max_iterations == 90 + initial_max_iterations
  43. # Simulate more iterations and another user message
  44. controller.state.iteration = 180 # Close to new limit
  45. user_message2 = MessageAction('another message', EventSource.USER)
  46. event_stream.add_event(user_message2, EventSource.USER)
  47. await asyncio.sleep(0.1) # Give time for event to be processed
  48. # Verify max_iterations was extended again
  49. assert controller.state.max_iterations == 180 + initial_max_iterations