| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- from datetime import datetime
- from typing import Optional
- import os
- import sys
- sys.path.append(os.path.dirname(os.path.dirname(__file__)))
- from sqlmodel import Field, SQLModel,create_engine,Session,select,func
- import psycopg2
- from config import DB_URL,logger
- from douyin.access_token import get_access_token
- # from db.common import engine
- from sqlalchemy import UniqueConstraint, Index
- from sqlalchemy.dialects.postgresql import insert
- # 定义数据库模型
- class UserOAuthToken(SQLModel, table=True):
- id: Optional[int] = Field(default=None, primary_key=True)
- access_token:str
- expires_in: Optional[int] = None
- open_id:str
- refresh_expires_in: Optional[int] = None
- refresh_token:str
- scope: str
- update_time: datetime = Field(default_factory=datetime.now) # 添加时间戳字段
- __table_args__ = (UniqueConstraint('open_id'),)
- class UserInfo(SQLModel, table=True):
- id: Optional[int] = Field(default=None, primary_key=True)
- avatar: str
- avatar_larger: str
- client_key: str
- e_account_role: str = Field(default="")
- nickname: str
- open_id: str
- union_id: str
- update_time: datetime = Field(default_factory=datetime.now)
- __table_args__ = (UniqueConstraint('open_id'),)
-
-
- engine = create_engine(DB_URL) # 替换成你的 DB_URL
- SQLModel.metadata.create_all(engine)
- class UserInfoRepository:
- def __init__(self, engine=engine):
- self.engine = engine
-
- def create_user_info(self, user_info_data):
- # 剔除不需要的字段
- cleaned_data = {k: v for k, v in user_info_data.items() if k not in ["log_id", "error_code"]}
-
- # 添加或更新时间戳
- cleaned_data['update_time'] = func.now()
-
- with Session(self.engine) as session:
- # 使用 on_conflict_do_update 处理 open_id 的冲突
- insert_stmt = insert(UserInfo).values(**cleaned_data)
- update_stmt = insert_stmt.on_conflict_do_update(
- constraint="open_id", # 使用 open_id 作为冲突约束
- set_={**{k: cleaned_data[k] for k in cleaned_data if k != "open_id"}, "update_time": func.now()} # 更新其他字段,包括时间戳
- )
- result = session.exec(update_stmt)
- session.commit()
-
- def get_user_info_by_open_id(self, open_id):
- with Session(self.engine) as session:
- statement = select(UserInfo).where(UserInfo.open_id == open_id)
- result = session.exec(statement)
- return result.first()
-
- def update_user_info(self, user_id, user_info_data):
- with Session(self.engine) as session:
- update_user_info = session.get(UserInfo, user_id)
- if update_user_info:
- for key, value in user_info_data.items():
- setattr(update_user_info, key, value)
- session.commit()
- return update_user_info
-
- def delete_user_info(self, user_id):
- with Session(self.engine) as session:
- delete_user_info = session.get(UserInfo, user_id)
- if delete_user_info:
- session.delete(delete_user_info)
- session.commit()
-
- # Database manager class
- class UserOAuthRepository:
- def __init__(self, engine=engine):
- self.engine = engine
- def add_token(self, data: dict):
- # 剔除不需要的字段
- cleaned_data = {
- k: v for k, v in data.items()
- if k not in ["log_id", "error_code", "captcha", "desc_url", "description"]
- }
-
- # 添加或更新时间戳
- cleaned_data['update_time'] = func.now()
-
- # 构造插入语句
- insert_stmt = insert(UserOAuthToken).values(**cleaned_data)
- update_stmt = insert_stmt.on_conflict_do_update(
- index_elements=['open_id'], # 使用 open_id 作为冲突的目标列
- set_={
- **{k: insert_stmt.excluded[k] for k in cleaned_data if k != "open_id"},
- "update_time": func.now() # 更新时间戳
- }
- )
-
- # 执行插入/更新操作
- with Session(self.engine) as session:
- result = session.exec(update_stmt) # 注意:这里应该是 execute 而不是 exec
- session.commit()
- logger.debug(f"Record added/updated: Access Token, Open ID - {cleaned_data['open_id']}")
- def delete_token(self, token_id: int):
- with Session(self.engine) as session:
- token = session.get(UserOAuthToken, token_id)
- if token:
- session.delete(token)
- session.commit()
- print(f"Record deleted: ID - {token_id}")
- else:
- print(f"Record with ID {token_id} not found")
- def display_all_records(self):
- with Session(self.engine) as session:
- statement = select(UserOAuthToken)
- user_tokens = session.exec(statement).all()
- return user_tokens
- def main():
- db_manager = UserOAuthRepository()
- data = {'access_token': 'act.3.wl8L3DFQ3sj3uKYzQShOSs8HbOgKh0FVvjxKeaTum0ZOEXoyBI8D1N7gTBqGbrY32KP-Pm41EAvcobSheOBi8tvRdhj7m5-5ZVoprZZu_GN5J2KnH2fZ_X9_l7Q6iFyvpPoMkX3Zyom3PCkeRZp4Jg9sE2ZiwuvZVdnvft0A25uBWXvj2IEbWW_0Bf8=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '20240129123749239735B0529965BC6D93', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
- db_manager.add_token(data)
- res = db_manager.display_all_records()
- logger.debug(res)
- if __name__ == "__main__":
- main()
|