user.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from datetime import datetime
  2. from typing import Optional
  3. import os
  4. import sys
  5. sys.path.append(os.path.dirname(os.path.dirname(__file__)))
  6. from db.engine import engine
  7. from sqlmodel import Field, SQLModel,create_engine,Session,select,func
  8. import psycopg2
  9. from config import DB_URL,logger
  10. # from db.common import engine
  11. from sqlalchemy import UniqueConstraint, Index
  12. from sqlalchemy.dialects.postgresql import insert
  13. from db.base import BaseRepository
  14. # 定义数据库模型,不推荐使用 __tablename__ = '"UserInfo"' 来定义包含大写的表名字,
  15. # 因为可能会导致与其他数据库系统不兼容,而且表查询的时候需要额外注意表格名的大小写
  16. class UserOAuthToken(SQLModel, table=True):
  17. id: Optional[int] = Field(primary_key=True)
  18. access_token:str
  19. expires_in: Optional[int] = None
  20. open_id:str = Field(index=True)
  21. refresh_expires_in: Optional[int] = None
  22. refresh_token:str
  23. scope: str
  24. update_time: datetime = Field(default_factory=datetime.now) # 添加时间戳字段
  25. __table_args__ = (UniqueConstraint('open_id'),)
  26. class UserInfo(SQLModel, table=True):
  27. id: Optional[int] = Field(primary_key=True)
  28. avatar: str
  29. avatar_larger: str
  30. client_key: str
  31. e_account_role: str = Field(default="")
  32. nickname: str
  33. # 外键约束有助于:级联操作、避免冗余、数据完整性
  34. open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)
  35. union_id: str
  36. update_time: datetime = Field(default_factory=datetime.now)
  37. __table_args__ = (UniqueConstraint('open_id'),)
  38. class DouyinBaseRepository(BaseRepository):
  39. def __init__(self, model: SQLModel, engine=engine):
  40. super().__init__(model, engine)
  41. def add_or_update(self, dict_data: dict) -> SQLModel:
  42. clean_data = {k: v for k, v in dict_data.items() if hasattr(self.model, k)}
  43. obj_model = self.model(**clean_data)
  44. with Session(bind=self.engine) as session:
  45. exist_obj = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
  46. logger.debug(f"check table '{self.model.__tablename__}' where {self.unique_constraint_fields}")
  47. if exist_obj:
  48. self.set_update_time(exist_obj)
  49. dict_data = self.model_dump_by_field(obj_model, self.non_unique_fields)
  50. self.set_obj_by_dict(exist_obj,dict_data)
  51. logger.info(f"modify table '{self.model.__tablename__}' id '{exist_obj.open_id}' from {dict_data}")
  52. session.commit()
  53. return exist_obj
  54. else:
  55. self.create(obj_model)
  56. logger.debug(f"on table '{self.model.__tablename__}' create {obj_model}")
  57. return obj_model
  58. def get_by_open_id(self, open_id):
  59. with Session(self.engine) as session:
  60. logger.debug(f"get {open_id}")
  61. base_statement = select(self.model).where(self.model.open_id == open_id)
  62. results = session.exec(base_statement)
  63. return results.first()
  64. class UserInfoRepository(DouyinBaseRepository):
  65. def __init__(self, engine=engine):
  66. super().__init__(UserInfo, engine)
  67. self.model:UserInfo
  68. def add_or_update_by_unique(self, obj_in: dict) -> SQLModel:
  69. clean_data = {k: v for k, v in obj_in.items() if hasattr(self.model, k)}
  70. obj_model = self.model(**clean_data)
  71. with Session(bind=self.engine) as session:
  72. exist_obj = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
  73. if exist_obj:
  74. self.set_update_time(exist_obj)
  75. dict_data = self.model_dump_by_field(obj_model, self.non_unique_fields)
  76. self.set_obj_by_dict(exist_obj,dict_data)
  77. session.commit()
  78. return exist_obj
  79. else:
  80. self.create(obj_model)
  81. session.commit()
  82. return obj_model
  83. # Database manager class
  84. class UserOAuthRepository(DouyinBaseRepository):
  85. def __init__(self, engine=engine):
  86. super().__init__(UserOAuthToken, engine)
  87. self.model:UserOAuthToken
  88. # def add_token(self, data: dict):
  89. # clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)}
  90. # obj_model = self.model(**clean_data)
  91. # with Session(bind=self.engine) as session:
  92. # exist_obj = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
  93. # if exist_obj:
  94. # self.set_update_time(exist_obj)
  95. # dict_data = self.model_dump_by_field(obj_model, self.unique_constraint_fields)
  96. # self.set_obj_by_dict(exist_obj,dict_data)
  97. # session.commit()
  98. # return exist_obj
  99. # else:
  100. # self.create(obj_model)
  101. # session.commit()
  102. # return obj_model
  103. async def delete_token(self, token_id: int):
  104. async with self.session_factory() as session:
  105. statement = select(UserOAuthToken).where(UserOAuthToken.id == token_id)
  106. token = await session.execute(statement).scalars().first()
  107. if token:
  108. await session.delete(token)
  109. await session.commit()
  110. print(f"Record deleted: ID - {token_id}")
  111. else:
  112. print(f"Record with ID {token_id} not found")
  113. def test_add(open_id=None):
  114. SQLModel.metadata.create_all(engine)
  115. user_oauth = {'access_token': 'act.3.m3kiZmfxxIH95i1bHZ7Bq3Wkv_Xm5TtD3kpLGjtCr3G96WIINBKEvzlsaObrGcH4GaxTQeLZA13jkzoZhpAwPwMRqFxlVuIcxpge_-BpdFib1xHqkcFa4B-LX4zpd2YK3kDFTFfMcJXN_fZ2eByg6oqqa1OieUWcvlaVgw==', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '202402171353461C33F969CEFB511B216F', 'open_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.e4b3da8bd3ef01d880d827b11e826391OEGHiRrTLcp5zsGYP1dh6F9Bo7fg', 'scope': 'user_info,trial.whitelist'}
  116. user_info = {
  117. "avatar": "https://p26.douyinpic.com/aweme/100x100/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
  118. "avatar_larger": "https://p3.douyinpic.com/aweme/1080x1080/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
  119. "captcha": "",
  120. "city": "",
  121. "client_key": "55",
  122. "country": "",
  123. "desc_url": "",
  124. "description": "",
  125. "district": "",
  126. "e_account_role": "",
  127. "error_code": 0,
  128. "gender": 0,
  129. "log_id": "202401261424326FE877A6CAB03910C553",
  130. "nickname": "程序员马工",
  131. "open_id": "_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy",
  132. "province": "",
  133. "union_id": "123-01ae-59bd-978a-1de8566186a8"
  134. }
  135. if open_id:
  136. user_oauth["open_id"] = open_id
  137. user_info["open_id"] = open_id
  138. user_info["nickname"] = "user" + open_id[:5]
  139. else:
  140. open_id = user_oauth["open_id"]
  141. db_manager = UserOAuthRepository()
  142. res = db_manager.add_or_update(user_oauth)
  143. # logger.debug(res)
  144. db_user_info = UserInfoRepository()
  145. res = db_user_info.add_or_update(user_info)
  146. logger.debug(db_manager.get_by_open_id(open_id))
  147. return user_oauth["open_id"]
  148. if __name__ == "__main__":
  149. test_add()