test_prompt_caching.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. from unittest.mock import Mock, patch
  2. import pytest
  3. from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
  4. from openhands.core.config import AgentConfig, LLMConfig
  5. from openhands.events.action import CmdRunAction, MessageAction
  6. from openhands.events.observation import CmdOutputObservation
  7. from openhands.llm.llm import LLM
  8. @pytest.fixture
  9. def mock_llm():
  10. llm = Mock(spec=LLM)
  11. llm.config = LLMConfig(model='claude-3-5-sonnet-20241022', caching_prompt=True)
  12. llm.is_caching_prompt_active.return_value = True
  13. return llm
  14. @pytest.fixture(params=[False, True])
  15. def codeact_agent(mock_llm, request):
  16. config = AgentConfig()
  17. config.function_calling = request.param
  18. return CodeActAgent(mock_llm, config)
  19. def response_mock(content: str):
  20. class MockModelResponse:
  21. def __init__(self, content):
  22. self.choices = [
  23. {
  24. 'message': {
  25. 'content': content,
  26. 'tool_calls': [
  27. {
  28. 'function': {
  29. 'name': 'execute_bash',
  30. 'arguments': '{}',
  31. }
  32. }
  33. ],
  34. }
  35. }
  36. ]
  37. def model_dump(self):
  38. return {'choices': self.choices}
  39. return MockModelResponse(content)
  40. def test_get_messages_with_reminder(codeact_agent: CodeActAgent):
  41. # Add some events to history
  42. history = list()
  43. message_action_1 = MessageAction('Initial user message')
  44. message_action_1._source = 'user'
  45. history.append(message_action_1)
  46. message_action_2 = MessageAction('Sure!')
  47. message_action_2._source = 'assistant'
  48. history.append(message_action_2)
  49. message_action_3 = MessageAction('Hello, agent!')
  50. message_action_3._source = 'user'
  51. history.append(message_action_3)
  52. message_action_4 = MessageAction('Hello, user!')
  53. message_action_4._source = 'assistant'
  54. history.append(message_action_4)
  55. message_action_5 = MessageAction('Laaaaaaaast!')
  56. message_action_5._source = 'user'
  57. history.append(message_action_5)
  58. codeact_agent.reset()
  59. messages = codeact_agent._get_messages(
  60. Mock(history=history, max_iterations=5, iteration=0)
  61. )
  62. assert (
  63. len(messages) == 6
  64. ) # System, initial user + user message, agent message, last user message
  65. assert messages[0].content[0].cache_prompt # system message
  66. assert messages[1].role == 'user'
  67. if not codeact_agent.config.function_calling:
  68. assert messages[1].content[0].text.endswith("LET'S START!")
  69. assert messages[1].content[1].text.endswith('Initial user message')
  70. else:
  71. assert messages[1].content[0].text.endswith('Initial user message')
  72. # we add cache breakpoint to the last 3 user messages
  73. assert messages[1].content[-1].cache_prompt
  74. assert messages[3].role == 'user'
  75. assert messages[3].content[0].text == ('Hello, agent!')
  76. assert messages[3].content[0].cache_prompt
  77. assert messages[4].role == 'assistant'
  78. assert messages[4].content[0].text == 'Hello, user!'
  79. assert not messages[4].content[0].cache_prompt
  80. assert messages[5].role == 'user'
  81. assert messages[5].content[0].text.startswith('Laaaaaaaast!')
  82. assert messages[5].content[0].cache_prompt
  83. if not codeact_agent.config.function_calling:
  84. assert (
  85. messages[5]
  86. .content[1]
  87. .text.endswith(
  88. 'ENVIRONMENT REMINDER: You have 5 turns left to complete the task. When finished reply with <finish></finish>.'
  89. )
  90. )
  91. def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
  92. history = list()
  93. # Add multiple user and agent messages
  94. for i in range(15):
  95. message_action_user = MessageAction(f'User message {i}')
  96. message_action_user._source = 'user'
  97. history.append(message_action_user)
  98. message_action_agent = MessageAction(f'Agent message {i}')
  99. message_action_agent._source = 'assistant'
  100. history.append(message_action_agent)
  101. codeact_agent.reset()
  102. messages = codeact_agent._get_messages(
  103. Mock(history=history, max_iterations=10, iteration=5)
  104. )
  105. # Check that only the last two user messages have cache_prompt=True
  106. cached_user_messages = [
  107. msg
  108. for msg in messages
  109. if msg.role in ('user', 'system') and msg.content[0].cache_prompt
  110. ]
  111. assert (
  112. len(cached_user_messages) == 4
  113. ) # Including the initial system+user + 2 last user message
  114. # Verify that these are indeed the last two user messages (from start)
  115. if not codeact_agent.config.function_calling:
  116. assert (
  117. cached_user_messages[0].content[0].text.startswith('A chat between')
  118. ) # system message
  119. assert cached_user_messages[2].content[0].text.startswith('User message 1')
  120. assert cached_user_messages[3].content[0].text.startswith('User message 1')
  121. def test_get_messages_with_cmd_action(codeact_agent: CodeActAgent):
  122. if codeact_agent.config.function_calling:
  123. pytest.skip('Skipping this test for function calling')
  124. history = list()
  125. # Add a mix of actions and observations
  126. message_action_1 = MessageAction(
  127. "Let's list the contents of the current directory."
  128. )
  129. message_action_1._source = 'user'
  130. history.append(message_action_1)
  131. cmd_action_1 = CmdRunAction('ls -l', thought='List files in current directory')
  132. cmd_action_1._source = 'agent'
  133. cmd_action_1._id = 'cmd_1'
  134. history.append(cmd_action_1)
  135. cmd_observation_1 = CmdOutputObservation(
  136. content='total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt',
  137. command_id=cmd_action_1._id,
  138. command='ls -l',
  139. exit_code=0,
  140. )
  141. cmd_observation_1._source = 'user'
  142. history.append(cmd_observation_1)
  143. message_action_2 = MessageAction("Now, let's create a new directory.")
  144. message_action_2._source = 'agent'
  145. history.append(message_action_2)
  146. cmd_action_2 = CmdRunAction('mkdir new_directory', thought='Create a new directory')
  147. cmd_action_2._source = 'agent'
  148. cmd_action_2._id = 'cmd_2'
  149. history.append(cmd_action_2)
  150. cmd_observation_2 = CmdOutputObservation(
  151. content='',
  152. command_id=cmd_action_2._id,
  153. command='mkdir new_directory',
  154. exit_code=0,
  155. )
  156. cmd_observation_2._source = 'user'
  157. history.append(cmd_observation_2)
  158. codeact_agent.reset()
  159. messages = codeact_agent._get_messages(
  160. Mock(history=history, max_iterations=5, iteration=0)
  161. )
  162. # Assert the presence of key elements in the messages
  163. assert (
  164. messages[1]
  165. .content[-1]
  166. .text.startswith("Let's list the contents of the current directory.")
  167. ) # user, included in the initial message
  168. if not codeact_agent.config.function_calling:
  169. assert any(
  170. 'List files in current directory\n<execute_bash>\nls -l\n</execute_bash>'
  171. in msg.content[0].text
  172. for msg in messages
  173. ) # agent
  174. assert any(
  175. 'total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt'
  176. in msg.content[0].text
  177. for msg in messages
  178. ) # user, observation
  179. assert any(
  180. "Now, let's create a new directory." in msg.content[0].text for msg in messages
  181. ) # agent
  182. if not codeact_agent.config.function_calling:
  183. assert messages[4].content[1].text.startswith('Create a new directory') # agent
  184. assert any(
  185. 'finished with exit code 0' in msg.content[0].text for msg in messages
  186. ) # user, observation
  187. assert (
  188. messages[5].content[0].text.startswith('OBSERVATION:\n\n')
  189. ) # user, observation
  190. # prompt cache is added to the system message
  191. assert messages[0].content[0].cache_prompt
  192. # and the first initial user message
  193. assert messages[1].content[-1].cache_prompt
  194. # and to the last two user messages
  195. assert messages[3].content[0].cache_prompt
  196. assert messages[5].content[0].cache_prompt
  197. # reminder is added to the last user message
  198. if not codeact_agent.config.function_calling:
  199. assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text
  200. def test_prompt_caching_headers(codeact_agent: CodeActAgent):
  201. history = list()
  202. if codeact_agent.config.function_calling:
  203. pytest.skip('Skipping this test for function calling')
  204. # Setup
  205. history.append(MessageAction('Hello, agent!'))
  206. history.append(MessageAction('Hello, user!'))
  207. mock_state = Mock()
  208. mock_state.history = history
  209. mock_state.max_iterations = 5
  210. mock_state.iteration = 0
  211. codeact_agent.reset()
  212. # Create a mock for litellm_completion
  213. def check_headers(**kwargs):
  214. assert 'extra_headers' in kwargs
  215. assert 'anthropic-beta' in kwargs['extra_headers']
  216. assert kwargs['extra_headers']['anthropic-beta'] == 'prompt-caching-2024-07-31'
  217. # Create a mock response with the expected structure
  218. mock_response = Mock()
  219. mock_response.choices = [Mock()]
  220. mock_response.choices[0].message = Mock()
  221. mock_response.choices[0].message.content = 'Hello! How can I assist you today?'
  222. return mock_response
  223. # Use patch to replace litellm_completion with our check_headers function
  224. with patch('openhands.llm.llm.litellm_completion', side_effect=check_headers):
  225. # Also patch the action parser to return a MessageAction
  226. with patch.object(
  227. codeact_agent.action_parser,
  228. 'parse',
  229. return_value=MessageAction('Hello! How can I assist you today?'),
  230. ):
  231. # Act
  232. result = codeact_agent.step(mock_state)
  233. # Assert
  234. assert isinstance(result, MessageAction)
  235. assert result.content == 'Hello! How can I assist you today?'