| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- import datetime
- import json
- import pydantic
- from typing import List, Dict, Any, Optional,Tuple
- import os
- import sys
- sys.path.append(os.path.dirname(os.path.dirname(__file__)))
- from conf.config import logger
- from sqlmodel import SQLModel, Field, Relationship, create_engine
- import pickle
- from database.config import DB_URL
- from database.s3 import S3Object
- from sqlmodel import Field, SQLModel,Relationship,Column,Session,select,func,UniqueConstraint,PickleType,text
- from sqlalchemy.dialects.postgresql import ARRAY, JSON
- from sqlalchemy.engine.cursor import CursorResult
- from sqlalchemy.sql.elements import BinaryExpression
- from pydantic import Json
- from typing import TypeVar, Type ,Generic
- engine = create_engine(DB_URL)
- def create_all_table():
- SQLModel.metadata.create_all(engine)
- T = TypeVar('T', bound="BaseSQLModel")
- class BaseSQLModel(SQLModel):
- # 自动过滤无关字段,任何 dict_data 中多余的字段都会被过滤掉
- @classmethod
- def dict_to_model(cls, dict_data: dict, model:T=None) -> T:
- if not model:
- model:T = cls
- # print("dict_to_model ", model)
- clean_data = {k: v for k, v in dict_data.items() if hasattr(model, k)}
- obj_model:T = model(**clean_data)
- return obj_model
- class Config:
- arbitrary_types_allowed = True
- class UserInfoModel(BaseSQLModel, table=True):
- __tablename__ = 'user_data'
- id:int = Field(default=None, primary_key=True)
- uid: Optional[str] = Field(nullable=False)
- nickname: Optional[str]
- avatar_medium:Optional[Json] = Field(sa_column=Column(JSON))
- sec_uid: Optional[str]
- signature: Optional[str]
- city: Optional[str]
- ip_location: Optional[str]
- province: Optional[str]
- school_name: Optional[str]
- follow_status: Optional[str]
- follower_count: Optional[int]
- total_favorited: Optional[int]
- aweme_count: Optional[int]
- unread_user_data: list["UnReadUserData"] = Relationship(back_populates="user_info")
- class UnReadUserData(BaseSQLModel, table=True):
- __tablename__ = 'unread_user_data'
- id:int = Field(default=None, primary_key=True)
- name: Optional[str] = Field(default=None)
- avator: Optional[str] = Field(default=None)
- msg: Optional[str] = Field(default=None)
- unread_msg_count: Optional[int] = Field(default=None)
- msg_time:Optional[datetime.datetime] = Field(default=None)
- msg_time_txt:Optional[str] = Field(default=None)
- chat_history:Optional[Json] = Field(sa_column=Column(JSON))
- detail: Optional[S3Object] = Field(sa_column=Column(PickleType))
- create_time:datetime.datetime = Field(default_factory=datetime.datetime.now)
- is_done:Optional[bool] = Field(default=False)
- user_info_id: Optional[int | None] = Field(default=None, foreign_key="user_data.id")
- user_info:Optional[UserInfoModel|None] = Relationship(back_populates="unread_user_data", sa_relationship_kwargs={"lazy": "joined","cascade": "all, delete-orphan", "single_parent":True})
- create_all_table()
- class DataBase(Generic[T]):
- def __init__(self, engine=engine) -> None:
- self.engine = engine
-
- def fine_one(self, model: Type[T], *where:BinaryExpression) -> T:
- session = Session(engine)
- statement = select(model).where(*where)
- res = session.exec(statement)
- if res:
- return res.first()
- def insert(self, data:T|dict|str, model_type: T=None) -> T:
- model = self._get_model_from_data(data, model_type)
- with Session(engine) as session:
- session.add(model)
- session.commit()
- session.refresh(model)
- return model
-
- def update(self, model: SQLModel, update_fields: List[str] = None):
- return self.insert(model)
-
- def insert_ignore(self, model: T, unique_keys: List[str] = ['id']) -> T:
- model = self._get_model_from_data(model)
- with Session(engine) as session:
- statement = select(model.__class__)
- for key in unique_keys:
- statement = statement.where(getattr(model.__class__, key) == key)
- result = session.exec(statement).first()
- if result:
- return result
- return self.insert(model)
-
- @classmethod
- def dict_to_model(self, dict_data: dict, model:T) -> T:
- # print("dict_to_model ", model)
- clean_data = {k: v for k, v in dict_data.items() if hasattr(model, k)}
- obj_model:T = model(**clean_data)
- return obj_model
-
- def exec(self, query:str):
- '''
- E.g.:
- db = DateBase()
- db.exec("SELECT * FROM chat_task LIMIT 1").first() # will return dict obj
- db.exec("...").all() # will return list[dict]
- '''
- with Session(engine) as session:
- result:CursorResult = session.exec(text(query))
- mappings = result.mappings()
- return mappings
- def _get_model_from_data(self, data:str|dict|T, model_type) -> T:
- if isinstance(data, SQLModel):
- return data
- if isinstance(data, dict):
- return self.dict_to_model(data, model_type)
- if isinstance(data, str):
- return self.dict_to_model(json.loads(data), model_type)
- raise Exception("data type not support")
- class Table(DataBase[T]):
- def __init__(self, model_type: Type[T], engine=engine, ) -> None:
- self.model_type = model_type
- super().__init__(engine)
- def fine_one(self, *where: BinaryExpression) -> T:
- res = super().fine_one(self.model_type, *where)
- return res
- def insert(self, data: T | dict) -> T:
- return super().insert(data, self.model_type)
- def update(self, model: T) -> T:
- return super().update(model)
- def dict_to_model(self, dict_data: dict) -> T:
- return super().dict_to_model(dict_data, self.model_type)
- def _get_model_from_data(self, data: str | Dict | T, model_type=None) -> T:
- return super()._get_model_from_data(data, self.model_type)
- # 为了检查器能够完成类型检查,这里定义的时候声明了类型
- db = DataBase[UnReadUserData|UserInfoModel]()
- unread_table = Table[UnReadUserData](UnReadUserData)
- user_table = Table[UserInfoModel](UserInfoModel)
- def main():
- import json
- import time
- unread_user_data = UnReadUserData(name=f"name{time.time()}", avator="avator", msg="msg", detail=S3Object(path="test", type=tuple))
- exist_user_info = user_table.fine_one(UserInfoModel.id == 3)
- logger.info(f"{exist_user_info}")
- if exist_user_info:
- user_info = exist_user_info
- unread_user_data.user_info = user_info
- return unread_table.update(unread_user_data)
- return
- db = DataBase()
- uf = UserInfoModel(nickname="test", uid=12)
- ud = UnReadUserData(name=f"name{time.time()}", avator="avator", msg="msg", detail=S3Object.put({"a":1, "b":2}, 'test122'))
- ud.user_info = uf
- db.insert(ud)
- # res = db.fine_one(UnReadUserData, UnReadUserData.id==2)
- print(f"{ud.detail}")
- print(f"{ud.detail.get()}")
- return
- with Session(engine) as session:
- query = text("SELECT * FROM chat_task LIMIT 1")
- result = session.exec(query)
- print(result)
- print(type(result))
- mappings = result.mappings()
- print("result.mappings() ", mappings)
- print(type(mappings))
- print("result.mappings() first ", mappings.all())
- return result
- if __name__ == "__main__":
- main()
|