Explorar o código

(feat) making prompt caching optional instead of enabled default (#3689)

* (feat) making prompt caching optional instead of enabled default

At present, only the Claude models support prompt caching as a experimental feature, therefore, this feature should be implemented as an optional setting rather than being enabled by default.

Signed-off-by: Yi Lin <teroincn@gmail.com>

* handle the conflict

* fix unittest mock return value

* fix lint error in whitespace

---------

Signed-off-by: Yi Lin <teroincn@gmail.com>
niliy01 hai 1 ano
pai
achega
82a154f7e7

+ 16 - 10
agenthub/codeact_agent/codeact_agent.py

@@ -201,6 +201,12 @@ class CodeActAgent(Agent):
             ],
             'temperature': 0.0,
         }
+
+        if self.llm.is_caching_prompt_active():
+            params['extra_headers'] = {
+                'anthropic-beta': 'prompt-caching-2024-07-31',
+            }
+
         try:
             response = self.llm.completion(**params)
         except Exception:
@@ -217,7 +223,7 @@ class CodeActAgent(Agent):
                 content=[
                     TextContent(
                         text=self.prompt_manager.system_message,
-                        cache_prompt=self.llm.supports_prompt_caching,
+                        cache_prompt=self.llm.is_caching_prompt_active(),  # Cache system prompt
                     )
                 ],
             ),
@@ -226,7 +232,7 @@ class CodeActAgent(Agent):
                 content=[
                     TextContent(
                         text=self.prompt_manager.initial_user_message,
-                        cache_prompt=self.llm.supports_prompt_caching,
+                        cache_prompt=self.llm.is_caching_prompt_active(),  # if the user asks the same query,
                     )
                 ],
             ),
@@ -252,14 +258,14 @@ class CodeActAgent(Agent):
                     messages.append(message)
 
         # Add caching to the last 2 user messages
-        if self.llm.supports_prompt_caching:
-            user_messages = list(
-                islice((m for m in reversed(messages) if m.role == 'user'), 2)
-            )
-            for message in user_messages:
-                message.content[
-                    -1
-                ].cache_prompt = True  # Last item inside the message content
+        if self.llm.is_caching_prompt_active():
+            user_turns_processed = 0
+            for message in reversed(messages):
+                if message.role == 'user' and user_turns_processed < 2:
+                    message.content[
+                        -1
+                    ].cache_prompt = True  # Last item inside the message content
+                    user_turns_processed += 1
 
         # The latest user message is important:
         # we want to remind the agent of the environment constraints

+ 3 - 0
config.template.toml

@@ -141,6 +141,9 @@ model = "gpt-4o"
 # Drop any unmapped (unsupported) params without causing an exception
 #drop_params = false
 
+# Using the prompt caching feature provided by the LLM
+#caching_prompt = false
+
 # Base URL for the OLLAMA API
 #ollama_base_url = ""
 

+ 1 - 0
docs/modules/usage/llms/llms.md

@@ -44,6 +44,7 @@ The following environment variables might be necessary for some LLMs/providers:
 * `LLM_EMBEDDING_DEPLOYMENT_NAME`
 * `LLM_DROP_PARAMS`
 * `LLM_DISABLE_VISION`
+* `LLM_CACHING_PROMPT`
 
 We have a few guides for running OpenHands with specific model providers:
 

+ 2 - 0
openhands/core/config.py

@@ -52,6 +52,7 @@ class LLMConfig:
         ollama_base_url: The base URL for the OLLAMA API.
         drop_params: Drop any unmapped (unsupported) params without causing an exception.
         disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
+        caching_prompt: Using the prompt caching feature provided by the LLM.
     """
 
     model: str = 'gpt-4o'
@@ -80,6 +81,7 @@ class LLMConfig:
     ollama_base_url: str | None = None
     drop_params: bool | None = None
     disable_vision: bool | None = None
+    caching_prompt: bool = False
 
     def defaults_to_dict(self) -> dict:
         """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""

+ 1 - 0
openhands/core/schema/config.py

@@ -21,6 +21,7 @@ class ConfigType(str, Enum):
     LLM_API_KEY = 'LLM_API_KEY'
     LLM_API_VERSION = 'LLM_API_VERSION'
     LLM_BASE_URL = 'LLM_BASE_URL'
+    LLM_CACHING_PROMPT = 'LLM_CACHING_PROMPT'
     LLM_CUSTOM_LLM_PROVIDER = 'LLM_CUSTOM_LLM_PROVIDER'
     LLM_DROP_PARAMS = 'LLM_DROP_PARAMS'
     LLM_EMBEDDING_BASE_URL = 'LLM_EMBEDDING_BASE_URL'

+ 12 - 6
openhands/llm/llm.py

@@ -70,11 +70,6 @@ class LLM:
         # Set up config attributes with default values to prevent AttributeError
         LLMConfig.set_missing_attributes(self.config)
 
-        self.supports_prompt_caching = (
-            self.vision_is_active()
-            and self.config.model in cache_prompting_supported_models
-        )
-
         # litellm actually uses base Exception here for unknown model
         self.model_info = None
         try:
@@ -190,7 +185,7 @@ class LLM:
                 if debug_str:
                     debug_message += message_separator + debug_str
 
-            if self.supports_prompt_caching:
+            if self.is_caching_prompt_active():
                 # Anthropic-specific prompt caching
                 if 'claude-3' in self.config.model:
                     kwargs['extra_headers'] = {
@@ -467,6 +462,17 @@ class LLM:
         except Exception:
             return False
 
+    def is_caching_prompt_active(self) -> bool:
+        """Check if prompt caching is enabled and supported for current model.
+
+        Returns:
+            boolean: True if prompt caching is active for the given model.
+        """
+        return (
+            self.config.caching_prompt is True
+            and self.config.model in cache_prompting_supported_models
+        )
+
     def _post_completion(self, response) -> None:
         """Post-process the completion response."""
         try:

+ 2 - 2
tests/unit/test_prompt_caching.py

@@ -14,8 +14,8 @@ from openhands.storage import get_file_store
 @pytest.fixture
 def mock_llm():
     llm = Mock(spec=LLM)
-    llm.config = LLMConfig(model='claude-3-5-sonnet-20240620')
-    llm.supports_prompt_caching = True
+    llm.config = LLMConfig(model='claude-3-5-sonnet-20240620', caching_prompt=True)
+    llm.is_caching_prompt_active.return_value = True
     return llm