test_traffic_control.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from unittest.mock import MagicMock
  2. import pytest
  3. from openhands.controller.agent_controller import AgentController
  4. from openhands.core.config import AgentConfig, LLMConfig
  5. from openhands.events import EventStream
  6. from openhands.llm.llm import LLM
  7. from openhands.storage import InMemoryFileStore
  8. @pytest.fixture
  9. def agent_controller():
  10. llm = LLM(config=LLMConfig())
  11. agent = MagicMock()
  12. agent.name = 'test_agent'
  13. agent.llm = llm
  14. agent.config = AgentConfig()
  15. event_stream = EventStream(sid='test', file_store=InMemoryFileStore())
  16. controller = AgentController(
  17. agent=agent,
  18. event_stream=event_stream,
  19. max_iterations=100,
  20. max_budget_per_task=10.0,
  21. sid='test',
  22. headless_mode=False,
  23. )
  24. return controller
  25. @pytest.mark.asyncio
  26. async def test_traffic_control_iteration_message(agent_controller):
  27. """Test that iteration messages are formatted as integers."""
  28. # Mock _react_to_exception to capture the error
  29. error = None
  30. async def mock_react_to_exception(e):
  31. nonlocal error
  32. error = e
  33. agent_controller._react_to_exception = mock_react_to_exception
  34. await agent_controller._handle_traffic_control('iteration', 200.0, 100.0)
  35. assert error is not None
  36. assert 'Current iteration: 200, max iteration: 100' in str(error)
  37. @pytest.mark.asyncio
  38. async def test_traffic_control_budget_message(agent_controller):
  39. """Test that budget messages keep decimal points."""
  40. # Mock _react_to_exception to capture the error
  41. error = None
  42. async def mock_react_to_exception(e):
  43. nonlocal error
  44. error = e
  45. agent_controller._react_to_exception = mock_react_to_exception
  46. await agent_controller._handle_traffic_control('budget', 15.75, 10.0)
  47. assert error is not None
  48. assert 'Current budget: 15.75, max budget: 10.00' in str(error)
  49. @pytest.mark.asyncio
  50. async def test_traffic_control_headless_mode(agent_controller):
  51. """Test that headless mode messages are formatted correctly."""
  52. # Mock _react_to_exception to capture the error
  53. error = None
  54. async def mock_react_to_exception(e):
  55. nonlocal error
  56. error = e
  57. agent_controller._react_to_exception = mock_react_to_exception
  58. agent_controller.headless_mode = True
  59. await agent_controller._handle_traffic_control('iteration', 200.0, 100.0)
  60. assert error is not None
  61. assert 'in headless mode' in str(error)
  62. assert 'Current iteration: 200, max iteration: 100' in str(error)