Эх сурвалжийг харах

fix: serialize tool calls (#5553)

Co-authored-by: openhands <openhands@all-hands.dev>
Engel Nyst 11 сар өмнө
parent
commit
d733bc6bdd

+ 1 - 0
openhands/agenthub/codeact_agent/codeact_agent.py

@@ -166,6 +166,7 @@ class CodeActAgent(Agent):
 
             # Add the LLM message (assistant) that initiated the tool calls
             # (overwrites any previous message with the same response_id)
+            logger.debug(f'Tool calls type: {type(assistant_msg.tool_calls)}, value: {assistant_msg.tool_calls}')
             pending_tool_call_action_messages[llm_response.id] = Message(
                 role=assistant_msg.role,
                 # tool call content SHOULD BE a string

+ 12 - 2
openhands/core/message.py

@@ -114,11 +114,21 @@ class Message(BaseModel):
     def _add_tool_call_keys(self, message_dict: dict) -> dict:
         """Add tool call keys if we have a tool call or response.
 
-        NOTE: this is necessary for both native and non-native tool calling"""
+        NOTE: this is necessary for both native and non-native tool calling."""
 
         # an assistant message calling a tool
         if self.tool_calls is not None:
-            message_dict['tool_calls'] = self.tool_calls
+            message_dict['tool_calls'] = [
+                {
+                    'id': tool_call.id,
+                    'type': 'function',
+                    'function': {
+                        'name': tool_call.function.name,
+                        'arguments': tool_call.function.arguments,
+                    },
+                }
+                for tool_call in self.tool_calls
+            ]
 
         # an observation message with tool response
         if self.tool_call_id is not None:

+ 89 - 0
openhands/core/message_format.md

@@ -0,0 +1,89 @@
+# OpenHands Message Format and litellm Integration
+
+## Overview
+
+OpenHands uses its own `Message` class (`openhands/core/message.py`) which provides rich content support while maintaining compatibility with litellm's message handling system.
+
+## Class Structure
+
+Our `Message` class (`openhands/core/message.py`):
+```python
+class Message(BaseModel):
+    role: Literal['user', 'system', 'assistant', 'tool']
+    content: list[TextContent | ImageContent] = Field(default_factory=list)
+    cache_enabled: bool = False
+    vision_enabled: bool = False
+    condensable: bool = True
+    function_calling_enabled: bool = False
+    tool_calls: list[ChatCompletionMessageToolCall] | None = None
+    tool_call_id: str | None = None
+    name: str | None = None
+    event_id: int = -1
+```
+
+litellm's `Message` class (`litellm/types/utils.py`):
+```python
+class Message(OpenAIObject):
+    content: Optional[str]
+    role: Literal["assistant", "user", "system", "tool", "function"]
+    tool_calls: Optional[List[ChatCompletionMessageToolCall]]
+    function_call: Optional[FunctionCall]
+    audio: Optional[ChatCompletionAudioResponse] = None
+```
+
+## How It Works
+
+1. **Message Creation**: Our `Message` class is a Pydantic model that supports rich content (text and images) through its `content` field.
+
+2. **Serialization**: The class uses Pydantic's `@model_serializer` to convert messages into dictionaries that litellm can understand. We have two serialization methods:
+   ```python
+   def _string_serializer(self) -> dict:
+       # convert content to a single string
+       content = '\n'.join(item.text for item in self.content if isinstance(item, TextContent))
+       message_dict: dict = {'content': content, 'role': self.role}
+       return self._add_tool_call_keys(message_dict)
+
+   def _list_serializer(self) -> dict:
+       content: list[dict] = []
+       for item in self.content:
+           d = item.model_dump()
+           if isinstance(item, TextContent):
+               content.append(d)
+           elif isinstance(item, ImageContent) and self.vision_enabled:
+               content.extend(d)
+       return {'content': content, 'role': self.role}
+   ```
+
+   The appropriate serializer is chosen based on the message's capabilities:
+   ```python
+   @model_serializer
+   def serialize_model(self) -> dict:
+       if self.cache_enabled or self.vision_enabled or self.function_calling_enabled:
+           return self._list_serializer()
+       return self._string_serializer()
+   ```
+
+3. **Tool Call Handling**: Tool calls require special attention in serialization because:
+   - They need to work with litellm's API calls (which accept both dicts and objects)
+   - They need to be properly serialized for token counting
+   - They need to maintain compatibility with different LLM providers' formats
+
+4. **litellm Integration**: When we pass our messages to `litellm.completion()`, litellm doesn't care about the message class type - it works with the dictionary representation. This works because:
+   - litellm's transformation code (e.g., `litellm/llms/anthropic/chat/transformation.py`) processes messages based on their structure, not their type
+   - our serialization produces dictionaries that match litellm's expected format
+   - litellm handles rich content by looking at the message structure, supporting both simple string content and lists of content items
+
+5. **Provider-Specific Handling**: litellm then transforms these messages into provider-specific formats (e.g., Anthropic, OpenAI) through its transformation layers, which know how to handle both simple and rich content structures.
+
+### Token Counting
+
+To use litellm's token counter, we need to make sure that all message components (including tool calls) are properly serialized to dictionaries. This is because:
+- litellm's token counter expects dictionary structures
+- Tool calls need to be included in the token count
+- Different providers may count tokens differently for structured content
+
+## Note
+
+- We don't need to inherit from litellm's `Message` class because litellm works with dictionary representations, not class types
+- Our rich content model is more sophisticated than litellm's basic string content, but litellm handles it correctly through its transformation layers
+- The compatibility is maintained through proper serialization rather than inheritance

+ 54 - 0
tests/unit/test_message_serialization.py

@@ -1,3 +1,5 @@
+from litellm import ChatCompletionMessageToolCall
+
 from openhands.core.message import ImageContent, Message, TextContent
 
 
@@ -114,3 +116,55 @@ def test_message_with_mixed_content_and_vision_disabled():
     assert serialized_message == expected_serialized_message
     # Assert that images exist in the original message
     assert message.contains_image
+
+
+def test_message_tool_call_serialization():
+    """Test that tool calls are properly serialized into dicts for token counting."""
+    # Create a tool call
+    tool_call = ChatCompletionMessageToolCall(
+        id='call_123',
+        type='function',
+        function={'name': 'test_function', 'arguments': '{"arg1": "value1"}'},
+    )
+
+    # Create a message with the tool call
+    message = Message(
+        role='assistant',
+        content=[TextContent(text='Test message')],
+        tool_calls=[tool_call],
+    )
+
+    # Serialize the message
+    serialized = message.model_dump()
+
+    # Check that tool calls are properly serialized
+    assert 'tool_calls' in serialized
+    assert isinstance(serialized['tool_calls'], list)
+    assert len(serialized['tool_calls']) == 1
+
+    tool_call_dict = serialized['tool_calls'][0]
+    assert isinstance(tool_call_dict, dict)
+    assert tool_call_dict['id'] == 'call_123'
+    assert tool_call_dict['type'] == 'function'
+    assert tool_call_dict['function']['name'] == 'test_function'
+    assert tool_call_dict['function']['arguments'] == '{"arg1": "value1"}'
+
+
+def test_message_tool_response_serialization():
+    """Test that tool responses are properly serialized."""
+    # Create a message with tool response
+    message = Message(
+        role='tool',
+        content=[TextContent(text='Function result')],
+        tool_call_id='call_123',
+        name='test_function',
+    )
+
+    # Serialize the message
+    serialized = message.model_dump()
+
+    # Check that tool response fields are properly serialized
+    assert 'tool_call_id' in serialized
+    assert serialized['tool_call_id'] == 'call_123'
+    assert 'name' in serialized
+    assert serialized['name'] == 'test_function'