test_prompt_caching.py 9.7 KB


  1. from unittest.mock import MagicMock, Mock, patch
  2. import pytest
  3. from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
  4. from openhands.core.config import AgentConfig, LLMConfig
  5. from openhands.events import EventSource, EventStream
  6. from openhands.events.action import CmdRunAction, MessageAction
  7. from openhands.events.observation import CmdOutputObservation
  8. from openhands.llm.llm import LLM
  9. from openhands.storage import get_file_store
  10. @pytest.fixture
  11. def mock_llm():
  12. llm = Mock(spec=LLM)
  13. llm.config = LLMConfig(model='claude-3-5-sonnet-20241022', caching_prompt=True)
  14. llm.is_caching_prompt_active.return_value = True
  15. return llm
  16. @pytest.fixture
  17. def mock_event_stream(tmp_path):
  18. file_store = get_file_store('local', str(tmp_path))
  19. return EventStream('test_session', file_store)
  20. @pytest.fixture(params=[False, True])
  21. def codeact_agent(mock_llm, request):
  22. config = AgentConfig()
  23. config.function_calling = request.param
  24. return CodeActAgent(mock_llm, config)
  25. def response_mock(content: str):
  26. class MockModelResponse:
  27. def __init__(self, content):
  28. self.choices = [
  29. {
  30. 'message': {
  31. 'content': content,
  32. 'tool_calls': [
  33. {
  34. 'function': {
  35. 'name': 'execute_bash',
  36. 'arguments': '{}',
  37. }
  38. }
  39. ],
  40. }
  41. }
  42. ]
  43. def model_dump(self):
  44. return {'choices': self.choices}
  45. return MockModelResponse(content)
  46. def test_get_messages_with_reminder(codeact_agent, mock_event_stream):
  47. # Add some events to the stream
  48. mock_event_stream.add_event(MessageAction('Initial user message'), EventSource.USER)
  49. mock_event_stream.add_event(MessageAction('Sure!'), EventSource.AGENT)
  50. mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
  51. mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
  52. mock_event_stream.add_event(MessageAction('Laaaaaaaast!'), EventSource.USER)
  53. codeact_agent.reset()
  54. messages = codeact_agent._get_messages(
  55. Mock(history=mock_event_stream, max_iterations=5, iteration=0)
  56. )
  57. assert (
  58. len(messages) == 6
  59. ) # System, initial user + user message, agent message, last user message
  60. assert messages[0].content[0].cache_prompt # system message
  61. assert messages[1].role == 'user'
  62. if not codeact_agent.config.function_calling:
  63. assert messages[1].content[0].text.endswith("LET'S START!")
  64. assert messages[1].content[1].text.endswith('Initial user message')
  65. else:
  66. assert messages[1].content[0].text.endswith('Initial user message')
  67. # we add cache breakpoint to the last 3 user messages
  68. assert messages[1].content[-1].cache_prompt
  69. assert messages[3].role == 'user'
  70. assert messages[3].content[0].text == ('Hello, agent!')
  71. assert messages[3].content[0].cache_prompt
  72. assert messages[4].role == 'assistant'
  73. assert messages[4].content[0].text == 'Hello, user!'
  74. assert not messages[4].content[0].cache_prompt
  75. assert messages[5].role == 'user'
  76. assert messages[5].content[0].text.startswith('Laaaaaaaast!')
  77. assert messages[5].content[0].cache_prompt
  78. if not codeact_agent.config.function_calling:
  79. assert (
  80. messages[5]
  81. .content[1]
  82. .text.endswith(
  83. 'ENVIRONMENT REMINDER: You have 5 turns left to complete the task. When finished reply with <finish></finish>.'
  84. )
  85. )
  86. def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
  87. # Add multiple user and agent messages
  88. for i in range(15):
  89. mock_event_stream.add_event(
  90. MessageAction(f'User message {i}'), EventSource.USER
  91. )
  92. mock_event_stream.add_event(
  93. MessageAction(f'Agent message {i}'), EventSource.AGENT
  94. )
  95. codeact_agent.reset()
  96. messages = codeact_agent._get_messages(
  97. Mock(history=mock_event_stream, max_iterations=10, iteration=5)
  98. )
  99. # Check that only the last two user messages have cache_prompt=True
  100. cached_user_messages = [
  101. msg
  102. for msg in messages
  103. if msg.role in ('user', 'system') and msg.content[0].cache_prompt
  104. ]
  105. assert (
  106. len(cached_user_messages) == 4
  107. ) # Including the initial system+user + 2 last user message
  108. # Verify that these are indeed the last two user messages (from start)
  109. if not codeact_agent.config.function_calling:
  110. assert (
  111. cached_user_messages[0].content[0].text.startswith('A chat between')
  112. ) # system message
  113. assert cached_user_messages[2].content[0].text.startswith('User message 1')
  114. assert cached_user_messages[3].content[0].text.startswith('User message 1')
  115. def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
  116. # Add a mix of actions and observations
  117. message_action_1 = MessageAction(
  118. "Let's list the contents of the current directory."
  119. )
  120. mock_event_stream.add_event(message_action_1, EventSource.USER)
  121. cmd_action_1 = CmdRunAction('ls -l', thought='List files in current directory')
  122. mock_event_stream.add_event(cmd_action_1, EventSource.AGENT)
  123. cmd_observation_1 = CmdOutputObservation(
  124. content='total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt',
  125. command_id=cmd_action_1._id,
  126. command='ls -l',
  127. exit_code=0,
  128. )
  129. mock_event_stream.add_event(cmd_observation_1, EventSource.USER)
  130. message_action_2 = MessageAction("Now, let's create a new directory.")
  131. mock_event_stream.add_event(message_action_2, EventSource.AGENT)
  132. cmd_action_2 = CmdRunAction('mkdir new_directory', thought='Create a new directory')
  133. mock_event_stream.add_event(cmd_action_2, EventSource.AGENT)
  134. cmd_observation_2 = CmdOutputObservation(
  135. content='',
  136. command_id=cmd_action_2._id,
  137. command='mkdir new_directory',
  138. exit_code=0,
  139. )
  140. mock_event_stream.add_event(cmd_observation_2, EventSource.USER)
  141. codeact_agent.reset()
  142. messages = codeact_agent._get_messages(
  143. Mock(history=mock_event_stream, max_iterations=5, iteration=0)
  144. )
  145. # Assert the presence of key elements in the messages
  146. assert (
  147. messages[1]
  148. .content[-1]
  149. .text.startswith("Let's list the contents of the current directory.")
  150. ) # user, included in the initial message
  151. if not codeact_agent.config.function_calling:
  152. assert any(
  153. 'List files in current directory\n<execute_bash>\nls -l\n</execute_bash>'
  154. in msg.content[0].text
  155. for msg in messages
  156. ) # agent
  157. assert any(
  158. 'total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt'
  159. in msg.content[0].text
  160. for msg in messages
  161. ) # user, observation
  162. assert any(
  163. "Now, let's create a new directory." in msg.content[0].text for msg in messages
  164. ) # agent
  165. if not codeact_agent.config.function_calling:
  166. assert messages[4].content[1].text.startswith('Create a new directory') # agent
  167. assert any(
  168. 'finished with exit code 0' in msg.content[0].text for msg in messages
  169. ) # user, observation
  170. assert (
  171. messages[5].content[0].text.startswith('OBSERVATION:\n\n')
  172. ) # user, observation
  173. # prompt cache is added to the system message
  174. assert messages[0].content[0].cache_prompt
  175. # and the first initial user message
  176. assert messages[1].content[-1].cache_prompt
  177. # and to the last two user messages
  178. assert messages[3].content[0].cache_prompt
  179. assert messages[5].content[0].cache_prompt
  180. # reminder is added to the last user message
  181. if not codeact_agent.config.function_calling:
  182. assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text
  183. def test_prompt_caching_headers(codeact_agent, mock_event_stream):
  184. if codeact_agent.config.function_calling:
  185. pytest.skip('Skipping this test for function calling')
  186. # Setup
  187. mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
  188. mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
  189. mock_short_term_history = MagicMock()
  190. mock_short_term_history.get_last_user_message.return_value = 'Hello, agent!'
  191. mock_state = Mock()
  192. mock_state.history = mock_short_term_history
  193. mock_state.max_iterations = 5
  194. mock_state.iteration = 0
  195. codeact_agent.reset()
  196. # Create a mock for litellm_completion
  197. def check_headers(**kwargs):
  198. assert 'extra_headers' in kwargs
  199. assert 'anthropic-beta' in kwargs['extra_headers']
  200. assert kwargs['extra_headers']['anthropic-beta'] == 'prompt-caching-2024-07-31'
  201. # Create a mock response with the expected structure
  202. mock_response = Mock()
  203. mock_response.choices = [Mock()]
  204. mock_response.choices[0].message = Mock()
  205. mock_response.choices[0].message.content = 'Hello! How can I assist you today?'
  206. return mock_response
  207. # Use patch to replace litellm_completion with our check_headers function
  208. with patch('openhands.llm.llm.litellm_completion', side_effect=check_headers):
  209. # Also patch the action parser to return a MessageAction
  210. with patch.object(
  211. codeact_agent.action_parser,
  212. 'parse',
  213. return_value=MessageAction('Hello! How can I assist you today?'),
  214. ):
  215. # Act
  216. result = codeact_agent.step(mock_state)
  217. # Assert
  218. assert isinstance(result, MessageAction)
  219. assert result.content == 'Hello! How can I assist you today?'