Procházet zdrojové kódy

Log cache hit/miss for deepseek (#4343)

Engel Nyst před 1 rokem
rodič
revize
caa77cf7a6
1 změnil soubory, kde provedl 17 přidání a 16 odebrání
  1. 17 16
      openhands/llm/llm.py

+ 17 - 16
openhands/llm/llm.py

@@ -9,7 +9,7 @@ from openhands.core.config import LLMConfig
 with warnings.catch_warnings():
     warnings.simplefilter('ignore')
     import litellm
-from litellm import ModelInfo
+from litellm import ModelInfo, PromptTokensDetails
 from litellm import completion as litellm_completion
 from litellm import completion_cost as litellm_completion_cost
 from litellm.exceptions import (
@@ -288,22 +288,23 @@ class LLM(RetryMixin, DebugMixin):
                     + '\n'
                 )
 
-            # read the prompt caching status as received from the provider
-            model_extra = usage.get('model_extra', {})
-
-            cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
-            if cache_creation_input_tokens:
-                stats += (
-                    'Input tokens (cache write): '
-                    + str(cache_creation_input_tokens)
-                    + '\n'
-                )
+            # read the prompt cache hit, if any
+            prompt_tokens_details: PromptTokensDetails = usage.get(
+                'prompt_tokens_details'
+            )
+            cache_hit_tokens = (
+                prompt_tokens_details.cached_tokens if prompt_tokens_details else None
+            )
+            if cache_hit_tokens:
+                stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n'
 
-            cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
-            if cache_read_input_tokens:
-                stats += (
-                    'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
-                )
+            # For Anthropic, the cache writes have a different cost than regular input tokens
+            # but litellm doesn't separate them in the usage stats
+            # so we can read it from the provider-specific extra field
+            model_extra = usage.get('model_extra', {})
+            cache_write_tokens = model_extra.get('cache_creation_input_tokens')
+            if cache_write_tokens:
+                stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n'
 
         # log the stats
         if stats: