message.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from enum import Enum
  2. from typing import Literal
  3. from pydantic import BaseModel, Field, model_serializer
  4. class ContentType(Enum):
  5. TEXT = 'text'
  6. IMAGE_URL = 'image_url'
  7. class Content(BaseModel):
  8. type: str
  9. cache_prompt: bool = False
  10. @model_serializer
  11. def serialize_model(self):
  12. raise NotImplementedError('Subclasses should implement this method.')
  13. class TextContent(Content):
  14. type: str = ContentType.TEXT.value
  15. text: str
  16. @model_serializer
  17. def serialize_model(self):
  18. data: dict[str, str | dict[str, str]] = {
  19. 'type': self.type,
  20. 'text': self.text,
  21. }
  22. if self.cache_prompt:
  23. data['cache_control'] = {'type': 'ephemeral'}
  24. return data
  25. class ImageContent(Content):
  26. type: str = ContentType.IMAGE_URL.value
  27. image_urls: list[str]
  28. @model_serializer
  29. def serialize_model(self):
  30. images: list[dict[str, str | dict[str, str]]] = []
  31. for url in self.image_urls:
  32. images.append({'type': self.type, 'image_url': {'url': url}})
  33. if self.cache_prompt and images:
  34. images[-1]['cache_control'] = {'type': 'ephemeral'}
  35. return images
  36. class Message(BaseModel):
  37. role: Literal['user', 'system', 'assistant']
  38. content: list[TextContent | ImageContent] = Field(default=list)
  39. cache_enabled: bool = False
  40. vision_enabled: bool = False
  41. @property
  42. def contains_image(self) -> bool:
  43. return any(isinstance(content, ImageContent) for content in self.content)
  44. @model_serializer
  45. def serialize_model(self) -> dict:
  46. content: list[dict] | str
  47. # two kinds of serializer:
  48. # 1. vision serializer: when prompt caching or vision is enabled
  49. # 2. single text serializer: for other cases
  50. # remove this when liteLLM or providers support this format translation
  51. if self.cache_enabled or self.vision_enabled:
  52. # when prompt caching or vision is enabled, use vision serializer
  53. content = []
  54. for item in self.content:
  55. if isinstance(item, TextContent):
  56. content.append(item.model_dump())
  57. elif isinstance(item, ImageContent):
  58. content.extend(item.model_dump())
  59. else:
  60. # for other cases, concatenate all text content
  61. # into a single string per message
  62. content = '\n'.join(
  63. item.text for item in self.content if isinstance(item, TextContent)
  64. )
  65. return {'content': content, 'role': self.role}