|
|
@@ -1,8 +1,16 @@
|
|
|
+from datetime import datetime
|
|
|
from typing import Optional
|
|
|
+import os
|
|
|
+import sys
|
|
|
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
|
|
|
|
|
-from sqlmodel import Field, SQLModel,create_engine,Session
|
|
|
-
|
|
|
-
|
|
|
+from sqlmodel import Field, SQLModel,create_engine,Session,select,func
|
|
|
+import psycopg2
|
|
|
+from config import DB_URL,logger
|
|
|
+from douyin.access_token import get_access_token
|
|
|
+# from db.common import engine
|
|
|
+from sqlalchemy import UniqueConstraint, Index
|
|
|
+from sqlalchemy.dialects.postgresql import insert
|
|
|
|
|
|
# 定义数据库模型
|
|
|
class UserOAuthToken(SQLModel, table=True):
|
|
|
@@ -12,15 +20,124 @@ class UserOAuthToken(SQLModel, table=True):
|
|
|
open_id:str
|
|
|
refresh_expires_in: Optional[int] = None
|
|
|
refresh_token:str
|
|
|
+ scope: str
|
|
|
+ update_time: datetime = Field(default_factory=datetime.now) # 添加时间戳字段
|
|
|
+ __table_args__ = (UniqueConstraint('open_id'),)
|
|
|
+
|
|
|
+class UserInfo(SQLModel, table=True):
|
|
|
+ id: Optional[int] = Field(default=None, primary_key=True)
|
|
|
+ avatar: str
|
|
|
+ avatar_larger: str
|
|
|
+ client_key: str
|
|
|
+ e_account_role: str = Field(default="")
|
|
|
+ nickname: str
|
|
|
+ open_id: str
|
|
|
+ union_id: str
|
|
|
+ update_time: datetime = Field(default_factory=datetime.now)
|
|
|
+ __table_args__ = (UniqueConstraint('open_id'),)
|
|
|
+
|
|
|
+
|
|
|
+engine = create_engine(DB_URL) # 替换成你的 DB_URL
|
|
|
+SQLModel.metadata.create_all(engine)
|
|
|
+
|
|
|
+class UserInfoRepository:
|
|
|
+ def __init__(self, engine=engine):
|
|
|
+ self.engine = engine
|
|
|
+
|
|
|
+ def create_user_info(self, user_info_data):
|
|
|
+ # 剔除不需要的字段
|
|
|
+ cleaned_data = {k: v for k, v in user_info_data.items() if k not in ["log_id", "error_code"]}
|
|
|
+
|
|
|
+ # 添加或更新时间戳
|
|
|
+ cleaned_data['update_time'] = func.now()
|
|
|
+
|
|
|
+ with Session(self.engine) as session:
|
|
|
+ # 使用 on_conflict_do_update 处理 open_id 的冲突
|
|
|
+ insert_stmt = insert(UserInfo).values(**cleaned_data)
|
|
|
+ update_stmt = insert_stmt.on_conflict_do_update(
|
|
|
+ constraint="open_id", # 使用 open_id 作为冲突约束
|
|
|
+ set_={**{k: cleaned_data[k] for k in cleaned_data if k != "open_id"}, "update_time": func.now()} # 更新其他字段,包括时间戳
|
|
|
+ )
|
|
|
+ result = session.exec(update_stmt)
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ def get_user_info_by_open_id(self, open_id):
|
|
|
+ with Session(self.engine) as session:
|
|
|
+ statement = select(UserInfo).where(UserInfo.open_id == open_id)
|
|
|
+ result = session.exec(statement)
|
|
|
+ return result.first()
|
|
|
+
|
|
|
+ def update_user_info(self, user_id, user_info_data):
|
|
|
+ with Session(self.engine) as session:
|
|
|
+ update_user_info = session.get(UserInfo, user_id)
|
|
|
+ if update_user_info:
|
|
|
+ for key, value in user_info_data.items():
|
|
|
+ setattr(update_user_info, key, value)
|
|
|
+ session.commit()
|
|
|
+ return update_user_info
|
|
|
+
|
|
|
+ def delete_user_info(self, user_id):
|
|
|
+ with Session(self.engine) as session:
|
|
|
+ delete_user_info = session.get(UserInfo, user_id)
|
|
|
+ if delete_user_info:
|
|
|
+ session.delete(delete_user_info)
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+# Database manager class
|
|
|
+class UserOAuthRepository:
|
|
|
+ def __init__(self, engine=engine):
|
|
|
+ self.engine = engine
|
|
|
+
|
|
|
+ def add_token(self, data: dict):
|
|
|
+ # 剔除不需要的字段
|
|
|
+ cleaned_data = {
|
|
|
+ k: v for k, v in data.items()
|
|
|
+ if k not in ["log_id", "error_code", "captcha", "desc_url", "description"]
|
|
|
+ }
|
|
|
+
|
|
|
+ # 添加或更新时间戳
|
|
|
+ cleaned_data['update_time'] = func.now()
|
|
|
+
|
|
|
+ # 构造插入语句
|
|
|
+ insert_stmt = insert(UserOAuthToken).values(**cleaned_data)
|
|
|
+ update_stmt = insert_stmt.on_conflict_do_update(
|
|
|
+ index_elements=['open_id'], # 使用 open_id 作为冲突的目标列
|
|
|
+ set_={
|
|
|
+ **{k: insert_stmt.excluded[k] for k in cleaned_data if k != "open_id"},
|
|
|
+ "update_time": func.now() # 更新时间戳
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ # 执行插入/更新操作
|
|
|
+ with Session(self.engine) as session:
|
|
|
+ result = session.exec(update_stmt) # 注意:这里应该是 execute 而不是 exec
|
|
|
+ session.commit()
|
|
|
+ logger.debug(f"Record added/updated: Access Token, Open ID - {cleaned_data['open_id']}")
|
|
|
+
|
|
|
|
|
|
+ def delete_token(self, token_id: int):
|
|
|
+ with Session(self.engine) as session:
|
|
|
+ token = session.get(UserOAuthToken, token_id)
|
|
|
+ if token:
|
|
|
+ session.delete(token)
|
|
|
+ session.commit()
|
|
|
+ print(f"Record deleted: ID - {token_id}")
|
|
|
+ else:
|
|
|
+ print(f"Record with ID {token_id} not found")
|
|
|
|
|
|
-test=UserOAuthToken(access_token="dfgfaha",open_id="dasfga",refresh_token="fagsdgas");
|
|
|
-DATABASE_URL = "postgresql:///test.db"
|
|
|
-engine = create_engine(DATABASE_URL)
|
|
|
-SQLModel.metadata.create_all(engine)
|
|
|
+ def display_all_records(self):
|
|
|
+ with Session(self.engine) as session:
|
|
|
+ statement = select(UserOAuthToken)
|
|
|
+ user_tokens = session.exec(statement).all()
|
|
|
+ return user_tokens
|
|
|
|
|
|
|
|
|
-with Session(engine) as session:
|
|
|
- session.add(test)
|
|
|
- session.commit()
|
|
|
+def main():
|
|
|
+ db_manager = UserOAuthRepository()
|
|
|
+ data = {'access_token': 'act.3.wl8L3DFQ3sj3uKYzQShOSs8HbOgKh0FVvjxKeaTum0ZOEXoyBI8D1N7gTBqGbrY32KP-Pm41EAvcobSheOBi8tvRdhj7m5-5ZVoprZZu_GN5J2KnH2fZ_X9_l7Q6iFyvpPoMkX3Zyom3PCkeRZp4Jg9sE2ZiwuvZVdnvft0A25uBWXvj2IEbWW_0Bf8=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '20240129123749239735B0529965BC6D93', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
|
|
|
+ db_manager.add_token(data)
|
|
|
+ res = db_manager.display_all_records()
|
|
|
+ logger.debug(res)
|
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|