test_prompt_caching.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from unittest.mock import Mock
  2. import pytest
  3. from litellm import ModelResponse
  4. from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
  5. from openhands.core.config import AgentConfig, LLMConfig
  6. from openhands.events.action import MessageAction
  7. from openhands.llm.llm import LLM
  8. @pytest.fixture
  9. def mock_llm():
  10. llm = LLM(
  11. LLMConfig(
  12. model='claude-3-5-sonnet-20241022',
  13. api_key='fake',
  14. caching_prompt=True,
  15. )
  16. )
  17. return llm
  18. @pytest.fixture
  19. def codeact_agent(mock_llm):
  20. config = AgentConfig()
  21. return CodeActAgent(mock_llm, config)
  22. def response_mock(content: str, tool_call_id: str):
  23. class MockModelResponse:
  24. def __init__(self, content, tool_call_id):
  25. self.choices = [
  26. {
  27. 'message': {
  28. 'content': content,
  29. 'tool_calls': [
  30. {
  31. 'function': {
  32. 'id': tool_call_id,
  33. 'name': 'execute_bash',
  34. 'arguments': '{}',
  35. }
  36. }
  37. ],
  38. }
  39. }
  40. ]
  41. def model_dump(self):
  42. return {'choices': self.choices}
  43. return ModelResponse(**MockModelResponse(content, tool_call_id).model_dump())
  44. def test_get_messages(codeact_agent: CodeActAgent):
  45. # Add some events to history
  46. history = list()
  47. message_action_1 = MessageAction('Initial user message')
  48. message_action_1._source = 'user'
  49. history.append(message_action_1)
  50. message_action_2 = MessageAction('Sure!')
  51. message_action_2._source = 'assistant'
  52. history.append(message_action_2)
  53. message_action_3 = MessageAction('Hello, agent!')
  54. message_action_3._source = 'user'
  55. history.append(message_action_3)
  56. message_action_4 = MessageAction('Hello, user!')
  57. message_action_4._source = 'assistant'
  58. history.append(message_action_4)
  59. message_action_5 = MessageAction('Laaaaaaaast!')
  60. message_action_5._source = 'user'
  61. history.append(message_action_5)
  62. codeact_agent.reset()
  63. messages = codeact_agent._get_messages(
  64. Mock(history=history, max_iterations=5, iteration=0)
  65. )
  66. assert (
  67. len(messages) == 6
  68. ) # System, initial user + user message, agent message, last user message
  69. assert messages[0].content[0].cache_prompt # system message
  70. assert messages[1].role == 'user'
  71. assert messages[1].content[0].text.endswith('Initial user message')
  72. # we add cache breakpoint to the last 3 user messages
  73. assert messages[1].content[0].cache_prompt
  74. assert messages[3].role == 'user'
  75. assert messages[3].content[0].text == ('Hello, agent!')
  76. assert messages[3].content[0].cache_prompt
  77. assert messages[4].role == 'assistant'
  78. assert messages[4].content[0].text == 'Hello, user!'
  79. assert not messages[4].content[0].cache_prompt
  80. assert messages[5].role == 'user'
  81. assert messages[5].content[0].text.startswith('Laaaaaaaast!')
  82. assert messages[5].content[0].cache_prompt
  83. def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
  84. history = list()
  85. # Add multiple user and agent messages
  86. for i in range(15):
  87. message_action_user = MessageAction(f'User message {i}')
  88. message_action_user._source = 'user'
  89. history.append(message_action_user)
  90. message_action_agent = MessageAction(f'Agent message {i}')
  91. message_action_agent._source = 'assistant'
  92. history.append(message_action_agent)
  93. codeact_agent.reset()
  94. messages = codeact_agent._get_messages(
  95. Mock(history=history, max_iterations=10, iteration=5)
  96. )
  97. # Check that only the last two user messages have cache_prompt=True
  98. cached_user_messages = [
  99. msg
  100. for msg in messages
  101. if msg.role in ('user', 'system') and msg.content[0].cache_prompt
  102. ]
  103. assert (
  104. len(cached_user_messages) == 4
  105. ) # Including the initial system+user + 2 last user message
  106. # Verify that these are indeed the last two user messages (from start)
  107. assert cached_user_messages[0].content[0].text.startswith('You are OpenHands agent')
  108. assert cached_user_messages[2].content[0].text.startswith('User message 1')
  109. assert cached_user_messages[3].content[0].text.startswith('User message 1')
  110. def test_prompt_caching_headers(codeact_agent: CodeActAgent):
  111. history = list()
  112. # Setup
  113. msg1 = MessageAction('Hello, agent!')
  114. msg1._source = 'user'
  115. history.append(msg1)
  116. msg2 = MessageAction('Hello, user!')
  117. msg2._source = 'agent'
  118. history.append(msg2)
  119. mock_state = Mock()
  120. mock_state.history = history
  121. mock_state.max_iterations = 5
  122. mock_state.iteration = 0
  123. codeact_agent.reset()
  124. # Create a mock for litellm_completion
  125. def check_headers(**kwargs):
  126. assert 'extra_headers' in kwargs
  127. assert 'anthropic-beta' in kwargs['extra_headers']
  128. assert kwargs['extra_headers']['anthropic-beta'] == 'prompt-caching-2024-07-31'
  129. return ModelResponse(
  130. choices=[{'message': {'content': 'Hello! How can I assist you today?'}}]
  131. )
  132. codeact_agent.llm._completion_unwrapped = check_headers
  133. result = codeact_agent.step(mock_state)
  134. # Assert
  135. assert isinstance(result, MessageAction)
  136. assert result.content == 'Hello! How can I assist you today?'