test_prompt_caching.py 8.3 KB


  1. from unittest.mock import MagicMock, Mock, patch
  2. import pytest
  3. from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
  4. from openhands.core.config import AgentConfig, LLMConfig
  5. from openhands.events import EventSource, EventStream
  6. from openhands.events.action import CmdRunAction, MessageAction
  7. from openhands.events.observation import CmdOutputObservation
  8. from openhands.llm.llm import LLM
  9. from openhands.storage import get_file_store
  10. @pytest.fixture
  11. def mock_llm():
  12. llm = Mock(spec=LLM)
  13. llm.config = LLMConfig(model='claude-3-5-sonnet-20240620', caching_prompt=True)
  14. llm.is_caching_prompt_active.return_value = True
  15. return llm
  16. @pytest.fixture
  17. def mock_event_stream(tmp_path):
  18. file_store = get_file_store('local', str(tmp_path))
  19. return EventStream('test_session', file_store)
  20. @pytest.fixture
  21. def codeact_agent(mock_llm):
  22. config = AgentConfig()
  23. return CodeActAgent(mock_llm, config)
  24. def test_get_messages_with_reminder(codeact_agent, mock_event_stream):
  25. # Add some events to the stream
  26. mock_event_stream.add_event(MessageAction('Initial user message'), EventSource.USER)
  27. mock_event_stream.add_event(MessageAction('Sure!'), EventSource.AGENT)
  28. mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
  29. mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
  30. mock_event_stream.add_event(MessageAction('Laaaaaaaast!'), EventSource.USER)
  31. codeact_agent.reset()
  32. messages = codeact_agent._get_messages(
  33. Mock(history=mock_event_stream, max_iterations=5, iteration=0)
  34. )
  35. assert (
  36. len(messages) == 6
  37. ) # System, initial user + user message, agent message, last user message
  38. assert messages[0].content[0].cache_prompt # system message
  39. assert messages[1].role == 'user'
  40. assert messages[1].content[0].text.endswith("LET'S START!")
  41. assert messages[1].content[1].text.endswith('Initial user message')
  42. assert messages[1].content[0].cache_prompt
  43. assert messages[3].role == 'user'
  44. assert messages[3].content[0].text == ('Hello, agent!')
  45. assert messages[3].content[0].cache_prompt
  46. assert messages[4].role == 'assistant'
  47. assert messages[4].content[0].text == 'Hello, user!'
  48. assert not messages[4].content[0].cache_prompt
  49. assert messages[5].role == 'user'
  50. assert messages[5].content[0].text.startswith('Laaaaaaaast!')
  51. assert messages[5].content[0].cache_prompt
  52. assert (
  53. messages[5]
  54. .content[1]
  55. .text.endswith(
  56. 'ENVIRONMENT REMINDER: You have 5 turns left to complete the task. When finished reply with <finish></finish>.'
  57. )
  58. )
  59. def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
  60. # Add multiple user and agent messages
  61. for i in range(15):
  62. mock_event_stream.add_event(
  63. MessageAction(f'User message {i}'), EventSource.USER
  64. )
  65. mock_event_stream.add_event(
  66. MessageAction(f'Agent message {i}'), EventSource.AGENT
  67. )
  68. codeact_agent.reset()
  69. messages = codeact_agent._get_messages(
  70. Mock(history=mock_event_stream, max_iterations=10, iteration=5)
  71. )
  72. # Check that only the last two user messages have cache_prompt=True
  73. cached_user_messages = [
  74. msg
  75. for msg in messages
  76. if msg.role in ('user', 'system') and msg.content[0].cache_prompt
  77. ]
  78. assert (
  79. len(cached_user_messages) == 4
  80. ) # Including the initial system+user + 2 last user message
  81. # Verify that these are indeed the last two user messages (from start)
  82. assert (
  83. cached_user_messages[0].content[0].text.startswith('A chat between')
  84. ) # system message
  85. assert cached_user_messages[2].content[0].text.startswith('User message 1')
  86. assert cached_user_messages[3].content[0].text.startswith('User message 1')
  87. def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
  88. # Add a mix of actions and observations
  89. message_action_1 = MessageAction(
  90. "Let's list the contents of the current directory."
  91. )
  92. mock_event_stream.add_event(message_action_1, EventSource.USER)
  93. cmd_action_1 = CmdRunAction('ls -l', thought='List files in current directory')
  94. mock_event_stream.add_event(cmd_action_1, EventSource.AGENT)
  95. cmd_observation_1 = CmdOutputObservation(
  96. content='total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt',
  97. command_id=cmd_action_1._id,
  98. command='ls -l',
  99. exit_code=0,
  100. )
  101. mock_event_stream.add_event(cmd_observation_1, EventSource.USER)
  102. message_action_2 = MessageAction("Now, let's create a new directory.")
  103. mock_event_stream.add_event(message_action_2, EventSource.AGENT)
  104. cmd_action_2 = CmdRunAction('mkdir new_directory', thought='Create a new directory')
  105. mock_event_stream.add_event(cmd_action_2, EventSource.AGENT)
  106. cmd_observation_2 = CmdOutputObservation(
  107. content='',
  108. command_id=cmd_action_2._id,
  109. command='mkdir new_directory',
  110. exit_code=0,
  111. )
  112. mock_event_stream.add_event(cmd_observation_2, EventSource.USER)
  113. codeact_agent.reset()
  114. messages = codeact_agent._get_messages(
  115. Mock(history=mock_event_stream, max_iterations=5, iteration=0)
  116. )
  117. # Assert the presence of key elements in the messages
  118. assert (
  119. messages[1]
  120. .content[1]
  121. .text.startswith("Let's list the contents of the current directory.")
  122. ) # user, included in the initial message
  123. assert any(
  124. 'List files in current directory\n<execute_bash>\nls -l\n</execute_bash>'
  125. in msg.content[0].text
  126. for msg in messages
  127. ) # agent
  128. assert any(
  129. 'total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt'
  130. in msg.content[0].text
  131. for msg in messages
  132. ) # user, observation
  133. assert any(
  134. "Now, let's create a new directory." in msg.content[0].text for msg in messages
  135. ) # agent
  136. assert messages[4].content[1].text.startswith('Create a new directory') # agent
  137. assert any(
  138. 'finished with exit code 0' in msg.content[0].text for msg in messages
  139. ) # user, observation
  140. assert (
  141. messages[5].content[0].text.startswith('OBSERVATION:\n\n')
  142. ) # user, observation
  143. # prompt cache is added to the system message
  144. assert messages[0].content[0].cache_prompt
  145. # and the first initial user message
  146. assert messages[1].content[0].cache_prompt
  147. # and to the last two user messages
  148. assert messages[3].content[0].cache_prompt
  149. assert messages[5].content[0].cache_prompt
  150. # reminder is added to the last user message
  151. assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text
  152. def test_prompt_caching_headers(codeact_agent, mock_event_stream):
  153. # Setup
  154. mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
  155. mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
  156. mock_short_term_history = MagicMock()
  157. mock_short_term_history.get_last_user_message.return_value = 'Hello, agent!'
  158. mock_state = Mock()
  159. mock_state.history = mock_short_term_history
  160. mock_state.max_iterations = 5
  161. mock_state.iteration = 0
  162. codeact_agent.reset()
  163. # Create a mock for litellm_completion
  164. def check_headers(**kwargs):
  165. assert 'extra_headers' in kwargs
  166. assert 'anthropic-beta' in kwargs['extra_headers']
  167. assert kwargs['extra_headers']['anthropic-beta'] == 'prompt-caching-2024-07-31'
  168. # Create a mock response with the expected structure
  169. mock_response = Mock()
  170. mock_response.choices = [Mock()]
  171. mock_response.choices[0].message = Mock()
  172. mock_response.choices[0].message.content = 'Hello! How can I assist you today?'
  173. return mock_response
  174. # Use patch to replace litellm_completion with our check_headers function
  175. with patch('openhands.llm.llm.litellm_completion', side_effect=check_headers):
  176. # Also patch the action parser to return a MessageAction
  177. with patch.object(
  178. codeact_agent.action_parser,
  179. 'parse',
  180. return_value=MessageAction('Hello! How can I assist you today?'),
  181. ):
  182. # Act
  183. result = codeact_agent.step(mock_state)
  184. # Assert
  185. assert isinstance(result, MessageAction)
  186. assert result.content == 'Hello! How can I assist you today?'