浏览代码

Add prompt caching (Sonnet, Haiku only) (#3411)

* Add prompt caching

* remove anthropic-version from extra_headers

* change supports_prompt_caching method to attribute

* change caching strat and log cache statistics

* add reminder as a new message to fix caching

* fix unit test

* append reminder to the end of the last message content

* move token logs to post completion function

* fix unit test failure

* fix reminder and prompt caching

* unit tests for prompt caching

* add test

* clean up tests

* separate reminder, use latest two messages

* fix tests

---------

Co-authored-by: tobitege <10787084+tobitege@users.noreply.github.com>
Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
Kaushik Deka 1 年之前
父节点
当前提交
5bb931e4d6
共有 4 个文件被更改,包括 300 次插入29 次删除
  1. 36 25
      agenthub/codeact_agent/codeact_agent.py
  2. 10 1
      openhands/core/message.py
  3. 44 3
      openhands/llm/llm.py
  4. 210 0
      tests/unit/test_prompt_caching.py

+ 36 - 25
agenthub/codeact_agent/codeact_agent.py

@@ -172,26 +172,44 @@ class CodeActAgent(Agent):
         # prepare what we want to send to the LLM
         messages = self._get_messages(state)
 
-        response = self.llm.completion(
-            messages=[message.model_dump() for message in messages],
-            stop=[
+        params = {
+            'messages': [message.model_dump() for message in messages],
+            'stop': [
                 '</execute_ipython>',
                 '</execute_bash>',
                 '</execute_browse>',
             ],
-            temperature=0.0,
-        )
+            'temperature': 0.0,
+        }
+
+        if self.llm.supports_prompt_caching:
+            params['extra_headers'] = {
+                'anthropic-beta': 'prompt-caching-2024-07-31',
+            }
+
+        response = self.llm.completion(**params)
+
         return self.action_parser.parse(response)
 
     def _get_messages(self, state: State) -> list[Message]:
         messages: list[Message] = [
             Message(
                 role='system',
-                content=[TextContent(text=self.prompt_manager.system_message)],
+                content=[
+                    TextContent(
+                        text=self.prompt_manager.system_message,
+                        cache_prompt=self.llm.supports_prompt_caching,  # Cache system prompt
+                    )
+                ],
             ),
             Message(
                 role='user',
-                content=[TextContent(text=self.prompt_manager.initial_user_message)],
+                content=[
+                    TextContent(
+                        text=self.prompt_manager.initial_user_message,
+                        cache_prompt=self.llm.supports_prompt_caching,  # if the user asks the same query,
+                    )
+                ],
             ),
         ]
 
@@ -214,6 +232,16 @@ class CodeActAgent(Agent):
                 else:
                     messages.append(message)
 
+        # Add caching to the last 2 user messages
+        if self.llm.supports_prompt_caching:
+            user_turns_processed = 0
+            for message in reversed(messages):
+                if message.role == 'user' and user_turns_processed < 2:
+                    message.content[
+                        -1
+                    ].cache_prompt = True  # Last item inside the message content
+                    user_turns_processed += 1
+
         # the latest user message is important:
         # we want to remind the agent of the environment constraints
         latest_user_message = next(
@@ -225,25 +253,8 @@ class CodeActAgent(Agent):
             ),
             None,
         )
-
-        # Get the last user text inside content
         if latest_user_message:
-            latest_user_message_text = next(
-                (
-                    t
-                    for t in reversed(latest_user_message.content)
-                    if isinstance(t, TextContent)
-                )
-            )
-            # add a reminder to the prompt
             reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with <finish></finish>.'
-
-            if latest_user_message_text:
-                latest_user_message_text.text = (
-                    latest_user_message_text.text + reminder_text
-                )
-            else:
-                latest_user_message_text = TextContent(text=reminder_text)
-                latest_user_message.content.append(latest_user_message_text)
+            latest_user_message.content.append(TextContent(text=reminder_text))
 
         return messages

+ 10 - 1
openhands/core/message.py

@@ -11,6 +11,7 @@ class ContentType(Enum):
 
 class Content(BaseModel):
     type: ContentType
+    cache_prompt: bool = False
 
     @model_serializer
     def serialize_model(self):
@@ -23,7 +24,13 @@ class TextContent(Content):
 
     @model_serializer
     def serialize_model(self):
-        return {'type': self.type.value, 'text': self.text}
+        data: dict[str, str | dict[str, str]] = {
+            'type': self.type.value,
+            'text': self.text,
+        }
+        if self.cache_prompt:
+            data['cache_control'] = {'type': 'ephemeral'}
+        return data
 
 
 class ImageContent(Content):
@@ -35,6 +42,8 @@ class ImageContent(Content):
         images: list[dict[str, str | dict[str, str]]] = []
         for url in self.image_urls:
             images.append({'type': self.type.value, 'image_url': {'url': url}})
+        if self.cache_prompt and images:
+            images[-1]['cache_control'] = {'type': 'ephemeral'}
         return images
 
 

+ 44 - 3
openhands/llm/llm.py

@@ -35,6 +35,11 @@ __all__ = ['LLM']
 
 message_separator = '\n\n----------\n\n'
 
+cache_prompting_supported_models = [
+    'claude-3-5-sonnet-20240620',
+    'claude-3-haiku-20240307',
+]
+
 
 class LLM:
     """The LLM class represents a Language Model instance.
@@ -58,6 +63,9 @@ class LLM:
         self.config = copy.deepcopy(config)
         self.metrics = metrics if metrics is not None else Metrics()
         self.cost_metric_supported = True
+        self.supports_prompt_caching = (
+            self.config.model in cache_prompting_supported_models
+        )
 
         # Set up config attributes with default values to prevent AttributeError
         LLMConfig.set_missing_attributes(self.config)
@@ -184,6 +192,7 @@ class LLM:
 
             # log the response
             message_back = resp['choices'][0]['message']['content']
+
             llm_response_logger.debug(message_back)
 
             # post-process to log costs
@@ -421,19 +430,51 @@ class LLM:
     def supports_vision(self):
         return litellm.supports_vision(self.config.model)
 
-    def _post_completion(self, response: str) -> None:
+    def _post_completion(self, response) -> None:
         """Post-process the completion response."""
         try:
             cur_cost = self.completion_cost(response)
         except Exception:
             cur_cost = 0
+
+        stats = ''
         if self.cost_metric_supported:
-            logger.info(
-                'Cost: %.2f USD | Accumulated Cost: %.2f USD',
+            stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
                 cur_cost,
                 self.metrics.accumulated_cost,
             )
 
+        usage = response.get('usage')
+
+        if usage:
+            input_tokens = usage.get('prompt_tokens')
+            output_tokens = usage.get('completion_tokens')
+
+            if input_tokens:
+                stats += 'Input tokens: ' + str(input_tokens) + '\n'
+
+            if output_tokens:
+                stats += 'Output tokens: ' + str(output_tokens) + '\n'
+
+            model_extra = usage.get('model_extra', {})
+
+            cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
+            if cache_creation_input_tokens:
+                stats += (
+                    'Input tokens (cache write): '
+                    + str(cache_creation_input_tokens)
+                    + '\n'
+                )
+
+            cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
+            if cache_read_input_tokens:
+                stats += (
+                    'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
+                )
+
+        if stats:
+            logger.info(stats)
+
     def get_token_count(self, messages):
         """Get the number of tokens in a list of messages.
 

+ 210 - 0
tests/unit/test_prompt_caching.py

@@ -0,0 +1,210 @@
+from unittest.mock import MagicMock, Mock
+
+import pytest
+
+from agenthub.codeact_agent.codeact_agent import CodeActAgent
+from openhands.core.config import AgentConfig, LLMConfig
+from openhands.events import EventSource, EventStream
+from openhands.events.action import CmdRunAction, MessageAction
+from openhands.events.observation import CmdOutputObservation
+from openhands.llm.llm import LLM
+from openhands.storage import get_file_store
+
+
+@pytest.fixture
+def mock_llm():
+    llm = Mock(spec=LLM)
+    llm.config = LLMConfig(model='claude-3-5-sonnet-20240620')
+    llm.supports_prompt_caching = True
+    return llm
+
+
+@pytest.fixture
+def mock_event_stream(tmp_path):
+    file_store = get_file_store('local', str(tmp_path))
+    return EventStream('test_session', file_store)
+
+
+@pytest.fixture
+def codeact_agent(mock_llm):
+    config = AgentConfig()
+    return CodeActAgent(mock_llm, config)
+
+
+def test_get_messages_with_reminder(codeact_agent, mock_event_stream):
+    # Add some events to the stream
+    mock_event_stream.add_event(MessageAction('Initial user message'), EventSource.USER)
+    mock_event_stream.add_event(MessageAction('Sure!'), EventSource.AGENT)
+    mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
+    mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
+    mock_event_stream.add_event(MessageAction('Laaaaaaaast!'), EventSource.USER)
+
+    codeact_agent.reset()
+    messages = codeact_agent._get_messages(
+        Mock(history=mock_event_stream, max_iterations=5, iteration=0)
+    )
+
+    assert (
+        len(messages) == 6
+    )  # System, initial user + user message, agent message, last user message
+    assert messages[0].content[0].cache_prompt
+    assert messages[1].role == 'user'
+    assert messages[1].content[0].text.endswith("LET'S START!")
+    assert messages[1].content[1].text.endswith('Initial user message')
+    assert messages[1].content[0].cache_prompt
+
+    assert messages[3].role == 'user'
+    assert messages[3].content[0].text == ('Hello, agent!')
+    assert messages[4].role == 'assistant'
+    assert messages[4].content[0].text == 'Hello, user!'
+    assert messages[5].role == 'user'
+    assert messages[5].content[0].text.startswith('Laaaaaaaast!')
+    assert messages[5].content[0].cache_prompt
+    assert (
+        messages[5]
+        .content[1]
+        .text.endswith(
+            'ENVIRONMENT REMINDER: You have 5 turns left to complete the task. When finished reply with <finish></finish>.'
+        )
+    )
+
+
+def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
+    # Add multiple user and agent messages
+    for i in range(15):
+        mock_event_stream.add_event(
+            MessageAction(f'User message {i}'), EventSource.USER
+        )
+        mock_event_stream.add_event(
+            MessageAction(f'Agent message {i}'), EventSource.AGENT
+        )
+
+    codeact_agent.reset()
+    messages = codeact_agent._get_messages(
+        Mock(history=mock_event_stream, max_iterations=10, iteration=5)
+    )
+
+    # Check that only the last two user messages have cache_prompt=True
+    cached_user_messages = [
+        msg for msg in messages if msg.role == 'user' and msg.content[0].cache_prompt
+    ]
+    assert len(cached_user_messages) == 3  # Including the initial system message
+
+    # Verify that these are indeed the last two user messages
+    assert cached_user_messages[0].content[0].text.startswith('Here is an example')
+    assert cached_user_messages[1].content[0].text == 'User message 13'
+    assert cached_user_messages[2].content[0].text.startswith('User message 14')
+
+
+def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
+    # Add a mix of actions and observations
+    message_action_1 = MessageAction(
+        "Let's list the contents of the current directory."
+    )
+    mock_event_stream.add_event(message_action_1, EventSource.USER)
+
+    cmd_action_1 = CmdRunAction('ls -l', thought='List files in current directory')
+    mock_event_stream.add_event(cmd_action_1, EventSource.AGENT)
+
+    cmd_observation_1 = CmdOutputObservation(
+        content='total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt',
+        command_id=cmd_action_1._id,
+        command='ls -l',
+        exit_code=0,
+    )
+    mock_event_stream.add_event(cmd_observation_1, EventSource.USER)
+
+    message_action_2 = MessageAction("Now, let's create a new directory.")
+    mock_event_stream.add_event(message_action_2, EventSource.AGENT)
+
+    cmd_action_2 = CmdRunAction('mkdir new_directory', thought='Create a new directory')
+    mock_event_stream.add_event(cmd_action_2, EventSource.AGENT)
+
+    cmd_observation_2 = CmdOutputObservation(
+        content='',
+        command_id=cmd_action_2._id,
+        command='mkdir new_directory',
+        exit_code=0,
+    )
+    mock_event_stream.add_event(cmd_observation_2, EventSource.USER)
+
+    codeact_agent.reset()
+    messages = codeact_agent._get_messages(
+        Mock(history=mock_event_stream, max_iterations=5, iteration=0)
+    )
+
+    # Assert the presence of key elements in the messages
+    assert (
+        messages[1]
+        .content[1]
+        .text.startswith("Let's list the contents of the current directory.")
+    )  # user, included in the initial message
+    assert any(
+        'List files in current directory\n<execute_bash>\nls -l\n</execute_bash>'
+        in msg.content[0].text
+        for msg in messages
+    )  # agent
+    assert any(
+        'total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt'
+        in msg.content[0].text
+        for msg in messages
+    )  # user, observation
+    assert any(
+        "Now, let's create a new directory." in msg.content[0].text for msg in messages
+    )  # agent
+    assert messages[4].content[1].text.startswith('Create a new directory')  # agent
+    assert any(
+        'finished with exit code 0' in msg.content[0].text for msg in messages
+    )  # user, observation
+    assert (
+        messages[5].content[0].text.startswith('OBSERVATION:\n\n')
+    )  # user, observation
+
+    # prompt cache is added to the system message
+    assert messages[0].content[0].cache_prompt
+    # and the first initial user message
+    assert messages[1].content[0].cache_prompt
+    # and to the last two user messages
+    assert messages[3].content[0].cache_prompt
+    assert messages[5].content[0].cache_prompt
+
+    # reminder is added to the last user message
+    assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text
+
+
+def test_prompt_caching_headers(codeact_agent, mock_event_stream):
+    # Setup
+    mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
+    mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
+
+    mock_short_term_history = MagicMock()
+    mock_short_term_history.get_last_user_message.return_value = 'Hello, agent!'
+
+    mock_state = Mock()
+    mock_state.history = mock_short_term_history
+    mock_state.max_iterations = 5
+    mock_state.iteration = 0
+
+    codeact_agent.reset()
+
+    # Replace mock LLM completion with a function that checks headers and returns a structured response
+    def check_headers(**kwargs):
+        assert 'extra_headers' in kwargs
+        assert 'anthropic-beta' in kwargs['extra_headers']
+        assert kwargs['extra_headers']['anthropic-beta'] == 'prompt-caching-2024-07-31'
+
+        # Create a mock response with the expected structure
+        mock_response = Mock()
+        mock_response.choices = [Mock()]
+        mock_response.choices[0].message = Mock()
+        mock_response.choices[0].message.content = 'Hello! How can I assist you today?'
+        return mock_response
+
+    codeact_agent.llm.completion = check_headers
+
+    # Act
+    result = codeact_agent.step(mock_state)
+
+    # Assert
+    assert isinstance(result, MessageAction)
+    assert 'Hello! How can I assist you today?' in result.content