models.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import datetime
  2. import json
  3. import pydantic
  4. from typing import List, Dict, Any, Optional,Tuple
  5. import os
  6. import sys
  7. sys.path.append(os.path.dirname(os.path.dirname(__file__)))
  8. from conf.config import logger
  9. from sqlmodel import SQLModel, Field, Relationship, create_engine
  10. import pickle
  11. from database.config import DB_URL
  12. from database.s3 import S3Object
  13. from sqlmodel import Field, SQLModel,Relationship,Column,Session,select,func,UniqueConstraint,PickleType,text
  14. from sqlalchemy.dialects.postgresql import ARRAY, JSON
  15. from sqlalchemy.engine.cursor import CursorResult
  16. from sqlalchemy.sql.elements import BinaryExpression
  17. from pydantic import Json
  18. from typing import TypeVar, Type ,Generic
  19. engine = create_engine(DB_URL)
  20. def create_all_table():
  21. SQLModel.metadata.create_all(engine)
  22. T = TypeVar('T', bound="BaseSQLModel")
  23. class BaseSQLModel(SQLModel):
  24. # 自动过滤无关字段,任何 dict_data 中多余的字段都会被过滤掉
  25. @classmethod
  26. def dict_to_model(cls, dict_data: dict, model:T=None) -> T:
  27. if not model:
  28. model:T = cls
  29. # print("dict_to_model ", model)
  30. clean_data = {k: v for k, v in dict_data.items() if hasattr(model, k)}
  31. obj_model:T = model(**clean_data)
  32. return obj_model
  33. class Config:
  34. arbitrary_types_allowed = True
  35. class UserInfoModel(BaseSQLModel, table=True):
  36. __tablename__ = 'user_data'
  37. id:int = Field(default=None, primary_key=True)
  38. uid: Optional[str] = Field(nullable=False)
  39. nickname: Optional[str]
  40. avatar_medium:Optional[Json] = Field(sa_column=Column(JSON))
  41. sec_uid: Optional[str]
  42. signature: Optional[str]
  43. city: Optional[str]
  44. ip_location: Optional[str]
  45. province: Optional[str]
  46. school_name: Optional[str]
  47. follow_status: Optional[str]
  48. follower_count: Optional[int]
  49. total_favorited: Optional[int]
  50. aweme_count: Optional[int]
  51. unread_user_data: list["UnReadUserData"] = Relationship(back_populates="user_info")
  52. class UnReadUserData(BaseSQLModel, table=True):
  53. __tablename__ = 'unread_user_data'
  54. id:int = Field(default=None, primary_key=True)
  55. name: Optional[str] = Field(default=None)
  56. avator: Optional[str] = Field(default=None)
  57. msg: Optional[str] = Field(default=None)
  58. unread_msg_count: Optional[int] = Field(default=None)
  59. msg_time:Optional[datetime.datetime] = Field(default=None)
  60. msg_time_txt:Optional[str] = Field(default=None)
  61. chat_history:Optional[Json] = Field(sa_column=Column(JSON))
  62. detail: Optional[S3Object] = Field(sa_column=Column(PickleType))
  63. create_time:datetime.datetime = Field(default_factory=datetime.datetime.now)
  64. is_done:Optional[bool] = Field(default=False)
  65. user_info_id: Optional[int | None] = Field(default=None, foreign_key="user_data.id")
  66. user_info:Optional[UserInfoModel|None] = Relationship(back_populates="unread_user_data", sa_relationship_kwargs={"lazy": "joined","cascade": "all, delete-orphan", "single_parent":True})
  67. create_all_table()
  68. class DataBase(Generic[T]):
  69. def __init__(self, engine=engine) -> None:
  70. self.engine = engine
  71. def fine_one(self, model: Type[T], *where:BinaryExpression) -> T:
  72. session = Session(engine)
  73. statement = select(model).where(*where)
  74. res = session.exec(statement)
  75. if res:
  76. return res.first()
  77. def insert(self, data:T|dict|str, model_type: T=None) -> T:
  78. model = self._get_model_from_data(data, model_type)
  79. with Session(engine) as session:
  80. session.add(model)
  81. session.commit()
  82. session.refresh(model)
  83. return model
  84. def update(self, model: SQLModel, update_fields: List[str] = None):
  85. return self.insert(model)
  86. def insert_ignore(self, model: T, unique_keys: List[str] = ['id']) -> T:
  87. model = self._get_model_from_data(model)
  88. with Session(engine) as session:
  89. statement = select(model.__class__)
  90. for key in unique_keys:
  91. statement = statement.where(getattr(model.__class__, key) == key)
  92. result = session.exec(statement).first()
  93. if result:
  94. return result
  95. return self.insert(model)
  96. @classmethod
  97. def dict_to_model(self, dict_data: dict, model:T) -> T:
  98. # print("dict_to_model ", model)
  99. clean_data = {k: v for k, v in dict_data.items() if hasattr(model, k)}
  100. obj_model:T = model(**clean_data)
  101. return obj_model
  102. def exec(self, query:str):
  103. '''
  104. E.g.:
  105. db = DateBase()
  106. db.exec("SELECT * FROM chat_task LIMIT 1").first() # will return dict obj
  107. db.exec("...").all() # will return list[dict]
  108. '''
  109. with Session(engine) as session:
  110. result:CursorResult = session.exec(text(query))
  111. mappings = result.mappings()
  112. return mappings
  113. def _get_model_from_data(self, data:str|dict|T, model_type) -> T:
  114. if isinstance(data, SQLModel):
  115. return data
  116. if isinstance(data, dict):
  117. return self.dict_to_model(data, model_type)
  118. if isinstance(data, str):
  119. return self.dict_to_model(json.loads(data), model_type)
  120. raise Exception("data type not support")
  121. class Table(DataBase[T]):
  122. def __init__(self, model_type: Type[T], engine=engine, ) -> None:
  123. self.model_type = model_type
  124. super().__init__(engine)
  125. def fine_one(self, *where: BinaryExpression) -> T:
  126. res = super().fine_one(self.model_type, *where)
  127. return res
  128. def insert(self, data: T | dict) -> T:
  129. return super().insert(data, self.model_type)
  130. def update(self, model: T) -> T:
  131. return super().update(model)
  132. def dict_to_model(self, dict_data: dict) -> T:
  133. return super().dict_to_model(dict_data, self.model_type)
  134. def _get_model_from_data(self, data: str | Dict | T, model_type=None) -> T:
  135. return super()._get_model_from_data(data, self.model_type)
  136. # 为了检查器能够完成类型检查,这里定义的时候声明了类型
  137. db = DataBase[UnReadUserData|UserInfoModel]()
  138. unread_table = Table[UnReadUserData](UnReadUserData)
  139. user_table = Table[UserInfoModel](UserInfoModel)
  140. def main():
  141. import json
  142. import time
  143. unread_user_data = UnReadUserData(name=f"name{time.time()}", avator="avator", msg="msg", detail=S3Object(path="test", type=tuple))
  144. exist_user_info = user_table.fine_one(UserInfoModel.id == 3)
  145. logger.info(f"{exist_user_info}")
  146. if exist_user_info:
  147. user_info = exist_user_info
  148. unread_user_data.user_info = user_info
  149. return unread_table.update(unread_user_data)
  150. return
  151. db = DataBase()
  152. uf = UserInfoModel(nickname="test", uid=12)
  153. ud = UnReadUserData(name=f"name{time.time()}", avator="avator", msg="msg", detail=S3Object.put({"a":1, "b":2}, 'test122'))
  154. ud.user_info = uf
  155. db.insert(ud)
  156. # res = db.fine_one(UnReadUserData, UnReadUserData.id==2)
  157. print(f"{ud.detail}")
  158. print(f"{ud.detail.get()}")
  159. return
  160. with Session(engine) as session:
  161. query = text("SELECT * FROM chat_task LIMIT 1")
  162. result = session.exec(query)
  163. print(result)
  164. print(type(result))
  165. mappings = result.mappings()
  166. print("result.mappings() ", mappings)
  167. print(type(mappings))
  168. print("result.mappings() first ", mappings.all())
  169. return result
  170. if __name__ == "__main__":
  171. main()