message.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from enum import Enum
  2. from typing import Union
  3. from pydantic import BaseModel, Field, model_serializer
  4. from typing_extensions import Literal
  5. from openhands.core.logger import openhands_logger as logger
  6. class ContentType(Enum):
  7. TEXT = 'text'
  8. IMAGE_URL = 'image_url'
  9. class Content(BaseModel):
  10. type: str
  11. cache_prompt: bool = False
  12. @model_serializer
  13. def serialize_model(self):
  14. raise NotImplementedError('Subclasses should implement this method.')
  15. class TextContent(Content):
  16. type: str = ContentType.TEXT.value
  17. text: str
  18. @model_serializer
  19. def serialize_model(self):
  20. data: dict[str, str | dict[str, str]] = {
  21. 'type': self.type,
  22. 'text': self.text,
  23. }
  24. if self.cache_prompt:
  25. data['cache_control'] = {'type': 'ephemeral'}
  26. return data
  27. class ImageContent(Content):
  28. type: str = ContentType.IMAGE_URL.value
  29. image_urls: list[str]
  30. @model_serializer
  31. def serialize_model(self):
  32. images: list[dict[str, str | dict[str, str]]] = []
  33. for url in self.image_urls:
  34. images.append({'type': self.type, 'image_url': {'url': url}})
  35. if self.cache_prompt and images:
  36. images[-1]['cache_control'] = {'type': 'ephemeral'}
  37. return images
  38. class Message(BaseModel):
  39. role: Literal['user', 'system', 'assistant']
  40. content: list[TextContent | ImageContent] = Field(default=list)
  41. @property
  42. def contains_image(self) -> bool:
  43. return any(isinstance(content, ImageContent) for content in self.content)
  44. @model_serializer
  45. def serialize_model(self) -> dict:
  46. content: list[dict[str, str | dict[str, str]]] = []
  47. for item in self.content:
  48. if isinstance(item, TextContent):
  49. content.append(item.model_dump())
  50. elif isinstance(item, ImageContent):
  51. content.extend(item.model_dump())
  52. return {'content': content, 'role': self.role}
  53. def format_messages(
  54. messages: Union[Message, list[Message]],
  55. with_images: bool,
  56. with_prompt_caching: bool,
  57. ) -> list[dict]:
  58. if not isinstance(messages, list):
  59. messages = [messages]
  60. if with_images or with_prompt_caching:
  61. return [message.model_dump() for message in messages]
  62. converted_messages = []
  63. for message in messages:
  64. content_parts = []
  65. role = 'user'
  66. if isinstance(message, str) and message:
  67. content_parts.append(message)
  68. elif isinstance(message, dict):
  69. role = message.get('role', 'user')
  70. if 'content' in message and message['content']:
  71. content_parts.append(message['content'])
  72. elif isinstance(message, Message):
  73. role = message.role
  74. for content in message.content:
  75. if isinstance(content, list):
  76. for item in content:
  77. if isinstance(item, TextContent) and item.text:
  78. content_parts.append(item.text)
  79. elif isinstance(content, TextContent) and content.text:
  80. content_parts.append(content.text)
  81. else:
  82. logger.error(
  83. f'>>> `message` is not a string, dict, or Message: {type(message)}'
  84. )
  85. if content_parts:
  86. content_str = '\n'.join(content_parts)
  87. converted_messages.append(
  88. {
  89. 'role': role,
  90. 'content': content_str,
  91. }
  92. )
  93. return converted_messages