Browse Source

Small fix and addition for token counting (#5550)

Co-authored-by: openhands <openhands@all-hands.dev>
Engel Nyst 1 year ago
parent
commit
590ebb6e47

+ 4 - 0
config.template.toml

@@ -172,6 +172,10 @@ model = "gpt-4o"
 # If model is vision capable, this option allows to disable image processing (useful for cost reduction).
 #disable_vision = true
 
+# Custom tokenizer to use for token counting
+# https://docs.litellm.ai/docs/completion/token_usage
+#custom_tokenizer = ""
+
 [llm.gpt4o-mini]
 api_key = "your-api-key"
 model = "gpt-4o"

+ 2 - 0
openhands/core/config/llm_config.py

@@ -43,6 +43,7 @@ class LLMConfig:
         log_completions: Whether to log LLM completions to the state.
         log_completions_folder: The folder to log LLM completions to. Required if log_completions is True.
         draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985).
+        custom_tokenizer: A custom tokenizer to use for token counting.
     """
 
     model: str = 'claude-3-5-sonnet-20241022'
@@ -77,6 +78,7 @@ class LLMConfig:
     log_completions: bool = False
     log_completions_folder: str = os.path.join(LOG_DIR, 'completions')
     draft_editor: Optional['LLMConfig'] = None
+    custom_tokenizer: str | None = None
 
     def defaults_to_dict(self) -> dict:
         """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""

+ 39 - 7
openhands/llm/llm.py

@@ -25,6 +25,7 @@ from litellm.exceptions import (
     ServiceUnavailableError,
 )
 from litellm.types.utils import CostPerToken, ModelResponse, Usage
+from litellm.utils import create_pretrained_tokenizer
 
 from openhands.core.exceptions import CloudFlareBlockageError
 from openhands.core.logger import openhands_logger as logger
@@ -122,6 +123,13 @@ class LLM(RetryMixin, DebugMixin):
         if self.is_function_calling_active():
             logger.debug('LLM: model supports function calling')
 
