| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- from typing import Type, TypeVar, Dict, Any, Union, Optional
- from pydantic import BaseModel, model_validator
- # 定义一个通用类型变量,用于表示任意 Pydantic 模型
- ModelType = TypeVar("ModelType", bound=BaseModel)
- class ModelField:
- """
- 通用的字段类型,用于将字典自动转换为指定的 Pydantic 模型对象。
- """
- def __init__(self, model_class: Type[ModelType]):
- self.model_class = model_class
- def __call__(self, value: Any) -> ModelType:
- if isinstance(value, dict):
- return self.model_class(**value)
- elif isinstance(value, self.model_class):
- return value
- else:
- raise ValueError(f"Expected dict or {self.model_class}, got {type(value)}")
- class AutoLoadModel(BaseModel):
- """
- 基类,用于自动加载嵌套的 Pydantic 模型对象。
- """
- @model_validator(mode='before')
- def auto_load_nested_models(cls, values: Dict[str, Any]) -> Dict[str, Any]:
- for field_name, field in cls.model_fields.items():
- field_type = field.annotation
- if hasattr(field_type, "__origin__") and field_type.__origin__ is Union:
- # 处理 Union 类型(如 Optional)
- field_type = next(t for t in field_type.__args__ if t is not type(None))
- if isinstance(field_type, type) and issubclass(field_type, BaseModel):
- # 如果字段是 Pydantic 模型类型,则递归处理
- field_value = values.get(field_name)
- if isinstance(field_value, dict):
- values[field_name] = field_type(**field_value)
- elif isinstance(field_value, list):
- values[field_name] = [field_type(**item) if isinstance(item, dict) else item for item in field_value]
- elif isinstance(field_type, dict) and hasattr(field_type, "get") and callable(field_type.get):
- # 处理 Dict 类型,检查值是否为 Pydantic 模型
- field_value = values.get(field_name)
- if isinstance(field_value, dict):
- for key, value in field_value.items():
- if isinstance(value, dict):
- values[field_name][key] = field_type(value)
- return values
|