test_codeact_agent.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from unittest.mock import Mock
  2. import pytest
  3. from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
  4. from openhands.core.config import AgentConfig, LLMConfig
  5. from openhands.core.message import TextContent
  6. from openhands.events.observation.commands import (
  7. CmdOutputObservation,
  8. IPythonRunCellObservation,
  9. )
  10. from openhands.events.observation.delegate import AgentDelegateObservation
  11. from openhands.events.observation.error import ErrorObservation
  12. from openhands.llm.llm import LLM
  13. @pytest.fixture
  14. def agent() -> CodeActAgent:
  15. agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
  16. agent.llm = Mock()
  17. agent.llm.config = Mock()
  18. agent.llm.config.max_message_chars = 100
  19. return agent
  20. def test_cmd_output_observation_message(agent: CodeActAgent):
  21. agent.config.function_calling = False
  22. obs = CmdOutputObservation(
  23. command='echo hello', content='Command output', command_id=1, exit_code=0
  24. )
  25. results = agent.get_observation_message(obs, tool_call_id_to_message={})
  26. assert len(results) == 1
  27. result = results[0]
  28. assert result is not None
  29. assert result.role == 'user'
  30. assert len(result.content) == 1
  31. assert isinstance(result.content[0], TextContent)
  32. assert 'OBSERVATION:' in result.content[0].text
  33. assert 'Command output' in result.content[0].text
  34. assert 'Command finished with exit code 0' in result.content[0].text
  35. def test_ipython_run_cell_observation_message(agent: CodeActAgent):
  36. agent.config.function_calling = False
  37. obs = IPythonRunCellObservation(
  38. code='plt.plot()',
  39. content='IPython output\n![image](data:image/png;base64,ABC123)',
  40. )
  41. results = agent.get_observation_message(obs, tool_call_id_to_message={})
  42. assert len(results) == 1
  43. result = results[0]
  44. assert result is not None
  45. assert result.role == 'user'
  46. assert len(result.content) == 1
  47. assert isinstance(result.content[0], TextContent)
  48. assert 'OBSERVATION:' in result.content[0].text
  49. assert 'IPython output' in result.content[0].text
  50. assert (
  51. '![image](data:image/png;base64, ...) already displayed to user'
  52. in result.content[0].text
  53. )
  54. assert 'ABC123' not in result.content[0].text
  55. def test_agent_delegate_observation_message(agent: CodeActAgent):
  56. agent.config.function_calling = False
  57. obs = AgentDelegateObservation(
  58. content='Content', outputs={'content': 'Delegated agent output'}
  59. )
  60. results = agent.get_observation_message(obs, tool_call_id_to_message={})
  61. assert len(results) == 1
  62. result = results[0]
  63. assert result is not None
  64. assert result.role == 'user'
  65. assert len(result.content) == 1
  66. assert isinstance(result.content[0], TextContent)
  67. assert 'OBSERVATION:' in result.content[0].text
  68. assert 'Delegated agent output' in result.content[0].text
  69. def test_error_observation_message(agent: CodeActAgent):
  70. agent.config.function_calling = False
  71. obs = ErrorObservation('Error message')
  72. results = agent.get_observation_message(obs, tool_call_id_to_message={})
  73. assert len(results) == 1
  74. result = results[0]
  75. assert result is not None
  76. assert result.role == 'user'
  77. assert len(result.content) == 1
  78. assert isinstance(result.content[0], TextContent)
  79. assert 'OBSERVATION:' in result.content[0].text
  80. assert 'Error message' in result.content[0].text
  81. assert 'Error occurred in processing last action' in result.content[0].text
  82. def test_unknown_observation_message(agent: CodeActAgent):
  83. obs = Mock()
  84. with pytest.raises(ValueError, match='Unknown observation type'):
  85. agent.get_observation_message(obs, tool_call_id_to_message={})