message.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from enum import Enum
  2. from typing import Union
  3. from pydantic import BaseModel, Field, model_serializer
  4. from typing_extensions import Literal
  5. from openhands.core.logger import openhands_logger as logger
  6. class ContentType(Enum):
  7. TEXT = 'text'
  8. IMAGE_URL = 'image_url'
  9. class Content(BaseModel):
  10. type: str
  11. cache_prompt: bool = False
  12. @model_serializer
  13. def serialize_model(self):
  14. raise NotImplementedError('Subclasses should implement this method.')
  15. class TextContent(Content):
  16. type: str = ContentType.TEXT.value
  17. text: str
  18. @model_serializer
  19. def serialize_model(self):
  20. data: dict[str, str | dict[str, str]] = {
  21. 'type': self.type,
  22. 'text': self.text,
  23. }
  24. if self.cache_prompt:
  25. data['cache_control'] = {'type': 'ephemeral'}
  26. return data
  27. class ImageContent(Content):
  28. type: str = ContentType.IMAGE_URL.value
  29. image_urls: list[str]
  30. @model_serializer
  31. def serialize_model(self):
  32. images: list[dict[str, str | dict[str, str]]] = []
  33. for url in self.image_urls:
  34. images.append({'type': self.type, 'image_url': {'url': url}})
  35. if self.cache_prompt and images:
  36. images[-1]['cache_control'] = {'type': 'ephemeral'}
  37. return images
  38. class Message(BaseModel):
  39. role: Literal['user', 'system', 'assistant']
  40. content: list[TextContent | ImageContent] = Field(default=list)
  41. @property
  42. def contains_image(self) -> bool:
  43. return any(isinstance(content, ImageContent) for content in self.content)
  44. @model_serializer
  45. def serialize_model(self) -> dict:
  46. content: list[dict[str, str | dict[str, str]]] = []
  47. for item in self.content:
  48. if isinstance(item, TextContent):
  49. content.append(item.model_dump())
  50. elif isinstance(item, ImageContent):
  51. content.extend(item.model_dump())
  52. return {'content': content, 'role': self.role}
  53. def format_messages(
  54. messages: Union[Message, list[Message]], with_images: bool
  55. ) -> list[dict]:
  56. if not isinstance(messages, list):
  57. messages = [messages]
  58. if with_images:
  59. return [message.model_dump() for message in messages]
  60. converted_messages = []
  61. for message in messages:
  62. content_str = ''
  63. role = 'user'
  64. if 'role' in message:
  65. role = message['role']
  66. if isinstance(message, str):
  67. content_str = content_str + message + '\n'
  68. continue
  69. if isinstance(message, dict):
  70. if 'content' in message:
  71. content_str = content_str + message['content'] + '\n'
  72. elif isinstance(message, Message):
  73. role = message.role
  74. for content in message.content:
  75. if isinstance(content, list):
  76. for item in content:
  77. if isinstance(item, TextContent):
  78. content_str = content_str + item.text + '\n'
  79. elif isinstance(content, TextContent):
  80. content_str = content_str + content.text + '\n'
  81. else:
  82. logger.error(
  83. f'>>> `message` is not a string, dict, or Message: {type(message)}'
  84. )
  85. if content_str:
  86. converted_messages.append(
  87. {
  88. 'role': role,
  89. 'content': content_str,
  90. }
  91. )
  92. return converted_messages