user.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from datetime import datetime
  2. from typing import Optional
  3. import os
  4. import sys
  5. sys.path.append(os.path.dirname(os.path.dirname(__file__)))
  6. from db.engine import engine
  7. from sqlmodel import Field, SQLModel,create_engine,Session,select,func
  8. import psycopg2
  9. from config import DB_URL,logger
  10. # from db.common import engine
  11. from sqlalchemy import UniqueConstraint, Index
  12. from sqlalchemy.dialects.postgresql import insert
  13. from db.base import BaseRepository
  14. # 定义数据库模型,不推荐使用 __tablename__ = '"UserInfo"' 来定义包含大写的表名字,
  15. # 因为可能会导致与其他数据库系统不兼容,而且表查询的时候需要额外注意表格名的大小写
  16. class UserOAuthToken(SQLModel, table=True):
  17. id: Optional[int] = Field(primary_key=True)
  18. access_token:str
  19. expires_in: Optional[int] = None
  20. open_id:str = Field(index=True)
  21. refresh_expires_in: Optional[int] = None
  22. refresh_token:str
  23. scope: str
  24. update_time: datetime = Field(default_factory=datetime.now) # 添加时间戳字段
  25. __table_args__ = (UniqueConstraint('open_id'),)
  26. class UserInfo(SQLModel, table=True):
  27. id: Optional[int] = Field(primary_key=True)
  28. avatar: str
  29. avatar_larger: str
  30. client_key: str
  31. e_account_role: str = Field(default="")
  32. nickname: str
  33. # 外键约束有助于:级联操作、避免冗余、数据完整性
  34. open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)
  35. union_id: str
  36. update_time: datetime = Field(default_factory=datetime.now)
  37. __table_args__ = (UniqueConstraint('open_id'),)
  38. class DouyinBaseRepository(BaseRepository):
  39. def __init__(self, model: SQLModel, engine=engine):
  40. super().__init__(model, engine)
  41. def add_or_update(self, dict_data: dict) -> SQLModel:
  42. clean_data = {k: v for k, v in dict_data.items() if hasattr(self.model, k)}
  43. obj_model = self.model(**clean_data)
  44. with Session(bind=self.engine) as session:
  45. exist_obj = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
  46. logger.debug(f"check table '{self.model.__tablename__}' where {self.unique_constraint_fields}")
  47. if exist_obj:
  48. self.set_update_time(exist_obj)
  49. dict_data = self.model_dump_by_field(obj_model, self.non_unique_fields)
  50. self.set_obj_by_dict(exist_obj,dict_data)
  51. logger.info(f"modify table '{self.model.__tablename__}' id '{exist_obj.open_id}' from {dict_data}")
  52. session.commit()
  53. return exist_obj
  54. else:
  55. session.commit()
  56. return obj_model
  57. class UserInfoRepository(DouyinBaseRepository):
  58. def __init__(self, engine=engine):
  59. super().__init__(UserInfo, engine)
  60. self.model:UserInfo
  61. def add_or_update_by_unique(self, obj_in: dict) -> SQLModel:
  62. clean_data = {k: v for k, v in obj_in.items() if hasattr(self.model, k)}
  63. obj_model = self.model(**clean_data)
  64. with Session(bind=self.engine) as session:
  65. exist_obj = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
  66. if exist_obj:
  67. self.set_update_time(exist_obj)
  68. dict_data = self.model_dump_by_field(obj_model, self.non_unique_fields)
  69. self.set_obj_by_dict(exist_obj,dict_data)
  70. session.commit()
  71. return exist_obj
  72. else:
  73. self.create(obj_model)
  74. session.commit()
  75. return obj_model
  76. # Database manager class
  77. class UserOAuthRepository(DouyinBaseRepository):
  78. def __init__(self, engine=engine):
  79. super().__init__(UserOAuthToken, engine)
  80. self.model:UserOAuthToken
  81. # def add_token(self, data: dict):
  82. # clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)}
  83. # obj_model = self.model(**clean_data)
  84. # with Session(bind=self.engine) as session:
  85. # exist_obj = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
  86. # if exist_obj:
  87. # self.set_update_time(exist_obj)
  88. # dict_data = self.model_dump_by_field(obj_model, self.unique_constraint_fields)
  89. # self.set_obj_by_dict(exist_obj,dict_data)
  90. # session.commit()
  91. # return exist_obj
  92. # else:
  93. # self.create(obj_model)
  94. # session.commit()
  95. # return obj_model
  96. async def delete_token(self, token_id: int):
  97. async with self.session_factory() as session:
  98. statement = select(UserOAuthToken).where(UserOAuthToken.id == token_id)
  99. token = await session.execute(statement).scalars().first()
  100. if token:
  101. await session.delete(token)
  102. await session.commit()
  103. print(f"Record deleted: ID - {token_id}")
  104. else:
  105. print(f"Record with ID {token_id} not found")
  106. def test_add(open_id=None):
  107. SQLModel.metadata.create_all(engine)
  108. user_oauth = {'access_token': 'act.3.wl8L3DFQ3sj3uKYzQShOSs8HbOgKh0FVvjxKeaTum0ZOEXoyBI8D1N7gTBqGbrY32KP-Pm41EAvcobSheOBi8tvRdhj7m5-5ZVoprZZu_GN5J2KnH2fZ_X9_l7Q6iFyvpPoMkX3Zyom3PCkeRZp4Jg9sE2ZiwuvZVdnvft0A25uBWXvj2IEbWW_0Bf8=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '20240129123749239735B0529965BC6D93', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
  109. user_info = {
  110. "avatar": "https://p26.douyinpic.com/aweme/100x100/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
  111. "avatar_larger": "https://p3.douyinpic.com/aweme/1080x1080/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
  112. "captcha": "",
  113. "city": "",
  114. "client_key": "55",
  115. "country": "",
  116. "desc_url": "",
  117. "description": "",
  118. "district": "",
  119. "e_account_role": "",
  120. "error_code": 0,
  121. "gender": 0,
  122. "log_id": "202401261424326FE877A6CAB03910C553",
  123. "nickname": "程序员马工",
  124. "open_id": "_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy",
  125. "province": "",
  126. "union_id": "123-01ae-59bd-978a-1de8566186a8"
  127. }
  128. if open_id:
  129. user_oauth["open_id"] = open_id
  130. user_info["open_id"] = open_id
  131. user_info["nickname"] = "user" + open_id[:5]
  132. db_manager = UserOAuthRepository()
  133. res = db_manager.add_or_update(user_oauth)
  134. # logger.debug(res)
  135. db_user_info = UserInfoRepository()
  136. res = db_user_info.add_or_update(user_info)
  137. # logger.debug(res)
  138. return user_oauth["open_id"]
  139. if __name__ == "__main__":
  140. test_add()