import datetime import json import pydantic from typing import List, Dict, Any, Optional,Tuple import os import sys sys.path.append(os.path.dirname(os.path.dirname(__file__))) from conf.config import logger from sqlmodel import SQLModel, Field, Relationship, create_engine import pickle from database.config import DB_URL from database.s3 import S3Object from sqlmodel import Field, SQLModel,Relationship,Column,Session,select,func,UniqueConstraint,PickleType,text from sqlalchemy.dialects.postgresql import ARRAY, JSON from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.sql.elements import BinaryExpression from pydantic import Json from typing import TypeVar, Type ,Generic engine = create_engine(DB_URL) def create_all_table(): SQLModel.metadata.create_all(engine) T = TypeVar('T', bound="BaseSQLModel") class BaseSQLModel(SQLModel): # 自动过滤无关字段,任何 dict_data 中多余的字段都会被过滤掉 @classmethod def dict_to_model(cls, dict_data: dict, model:T=None) -> T: if not model: model:T = cls # print("dict_to_model ", model) clean_data = {k: v for k, v in dict_data.items() if hasattr(model, k)} obj_model:T = model(**clean_data) return obj_model class Config: arbitrary_types_allowed = True class UserInfoModel(BaseSQLModel, table=True): __tablename__ = 'user_data' id:int = Field(default=None, primary_key=True) uid: Optional[str] = Field(nullable=False) nickname: Optional[str] avatar_medium:Optional[Json] = Field(sa_column=Column(JSON)) sec_uid: Optional[str] signature: Optional[str] city: Optional[str] ip_location: Optional[str] province: Optional[str] school_name: Optional[str] follow_status: Optional[str] follower_count: Optional[int] total_favorited: Optional[int] aweme_count: Optional[int] unread_user_data: list["UnReadUserData"] = Relationship(back_populates="user_info") class UnReadUserData(BaseSQLModel, table=True): __tablename__ = 'unread_user_data' id:int = Field(default=None, primary_key=True) name: Optional[str] = Field(default=None) avator: Optional[str] = Field(default=None) msg: Optional[str] = Field(default=None) unread_msg_count: Optional[int] = Field(default=None) msg_time:Optional[datetime.datetime] = Field(default=None) msg_time_txt:Optional[str] = Field(default=None) chat_history:Optional[Json] = Field(sa_column=Column(JSON)) detail: Optional[S3Object] = Field(sa_column=Column(PickleType)) create_time:datetime.datetime = Field(default_factory=datetime.datetime.now) is_done:Optional[bool] = Field(default=False) user_info_id: Optional[int | None] = Field(default=None, foreign_key="user_data.id") user_info:Optional[UserInfoModel|None] = Relationship(back_populates="unread_user_data", sa_relationship_kwargs={"lazy": "joined","cascade": "all, delete-orphan", "single_parent":True}) create_all_table() class DataBase(Generic[T]): def __init__(self, engine=engine) -> None: self.engine = engine def fine_one(self, model: Type[T], *where:BinaryExpression) -> T: session = Session(engine) statement = select(model).where(*where) res = session.exec(statement) if res: return res.first() def insert(self, data:T|dict|str, model_type: T=None) -> T: model = self._get_model_from_data(data, model_type) with Session(engine) as session: session.add(model) session.commit() session.refresh(model) return model def update(self, model: SQLModel, update_fields: List[str] = None): return self.insert(model) def insert_ignore(self, model: T, unique_keys: List[str] = ['id']) -> T: model = self._get_model_from_data(model) with Session(engine) as session: statement = select(model.__class__) for key in unique_keys: statement = statement.where(getattr(model.__class__, key) == key) result = session.exec(statement).first() if result: return result return self.insert(model) @classmethod def dict_to_model(self, dict_data: dict, model:T) -> T: # print("dict_to_model ", model) clean_data = {k: v for k, v in dict_data.items() if hasattr(model, k)} obj_model:T = model(**clean_data) return obj_model def exec(self, query:str): ''' E.g.: db = DateBase() db.exec("SELECT * FROM chat_task LIMIT 1").first() # will return dict obj db.exec("...").all() # will return list[dict] ''' with Session(engine) as session: result:CursorResult = session.exec(text(query)) mappings = result.mappings() return mappings def _get_model_from_data(self, data:str|dict|T, model_type) -> T: if isinstance(data, SQLModel): return data if isinstance(data, dict): return self.dict_to_model(data, model_type) if isinstance(data, str): return self.dict_to_model(json.loads(data), model_type) raise Exception("data type not support") class Table(DataBase[T]): def __init__(self, model_type: Type[T], engine=engine, ) -> None: self.model_type = model_type super().__init__(engine) def fine_one(self, *where: BinaryExpression) -> T: res = super().fine_one(self.model_type, *where) return res def insert(self, data: T | dict) -> T: return super().insert(data, self.model_type) def update(self, model: T) -> T: return super().update(model) def dict_to_model(self, dict_data: dict) -> T: return super().dict_to_model(dict_data, self.model_type) def _get_model_from_data(self, data: str | Dict | T, model_type=None) -> T: return super()._get_model_from_data(data, self.model_type) # 为了检查器能够完成类型检查,这里定义的时候声明了类型 db = DataBase[UnReadUserData|UserInfoModel]() unread_table = Table[UnReadUserData](UnReadUserData) user_table = Table[UserInfoModel](UserInfoModel) def main(): import json import time unread_user_data = UnReadUserData(name=f"name{time.time()}", avator="avator", msg="msg", detail=S3Object(path="test", type=tuple)) exist_user_info = user_table.fine_one(UserInfoModel.id == 3) logger.info(f"{exist_user_info}") if exist_user_info: user_info = exist_user_info unread_user_data.user_info = user_info return unread_table.update(unread_user_data) return db = DataBase() uf = UserInfoModel(nickname="test", uid=12) ud = UnReadUserData(name=f"name{time.time()}", avator="avator", msg="msg", detail=S3Object.put({"a":1, "b":2}, 'test122')) ud.user_info = uf db.insert(ud) # res = db.fine_one(UnReadUserData, UnReadUserData.id==2) print(f"{ud.detail}") print(f"{ud.detail.get()}") return with Session(engine) as session: query = text("SELECT * FROM chat_task LIMIT 1") result = session.exec(query) print(result) print(type(result)) mappings = result.mappings() print("result.mappings() ", mappings) print(type(mappings)) print("result.mappings() first ", mappings.all()) return result if __name__ == "__main__": main()