message.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from enum import Enum
  2. from typing import Literal
  3. from litellm import ChatCompletionMessageToolCall
  4. from pydantic import BaseModel, Field, model_serializer
  5. class ContentType(Enum):
  6. TEXT = 'text'
  7. IMAGE_URL = 'image_url'
  8. class Content(BaseModel):
  9. type: str
  10. cache_prompt: bool = False
  11. @model_serializer
  12. def serialize_model(self):
  13. raise NotImplementedError('Subclasses should implement this method.')
  14. class TextContent(Content):
  15. type: str = ContentType.TEXT.value
  16. text: str
  17. @model_serializer
  18. def serialize_model(self):
  19. data: dict[str, str | dict[str, str]] = {
  20. 'type': self.type,
  21. 'text': self.text,
  22. }
  23. if self.cache_prompt:
  24. data['cache_control'] = {'type': 'ephemeral'}
  25. return data
  26. class ImageContent(Content):
  27. type: str = ContentType.IMAGE_URL.value
  28. image_urls: list[str]
  29. @model_serializer
  30. def serialize_model(self):
  31. images: list[dict[str, str | dict[str, str]]] = []
  32. for url in self.image_urls:
  33. images.append({'type': self.type, 'image_url': {'url': url}})
  34. if self.cache_prompt and images:
  35. images[-1]['cache_control'] = {'type': 'ephemeral'}
  36. return images
  37. class Message(BaseModel):
  38. # NOTE: this is not the same as EventSource
  39. # These are the roles in the LLM's APIs
  40. role: Literal['user', 'system', 'assistant', 'tool']
  41. content: list[TextContent | ImageContent] = Field(default_factory=list)
  42. cache_enabled: bool = False
  43. vision_enabled: bool = False
  44. # function calling
  45. function_calling_enabled: bool = False
  46. # - tool calls (from LLM)
  47. tool_calls: list[ChatCompletionMessageToolCall] | None = None
  48. # - tool execution result (to LLM)
  49. tool_call_id: str | None = None
  50. name: str | None = None # name of the tool
  51. # force string serializer
  52. force_string_serializer: bool = False
  53. @property
  54. def contains_image(self) -> bool:
  55. return any(isinstance(content, ImageContent) for content in self.content)
  56. @model_serializer
  57. def serialize_model(self) -> dict:
  58. # We need two kinds of serializations:
  59. # - into a single string: for providers that don't support list of content items (e.g. no vision, no tool calls)
  60. # - into a list of content items: the new APIs of providers with vision/prompt caching/tool calls
  61. # NOTE: remove this when litellm or providers support the new API
  62. if not self.force_string_serializer and (
  63. self.cache_enabled or self.vision_enabled or self.function_calling_enabled
  64. ):
  65. return self._list_serializer()
  66. # some providers, like HF and Groq/llama, don't support a list here, but a single string
  67. return self._string_serializer()
  68. def _string_serializer(self) -> dict:
  69. # convert content to a single string
  70. content = '\n'.join(
  71. item.text for item in self.content if isinstance(item, TextContent)
  72. )
  73. message_dict: dict = {'content': content, 'role': self.role}
  74. # add tool call keys if we have a tool call or response
  75. return self._add_tool_call_keys(message_dict)
  76. def _list_serializer(self) -> dict:
  77. content: list[dict] = []
  78. role_tool_with_prompt_caching = False
  79. for item in self.content:
  80. d = item.model_dump()
  81. # We have to remove cache_prompt for tool content and move it up to the message level
  82. # See discussion here for details: https://github.com/BerriAI/litellm/issues/6422#issuecomment-2438765472
  83. if self.role == 'tool' and item.cache_prompt:
  84. role_tool_with_prompt_caching = True
  85. d.pop('cache_control')
  86. if isinstance(item, TextContent):
  87. content.append(d)
  88. elif isinstance(item, ImageContent) and self.vision_enabled:
  89. content.extend(d)
  90. message_dict: dict = {'content': content, 'role': self.role}
  91. if role_tool_with_prompt_caching:
  92. message_dict['cache_control'] = {'type': 'ephemeral'}
  93. # add tool call keys if we have a tool call or response
  94. return self._add_tool_call_keys(message_dict)
  95. def _add_tool_call_keys(self, message_dict: dict) -> dict:
  96. """Add tool call keys if we have a tool call or response.
  97. NOTE: this is necessary for both native and non-native tool calling."""
  98. # an assistant message calling a tool
  99. if self.tool_calls is not None:
  100. message_dict['tool_calls'] = [
  101. {
  102. 'id': tool_call.id,
  103. 'type': 'function',
  104. 'function': {
  105. 'name': tool_call.function.name,
  106. 'arguments': tool_call.function.arguments,
  107. },
  108. }
  109. for tool_call in self.tool_calls
  110. ]
  111. # an observation message with tool response
  112. if self.tool_call_id is not None:
  113. assert (
  114. self.name is not None
  115. ), 'name is required when tool_call_id is not None'
  116. message_dict['tool_call_id'] = self.tool_call_id
  117. message_dict['name'] = self.name
  118. return message_dict