user.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from datetime import datetime
  2. from typing import Optional
  3. import os
  4. import sys
  5. sys.path.append(os.path.dirname(os.path.dirname(__file__)))
  6. from sqlmodel import Field, SQLModel,create_engine,Session,select,func
  7. import psycopg2
  8. from config import DB_URL,logger
  9. from douyin.access_token import get_access_token
  10. # from db.common import engine
  11. from sqlalchemy import UniqueConstraint, Index
  12. from sqlalchemy.dialects.postgresql import insert
  13. # 定义数据库模型
  14. class UserOAuthToken(SQLModel, table=True):
  15. id: Optional[int] = Field(default=None, primary_key=True)
  16. access_token:str
  17. expires_in: Optional[int] = None
  18. open_id:str
  19. refresh_expires_in: Optional[int] = None
  20. refresh_token:str
  21. scope: str
  22. update_time: datetime = Field(default_factory=datetime.now) # 添加时间戳字段
  23. __table_args__ = (UniqueConstraint('open_id'),)
  24. class UserInfo(SQLModel, table=True):
  25. id: Optional[int] = Field(default=None, primary_key=True)
  26. avatar: str
  27. avatar_larger: str
  28. client_key: str
  29. e_account_role: str = Field(default="")
  30. nickname: str
  31. open_id: str
  32. union_id: str
  33. update_time: datetime = Field(default_factory=datetime.now)
  34. __table_args__ = (UniqueConstraint('open_id'),)
  35. engine = create_engine(DB_URL) # 替换成你的 DB_URL
  36. SQLModel.metadata.create_all(engine)
  37. class UserInfoRepository:
  38. def __init__(self, engine=engine):
  39. self.engine = engine
  40. def create_user_info(self, user_info_data):
  41. # 剔除不需要的字段
  42. cleaned_data = {k: v for k, v in user_info_data.items() if k not in ["log_id", "error_code"]}
  43. # 添加或更新时间戳
  44. cleaned_data['update_time'] = func.now()
  45. with Session(self.engine) as session:
  46. # 使用 on_conflict_do_update 处理 open_id 的冲突
  47. insert_stmt = insert(UserInfo).values(**cleaned_data)
  48. update_stmt = insert_stmt.on_conflict_do_update(
  49. constraint="open_id", # 使用 open_id 作为冲突约束
  50. set_={**{k: cleaned_data[k] for k in cleaned_data if k != "open_id"}, "update_time": func.now()} # 更新其他字段,包括时间戳
  51. )
  52. result = session.exec(update_stmt)
  53. session.commit()
  54. def get_user_info_by_open_id(self, open_id):
  55. with Session(self.engine) as session:
  56. statement = select(UserInfo).where(UserInfo.open_id == open_id)
  57. result = session.exec(statement)
  58. return result.first()
  59. def update_user_info(self, user_id, user_info_data):
  60. with Session(self.engine) as session:
  61. update_user_info = session.get(UserInfo, user_id)
  62. if update_user_info:
  63. for key, value in user_info_data.items():
  64. setattr(update_user_info, key, value)
  65. session.commit()
  66. return update_user_info
  67. def delete_user_info(self, user_id):
  68. with Session(self.engine) as session:
  69. delete_user_info = session.get(UserInfo, user_id)
  70. if delete_user_info:
  71. session.delete(delete_user_info)
  72. session.commit()
  73. # Database manager class
  74. class UserOAuthRepository:
  75. def __init__(self, engine=engine):
  76. self.engine = engine
  77. def add_token(self, data: dict):
  78. # 剔除不需要的字段
  79. cleaned_data = {
  80. k: v for k, v in data.items()
  81. if k not in ["log_id", "error_code", "captcha", "desc_url", "description"]
  82. }
  83. # 添加或更新时间戳
  84. cleaned_data['update_time'] = func.now()
  85. # 构造插入语句
  86. insert_stmt = insert(UserOAuthToken).values(**cleaned_data)
  87. update_stmt = insert_stmt.on_conflict_do_update(
  88. index_elements=['open_id'], # 使用 open_id 作为冲突的目标列
  89. set_={
  90. **{k: insert_stmt.excluded[k] for k in cleaned_data if k != "open_id"},
  91. "update_time": func.now() # 更新时间戳
  92. }
  93. )
  94. # 执行插入/更新操作
  95. with Session(self.engine) as session:
  96. result = session.exec(update_stmt) # 注意:这里应该是 execute 而不是 exec
  97. session.commit()
  98. logger.debug(f"Record added/updated: Access Token, Open ID - {cleaned_data['open_id']}")
  99. def delete_token(self, token_id: int):
  100. with Session(self.engine) as session:
  101. token = session.get(UserOAuthToken, token_id)
  102. if token:
  103. session.delete(token)
  104. session.commit()
  105. print(f"Record deleted: ID - {token_id}")
  106. else:
  107. print(f"Record with ID {token_id} not found")
  108. def display_all_records(self):
  109. with Session(self.engine) as session:
  110. statement = select(UserOAuthToken)
  111. user_tokens = session.exec(statement).all()
  112. return user_tokens
  113. def main():
  114. db_manager = UserOAuthRepository()
  115. data = {'access_token': 'act.3.wl8L3DFQ3sj3uKYzQShOSs8HbOgKh0FVvjxKeaTum0ZOEXoyBI8D1N7gTBqGbrY32KP-Pm41EAvcobSheOBi8tvRdhj7m5-5ZVoprZZu_GN5J2KnH2fZ_X9_l7Q6iFyvpPoMkX3Zyom3PCkeRZp4Jg9sE2ZiwuvZVdnvft0A25uBWXvj2IEbWW_0Bf8=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '20240129123749239735B0529965BC6D93', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
  116. db_manager.add_token(data)
  117. res = db_manager.display_all_records()
  118. logger.debug(res)
  119. if __name__ == "__main__":
  120. main()