| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- from enum import Enum
- from typing import Literal
- from litellm import ChatCompletionMessageToolCall
- from pydantic import BaseModel, Field, model_serializer
- class ContentType(Enum):
- TEXT = 'text'
- IMAGE_URL = 'image_url'
- class Content(BaseModel):
- type: str
- cache_prompt: bool = False
- @model_serializer
- def serialize_model(self):
- raise NotImplementedError('Subclasses should implement this method.')
- class TextContent(Content):
- type: str = ContentType.TEXT.value
- text: str
- @model_serializer
- def serialize_model(self):
- data: dict[str, str | dict[str, str]] = {
- 'type': self.type,
- 'text': self.text,
- }
- if self.cache_prompt:
- data['cache_control'] = {'type': 'ephemeral'}
- return data
- class ImageContent(Content):
- type: str = ContentType.IMAGE_URL.value
- image_urls: list[str]
- @model_serializer
- def serialize_model(self):
- images: list[dict[str, str | dict[str, str]]] = []
- for url in self.image_urls:
- images.append({'type': self.type, 'image_url': {'url': url}})
- if self.cache_prompt and images:
- images[-1]['cache_control'] = {'type': 'ephemeral'}
- return images
- class Message(BaseModel):
- # NOTE: this is not the same as EventSource
- # These are the roles in the LLM's APIs
- role: Literal['user', 'system', 'assistant', 'tool']
- content: list[TextContent | ImageContent] = Field(default_factory=list)
- cache_enabled: bool = False
- vision_enabled: bool = False
- # function calling
- function_calling_enabled: bool = False
- # - tool calls (from LLM)
- tool_calls: list[ChatCompletionMessageToolCall] | None = None
- # - tool execution result (to LLM)
- tool_call_id: str | None = None
- name: str | None = None # name of the tool
- @property
- def contains_image(self) -> bool:
- return any(isinstance(content, ImageContent) for content in self.content)
- @model_serializer
- def serialize_model(self) -> dict:
- # We need two kinds of serializations:
- # - into a single string: for providers that don't support list of content items (e.g. no vision, no tool calls)
- # - into a list of content items: the new APIs of providers with vision/prompt caching/tool calls
- # NOTE: remove this when litellm or providers support the new API
- if self.cache_enabled or self.vision_enabled or self.function_calling_enabled:
- return self._list_serializer()
- # some providers, like HF and Groq/llama, don't support a list here, but a single string
- return self._string_serializer()
- def _string_serializer(self) -> dict:
- # convert content to a single string
- content = '\n'.join(
- item.text for item in self.content if isinstance(item, TextContent)
- )
- message_dict: dict = {'content': content, 'role': self.role}
- # add tool call keys if we have a tool call or response
- return self._add_tool_call_keys(message_dict)
- def _list_serializer(self) -> dict:
- content: list[dict] = []
- role_tool_with_prompt_caching = False
- for item in self.content:
- d = item.model_dump()
- # We have to remove cache_prompt for tool content and move it up to the message level
- # See discussion here for details: https://github.com/BerriAI/litellm/issues/6422#issuecomment-2438765472
- if self.role == 'tool' and item.cache_prompt:
- role_tool_with_prompt_caching = True
- d.pop('cache_control')
- if isinstance(item, TextContent):
- content.append(d)
- elif isinstance(item, ImageContent) and self.vision_enabled:
- content.extend(d)
- message_dict: dict = {'content': content, 'role': self.role}
- if role_tool_with_prompt_caching:
- message_dict['cache_control'] = {'type': 'ephemeral'}
- # add tool call keys if we have a tool call or response
- return self._add_tool_call_keys(message_dict)
- def _add_tool_call_keys(self, message_dict: dict) -> dict:
- """Add tool call keys if we have a tool call or response.
- NOTE: this is necessary for both native and non-native tool calling."""
- # an assistant message calling a tool
- if self.tool_calls is not None:
- message_dict['tool_calls'] = [
- {
- 'id': tool_call.id,
- 'type': 'function',
- 'function': {
- 'name': tool_call.function.name,
- 'arguments': tool_call.function.arguments,
- },
- }
- for tool_call in self.tool_calls
- ]
- # an observation message with tool response
- if self.tool_call_id is not None:
- assert (
- self.name is not None
- ), 'name is required when tool_call_id is not None'
- message_dict['tool_call_id'] = self.tool_call_id
- message_dict['name'] = self.name
- return message_dict
|