|
|
@@ -56,6 +56,7 @@ class Message(BaseModel):
|
|
|
cache_enabled: bool = False
|
|
|
vision_enabled: bool = False
|
|
|
# function calling
|
|
|
+ function_calling_enabled: bool = False
|
|
|
# - tool calls (from LLM)
|
|
|
tool_calls: list[ChatCompletionMessageToolCall] | None = None
|
|
|
# - tool execution result (to LLM)
|
|
|
@@ -72,22 +73,22 @@ class Message(BaseModel):
|
|
|
# - into a single string: for providers that don't support list of content items (e.g. no vision, no tool calls)
|
|
|
# - into a list of content items: the new APIs of providers with vision/prompt caching/tool calls
|
|
|
# NOTE: remove this when litellm or providers support the new API
|
|
|
- if (
|
|
|
- self.cache_enabled
|
|
|
- or self.vision_enabled
|
|
|
- or self.tool_call_id is not None
|
|
|
- or self.tool_calls is not None
|
|
|
- ):
|
|
|
+ if self.cache_enabled or self.vision_enabled or self.function_calling_enabled:
|
|
|
return self._list_serializer()
|
|
|
+ # some providers, like HF and Groq/llama, don't support a list here, but a single string
|
|
|
return self._string_serializer()
|
|
|
|
|
|
- def _string_serializer(self):
|
|
|
+ def _string_serializer(self) -> dict:
|
|
|
+ # convert content to a single string
|
|
|
content = '\n'.join(
|
|
|
item.text for item in self.content if isinstance(item, TextContent)
|
|
|
)
|
|
|
- return {'content': content, 'role': self.role}
|
|
|
+ message_dict: dict = {'content': content, 'role': self.role}
|
|
|
+
|
|
|
+ # add tool call keys if we have a tool call or response
|
|
|
+ return self._add_tool_call_keys(message_dict)
|
|
|
|
|
|
- def _list_serializer(self):
|
|
|
+ def _list_serializer(self) -> dict:
|
|
|
content: list[dict] = []
|
|
|
role_tool_with_prompt_caching = False
|
|
|
for item in self.content:
|
|
|
@@ -102,24 +103,37 @@ class Message(BaseModel):
|
|
|
elif isinstance(item, ImageContent) and self.vision_enabled:
|
|
|
content.extend(d)
|
|
|
|
|
|
- ret: dict = {'content': content, 'role': self.role}
|
|
|
+ message_dict: dict = {'content': content, 'role': self.role}
|
|
|
+
|
|
|
# pop content if it's empty
|
|
|
if not content or (
|
|
|
len(content) == 1
|
|
|
and content[0]['type'] == 'text'
|
|
|
and content[0]['text'] == ''
|
|
|
):
|
|
|
- ret.pop('content')
|
|
|
+ message_dict.pop('content')
|
|
|
|
|
|
if role_tool_with_prompt_caching:
|
|
|
- ret['cache_control'] = {'type': 'ephemeral'}
|
|
|
+ message_dict['cache_control'] = {'type': 'ephemeral'}
|
|
|
+
|
|
|
+ # add tool call keys if we have a tool call or response
|
|
|
+ return self._add_tool_call_keys(message_dict)
|
|
|
|
|
|
+ def _add_tool_call_keys(self, message_dict: dict) -> dict:
|
|
|
+ """Add tool call keys if we have a tool call or response.
|
|
|
+
|
|
|
+ NOTE: this is necessary for both native and non-native tool calling"""
|
|
|
+
|
|
|
+ # an assistant message calling a tool
|
|
|
+ if self.tool_calls is not None:
|
|
|
+ message_dict['tool_calls'] = self.tool_calls
|
|
|
+
|
|
|
+ # an observation message with tool response
|
|
|
if self.tool_call_id is not None:
|
|
|
assert (
|
|
|
self.name is not None
|
|
|
), 'name is required when tool_call_id is not None'
|
|
|
- ret['tool_call_id'] = self.tool_call_id
|
|
|
- ret['name'] = self.name
|
|
|
- if self.tool_calls:
|
|
|
- ret['tool_calls'] = self.tool_calls
|
|
|
- return ret
|
|
|
+ message_dict['tool_call_id'] = self.tool_call_id
|
|
|
+ message_dict['name'] = self.name
|
|
|
+
|
|
|
+ return message_dict
|