| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- from enum import Enum
- from typing import Union
- from pydantic import BaseModel, Field, model_serializer
- from typing_extensions import Literal
- from openhands.core.logger import openhands_logger as logger
- class ContentType(Enum):
- TEXT = 'text'
- IMAGE_URL = 'image_url'
- class Content(BaseModel):
- type: str
- cache_prompt: bool = False
- @model_serializer
- def serialize_model(self):
- raise NotImplementedError('Subclasses should implement this method.')
- class TextContent(Content):
- type: str = ContentType.TEXT.value
- text: str
- @model_serializer
- def serialize_model(self):
- data: dict[str, str | dict[str, str]] = {
- 'type': self.type,
- 'text': self.text,
- }
- if self.cache_prompt:
- data['cache_control'] = {'type': 'ephemeral'}
- return data
- class ImageContent(Content):
- type: str = ContentType.IMAGE_URL.value
- image_urls: list[str]
- @model_serializer
- def serialize_model(self):
- images: list[dict[str, str | dict[str, str]]] = []
- for url in self.image_urls:
- images.append({'type': self.type, 'image_url': {'url': url}})
- if self.cache_prompt and images:
- images[-1]['cache_control'] = {'type': 'ephemeral'}
- return images
- class Message(BaseModel):
- role: Literal['user', 'system', 'assistant']
- content: list[TextContent | ImageContent] = Field(default=list)
- @property
- def contains_image(self) -> bool:
- return any(isinstance(content, ImageContent) for content in self.content)
- @model_serializer
- def serialize_model(self) -> dict:
- content: list[dict[str, str | dict[str, str]]] = []
- for item in self.content:
- if isinstance(item, TextContent):
- content.append(item.model_dump())
- elif isinstance(item, ImageContent):
- content.extend(item.model_dump())
- return {'content': content, 'role': self.role}
- def format_messages(
- messages: Union[Message, list[Message]], with_images: bool
- ) -> list[dict]:
- if not isinstance(messages, list):
- messages = [messages]
- if with_images:
- return [message.model_dump() for message in messages]
- converted_messages = []
- for message in messages:
- content_str = ''
- role = 'user'
- if 'role' in message:
- role = message['role']
- if isinstance(message, str):
- content_str = content_str + message + '\n'
- continue
- if isinstance(message, dict):
- if 'content' in message:
- content_str = content_str + message['content'] + '\n'
- elif isinstance(message, Message):
- role = message.role
- for content in message.content:
- if isinstance(content, list):
- for item in content:
- if isinstance(item, TextContent):
- content_str = content_str + item.text + '\n'
- elif isinstance(content, TextContent):
- content_str = content_str + content.text + '\n'
- else:
- logger.error(
- f'>>> `message` is not a string, dict, or Message: {type(message)}'
- )
- if content_str:
- converted_messages.append(
- {
- 'role': role,
- 'content': content_str,
- }
- )
- return converted_messages
|