from datetime import datetime from typing import Optional import os import sys sys.path.append(os.path.dirname(os.path.dirname(__file__))) from sqlmodel import Field, SQLModel,create_engine,Session,select,func import psycopg2 from config import DB_URL,logger from douyin.access_token import get_access_token # from db.common import engine from sqlalchemy import UniqueConstraint, Index from sqlalchemy.dialects.postgresql import insert # 定义数据库模型 class UserOAuthToken(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) access_token:str expires_in: Optional[int] = None open_id:str refresh_expires_in: Optional[int] = None refresh_token:str scope: str update_time: datetime = Field(default_factory=datetime.now) # 添加时间戳字段 __table_args__ = (UniqueConstraint('open_id'),) class UserInfo(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) avatar: str avatar_larger: str client_key: str e_account_role: str = Field(default="") nickname: str open_id: str union_id: str update_time: datetime = Field(default_factory=datetime.now) __table_args__ = (UniqueConstraint('open_id'),) engine = create_engine(DB_URL) # 替换成你的 DB_URL SQLModel.metadata.create_all(engine) class UserInfoRepository: def __init__(self, engine=engine): self.engine = engine def create_user_info(self, user_info_data): # 剔除不需要的字段 cleaned_data = {k: v for k, v in user_info_data.items() if k not in ["log_id", "error_code"]} # 添加或更新时间戳 cleaned_data['update_time'] = func.now() with Session(self.engine) as session: # 使用 on_conflict_do_update 处理 open_id 的冲突 insert_stmt = insert(UserInfo).values(**cleaned_data) update_stmt = insert_stmt.on_conflict_do_update( constraint="open_id", # 使用 open_id 作为冲突约束 set_={**{k: cleaned_data[k] for k in cleaned_data if k != "open_id"}, "update_time": func.now()} # 更新其他字段,包括时间戳 ) result = session.exec(update_stmt) session.commit() def get_user_info_by_open_id(self, open_id): with Session(self.engine) as session: statement = select(UserInfo).where(UserInfo.open_id == open_id) result = session.exec(statement) return result.first() def update_user_info(self, user_id, user_info_data): with Session(self.engine) as session: update_user_info = session.get(UserInfo, user_id) if update_user_info: for key, value in user_info_data.items(): setattr(update_user_info, key, value) session.commit() return update_user_info def delete_user_info(self, user_id): with Session(self.engine) as session: delete_user_info = session.get(UserInfo, user_id) if delete_user_info: session.delete(delete_user_info) session.commit() # Database manager class class UserOAuthRepository: def __init__(self, engine=engine): self.engine = engine def add_token(self, data: dict): # 剔除不需要的字段 cleaned_data = { k: v for k, v in data.items() if k not in ["log_id", "error_code", "captcha", "desc_url", "description"] } # 添加或更新时间戳 cleaned_data['update_time'] = func.now() # 构造插入语句 insert_stmt = insert(UserOAuthToken).values(**cleaned_data) update_stmt = insert_stmt.on_conflict_do_update( index_elements=['open_id'], # 使用 open_id 作为冲突的目标列 set_={ **{k: insert_stmt.excluded[k] for k in cleaned_data if k != "open_id"}, "update_time": func.now() # 更新时间戳 } ) # 执行插入/更新操作 with Session(self.engine) as session: result = session.exec(update_stmt) # 注意:这里应该是 execute 而不是 exec session.commit() logger.debug(f"Record added/updated: Access Token, Open ID - {cleaned_data['open_id']}") def delete_token(self, token_id: int): with Session(self.engine) as session: token = session.get(UserOAuthToken, token_id) if token: session.delete(token) session.commit() print(f"Record deleted: ID - {token_id}") else: print(f"Record with ID {token_id} not found") def display_all_records(self): with Session(self.engine) as session: statement = select(UserOAuthToken) user_tokens = session.exec(statement).all() return user_tokens def main(): db_manager = UserOAuthRepository() data = {'access_token': 'act.3.wl8L3DFQ3sj3uKYzQShOSs8HbOgKh0FVvjxKeaTum0ZOEXoyBI8D1N7gTBqGbrY32KP-Pm41EAvcobSheOBi8tvRdhj7m5-5ZVoprZZu_GN5J2KnH2fZ_X9_l7Q6iFyvpPoMkX3Zyom3PCkeRZp4Jg9sE2ZiwuvZVdnvft0A25uBWXvj2IEbWW_0Bf8=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '20240129123749239735B0529965BC6D93', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'} db_manager.add_token(data) res = db_manager.display_all_records() logger.debug(res) if __name__ == "__main__": main()