test_prompt_caching.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. from unittest.mock import MagicMock, Mock
  2. import pytest
  3. from agenthub.codeact_agent.codeact_agent import CodeActAgent
  4. from openhands.core.config import AgentConfig, LLMConfig
  5. from openhands.events import EventSource, EventStream
  6. from openhands.events.action import CmdRunAction, MessageAction
  7. from openhands.events.observation import CmdOutputObservation
  8. from openhands.llm.llm import LLM
  9. from openhands.storage import get_file_store
  10. @pytest.fixture
  11. def mock_llm():
  12. llm = Mock(spec=LLM)
  13. llm.config = LLMConfig(model='claude-3-5-sonnet-20240620')
  14. llm.supports_prompt_caching = True
  15. return llm
  16. @pytest.fixture
  17. def mock_event_stream(tmp_path):
  18. file_store = get_file_store('local', str(tmp_path))
  19. return EventStream('test_session', file_store)
  20. @pytest.fixture
  21. def codeact_agent(mock_llm):
  22. config = AgentConfig()
  23. return CodeActAgent(mock_llm, config)
  24. def test_get_messages_with_reminder(codeact_agent, mock_event_stream):
  25. # Add some events to the stream
  26. mock_event_stream.add_event(MessageAction('Initial user message'), EventSource.USER)
  27. mock_event_stream.add_event(MessageAction('Sure!'), EventSource.AGENT)
  28. mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
  29. mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
  30. mock_event_stream.add_event(MessageAction('Laaaaaaaast!'), EventSource.USER)
  31. codeact_agent.reset()
  32. messages = codeact_agent._get_messages(
  33. Mock(history=mock_event_stream, max_iterations=5, iteration=0)
  34. )
  35. assert (
  36. len(messages) == 6
  37. ) # System, initial user + user message, agent message, last user message
  38. assert messages[0].content[0].cache_prompt
  39. assert messages[1].role == 'user'
  40. assert messages[1].content[0].text.endswith("LET'S START!")
  41. assert messages[1].content[1].text.endswith('Initial user message')
  42. assert messages[1].content[0].cache_prompt
  43. assert messages[3].role == 'user'
  44. assert messages[3].content[0].text == ('Hello, agent!')
  45. assert messages[4].role == 'assistant'
  46. assert messages[4].content[0].text == 'Hello, user!'
  47. assert messages[5].role == 'user'
  48. assert messages[5].content[0].text.startswith('Laaaaaaaast!')
  49. assert messages[5].content[0].cache_prompt
  50. assert (
  51. messages[5]
  52. .content[1]
  53. .text.endswith(
  54. 'ENVIRONMENT REMINDER: You have 5 turns left to complete the task. When finished reply with <finish></finish>.'
  55. )
  56. )
  57. def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
  58. # Add multiple user and agent messages
  59. for i in range(15):
  60. mock_event_stream.add_event(
  61. MessageAction(f'User message {i}'), EventSource.USER
  62. )
  63. mock_event_stream.add_event(
  64. MessageAction(f'Agent message {i}'), EventSource.AGENT
  65. )
  66. codeact_agent.reset()
  67. messages = codeact_agent._get_messages(
  68. Mock(history=mock_event_stream, max_iterations=10, iteration=5)
  69. )
  70. # Check that only the last two user messages have cache_prompt=True
  71. cached_user_messages = [
  72. msg for msg in messages if msg.role == 'user' and msg.content[0].cache_prompt
  73. ]
  74. assert len(cached_user_messages) == 3 # Including the initial system message
  75. # Verify that these are indeed the last two user messages
  76. assert cached_user_messages[0].content[0].text.startswith('Here is an example')
  77. assert cached_user_messages[1].content[0].text == 'User message 13'
  78. assert cached_user_messages[2].content[0].text.startswith('User message 14')
  79. def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
  80. # Add a mix of actions and observations
  81. message_action_1 = MessageAction(
  82. "Let's list the contents of the current directory."
  83. )
  84. mock_event_stream.add_event(message_action_1, EventSource.USER)
  85. cmd_action_1 = CmdRunAction('ls -l', thought='List files in current directory')
  86. mock_event_stream.add_event(cmd_action_1, EventSource.AGENT)
  87. cmd_observation_1 = CmdOutputObservation(
  88. content='total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt',
  89. command_id=cmd_action_1._id,
  90. command='ls -l',
  91. exit_code=0,
  92. )
  93. mock_event_stream.add_event(cmd_observation_1, EventSource.USER)
  94. message_action_2 = MessageAction("Now, let's create a new directory.")
  95. mock_event_stream.add_event(message_action_2, EventSource.AGENT)
  96. cmd_action_2 = CmdRunAction('mkdir new_directory', thought='Create a new directory')
  97. mock_event_stream.add_event(cmd_action_2, EventSource.AGENT)
  98. cmd_observation_2 = CmdOutputObservation(
  99. content='',
  100. command_id=cmd_action_2._id,
  101. command='mkdir new_directory',
  102. exit_code=0,
  103. )
  104. mock_event_stream.add_event(cmd_observation_2, EventSource.USER)
  105. codeact_agent.reset()
  106. messages = codeact_agent._get_messages(
  107. Mock(history=mock_event_stream, max_iterations=5, iteration=0)
  108. )
  109. # Assert the presence of key elements in the messages
  110. assert (
  111. messages[1]
  112. .content[1]
  113. .text.startswith("Let's list the contents of the current directory.")
  114. ) # user, included in the initial message
  115. assert any(
  116. 'List files in current directory\n<execute_bash>\nls -l\n</execute_bash>'
  117. in msg.content[0].text
  118. for msg in messages
  119. ) # agent
  120. assert any(
  121. 'total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt'
  122. in msg.content[0].text
  123. for msg in messages
  124. ) # user, observation
  125. assert any(
  126. "Now, let's create a new directory." in msg.content[0].text for msg in messages
  127. ) # agent
  128. assert messages[4].content[1].text.startswith('Create a new directory') # agent
  129. assert any(
  130. 'finished with exit code 0' in msg.content[0].text for msg in messages
  131. ) # user, observation
  132. assert (
  133. messages[5].content[0].text.startswith('OBSERVATION:\n\n')
  134. ) # user, observation
  135. # prompt cache is added to the system message
  136. assert messages[0].content[0].cache_prompt
  137. # and the first initial user message
  138. assert messages[1].content[0].cache_prompt
  139. # and to the last two user messages
  140. assert messages[3].content[0].cache_prompt
  141. assert messages[5].content[0].cache_prompt
  142. # reminder is added to the last user message
  143. assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text
  144. def test_prompt_caching_headers(codeact_agent, mock_event_stream):
  145. # Setup
  146. mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
  147. mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
  148. mock_short_term_history = MagicMock()
  149. mock_short_term_history.get_last_user_message.return_value = 'Hello, agent!'
  150. mock_state = Mock()
  151. mock_state.history = mock_short_term_history
  152. mock_state.max_iterations = 5
  153. mock_state.iteration = 0
  154. codeact_agent.reset()
  155. # Replace mock LLM completion with a function that checks headers and returns a structured response
  156. def check_headers(**kwargs):
  157. assert 'extra_headers' in kwargs
  158. assert 'anthropic-beta' in kwargs['extra_headers']
  159. assert kwargs['extra_headers']['anthropic-beta'] == 'prompt-caching-2024-07-31'
  160. # Create a mock response with the expected structure
  161. mock_response = Mock()
  162. mock_response.choices = [Mock()]
  163. mock_response.choices[0].message = Mock()
  164. mock_response.choices[0].message.content = 'Hello! How can I assist you today?'
  165. return mock_response
  166. codeact_agent.llm.completion = check_headers
  167. # Act
  168. result = codeact_agent.step(mock_state)
  169. # Assert
  170. assert isinstance(result, MessageAction)
  171. assert 'Hello! How can I assist you today?' in result.content