+        # if using a custom tokenizer, make sure it's loaded and accessible in the format expected by litellm
+        if self.config.custom_tokenizer is not None:
+            self.tokenizer = create_pretrained_tokenizer(self.config.custom_tokenizer)
+        else:
+            self.tokenizer = None
+
+        # set up the completion function
         self._completion = partial(
             litellm_completion,
             model=self.config.model,
@@ -491,19 +499,43 @@ class LLM(RetryMixin, DebugMixin):
 
         return cur_cost
 
-    def get_token_count(self, messages) -> int:
-        """Get the number of tokens in a list of messages.
+    def get_token_count(self, messages: list[dict] | list[Message]) -> int:
+        """Get the number of tokens in a list of messages. Use dicts for better token counting.
 
         Args:
-            messages (list): A list of messages.
-
+            messages (list): A list of messages, either as a list of dicts or as a list of Message objects.
         Returns:
             int: The number of tokens.
         """
+        # attempt to convert Message objects to dicts, litellm expects dicts
+        if (
+            isinstance(messages, list)
+            and len(messages) > 0
+            and isinstance(messages[0], Message)
+        ):
+            logger.info(
+                'Message objects now include serialized tool calls in token counting'
+            )
+            messages = self.format_messages_for_llm(messages)  # type: ignore
+
+        # try to get the token count with the default litellm tokenizers
+        # or the custom tokenizer if set for this LLM configuration
         try:
-            return litellm.token_counter(model=self.config.model, messages=messages)
-        except Exception:
-            # TODO: this is to limit logspam in case token count is not supported
+            return litellm.token_counter(
+                model=self.config.model,
+                messages=messages,
+                custom_tokenizer=self.tokenizer,
+            )
+        except Exception as e:
+            # limit logspam in case token count is not supported
+            logger.error(
+                f'Error getting token count for\n model {self.config.model}\n{e}'
+                + (
+                    f'\ncustom_tokenizer: {self.config.custom_tokenizer}'
+                    if self.config.custom_tokenizer is not None
+                    else ''
+                )
+            )
             return 0
 
     def _is_local(self) -> bool:

+ 19 - 7
openhands/server/session/session.py

@@ -1,6 +1,6 @@
 import asyncio
-from copy import deepcopy
 import time
+from copy import deepcopy
 
 import socketio
 
@@ -9,7 +9,6 @@ from openhands.core.config import AppConfig
 from openhands.core.const.guide_url import TROUBLESHOOTING_URL
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.schema import AgentState
-from openhands.core.schema.config import ConfigType
 from openhands.events.action import MessageAction, NullAction
 from openhands.events.event import Event, EventSource
 from openhands.events.observation import (
@@ -68,15 +67,28 @@ class Session:
         )
         # Extract the agent-relevant arguments from the request
         agent_cls = session_init_data.agent or self.config.default_agent
-        self.config.security.confirmation_mode = self.config.security.confirmation_mode if session_init_data.confirmation_mode is None else session_init_data.confirmation_mode
-        self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
+        self.config.security.confirmation_mode = (
+            self.config.security.confirmation_mode
+            if session_init_data.confirmation_mode is None
+            else session_init_data.confirmation_mode
+        )
+        self.config.security.security_analyzer = (
+            session_init_data.security_analyzer
+            or self.config.security.security_analyzer
+        )
         max_iterations = session_init_data.max_iterations or self.config.max_iterations
         # override default LLM config
 
         default_llm_config = self.config.get_llm_config()
-        default_llm_config.model = session_init_data.llm_model or default_llm_config.model
-        default_llm_config.api_key = session_init_data.llm_api_key or default_llm_config.api_key
-        default_llm_config.base_url = session_init_data.llm_base_url or default_llm_config.base_url
+        default_llm_config.model = (
+            session_init_data.llm_model or default_llm_config.model
+        )
+        default_llm_config.api_key = (
+            session_init_data.llm_api_key or default_llm_config.api_key
+        )
+        default_llm_config.base_url = (
+            session_init_data.llm_base_url or default_llm_config.base_url
+        )
 
         # TODO: override other LLM config & agent config groups (#2075)
 

+ 1 - 2
openhands/server/session/session_init_data.py

@@ -1,5 +1,3 @@
-
-
 from dataclasses import dataclass
 
 
@@ -8,6 +6,7 @@ class SessionInitData:
     """
     Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
     """
+
     language: str | None = None
     agent: str | None = None
     max_iterations: int | None = None

+ 1 - 0
tests/unit/test_config.py

@@ -428,6 +428,7 @@ def test_api_keys_repr_str():
         'aws_secret_access_key',
         'input_cost_per_token',
         'output_cost_per_token',
+        'custom_tokenizer',
     ]
     for attr_name in dir(LLMConfig):
         if (

+ 78 - 0
tests/unit/test_llm.py

@@ -11,6 +11,7 @@ from litellm.exceptions import (
 
 from openhands.core.config import LLMConfig
 from openhands.core.exceptions import OperationCancelled
+from openhands.core.message import Message, TextContent
 from openhands.llm.llm import LLM
 from openhands.llm.metrics import Metrics
 
@@ -21,6 +22,7 @@ def mock_logger(monkeypatch):
     mock_logger = MagicMock()
     monkeypatch.setattr('openhands.llm.debug_mixin.llm_prompt_logger', mock_logger)
     monkeypatch.setattr('openhands.llm.debug_mixin.llm_response_logger', mock_logger)
+    monkeypatch.setattr('openhands.llm.llm.logger', mock_logger)
     return mock_logger
 
 
@@ -397,3 +399,79 @@ def test_llm_cloudflare_blockage(mock_litellm_completion, default_config):
 
     # Ensure the completion was called
     mock_litellm_completion.assert_called_once()
+
+
+@patch('openhands.llm.llm.litellm.token_counter')
+def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
+    mock_token_counter.return_value = 42
+    llm = LLM(default_config)
+    messages = [{'role': 'user', 'content': 'Hello!'}]
+
+    token_count = llm.get_token_count(messages)
+
+    assert token_count == 42
+    mock_token_counter.assert_called_once_with(
+        model=default_config.model, messages=messages, custom_tokenizer=None
+    )
+
+
+@patch('openhands.llm.llm.litellm.token_counter')
+def test_get_token_count_with_message_objects(
+    mock_token_counter, default_config, mock_logger
+):
+    llm = LLM(default_config)
+
+    # Create a Message object and its equivalent dict
+    message_obj = Message(role='user', content=[TextContent(text='Hello!')])
+    message_dict = {'role': 'user', 'content': 'Hello!'}
+
+    # Mock token counter to return different values for each call
+    mock_token_counter.side_effect = [42, 42]  # Same value for both cases
+
+    # Get token counts for both formats
+    token_count_obj = llm.get_token_count([message_obj])
+    token_count_dict = llm.get_token_count([message_dict])
+
+    # Verify both formats get the same token count
+    assert token_count_obj == token_count_dict
+    assert mock_token_counter.call_count == 2
+
+
+@patch('openhands.llm.llm.litellm.token_counter')
+@patch('openhands.llm.llm.create_pretrained_tokenizer')
+def test_get_token_count_with_custom_tokenizer(
+    mock_create_tokenizer, mock_token_counter, default_config
+):
+    mock_tokenizer = MagicMock()
+    mock_create_tokenizer.return_value = mock_tokenizer
+    mock_token_counter.return_value = 42
+
+    config = copy.deepcopy(default_config)
+    config.custom_tokenizer = 'custom/tokenizer'
+    llm = LLM(config)
+    messages = [{'role': 'user', 'content': 'Hello!'}]
+
+    token_count = llm.get_token_count(messages)
+
+    assert token_count == 42
+    mock_create_tokenizer.assert_called_once_with('custom/tokenizer')
+    mock_token_counter.assert_called_once_with(
+        model=config.model, messages=messages, custom_tokenizer=mock_tokenizer
+    )
+
+
+@patch('openhands.llm.llm.litellm.token_counter')
+def test_get_token_count_error_handling(
+    mock_token_counter, default_config, mock_logger
+):
+    mock_token_counter.side_effect = Exception('Token counting failed')
+    llm = LLM(default_config)
+    messages = [{'role': 'user', 'content': 'Hello!'}]
+
+    token_count = llm.get_token_count(messages)
+
+    assert token_count == 0
+    mock_token_counter.assert_called_once()
+    mock_logger.error.assert_called_once_with(
+        'Error getting token count for\n model gpt-4o\nToken counting failed'
+    )