|
|
@@ -10,18 +10,21 @@ from config import DB_URL,logger
|
|
|
# from db.common import engine
|
|
|
from sqlalchemy import UniqueConstraint, Index
|
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
|
+from db.base import BaseRepository,DouyinBaseRepository
|
|
|
+from db.engine import engine,create_all
|
|
|
|
|
|
-# 定义数据库模型
|
|
|
+# 定义数据库模型,不推荐使用 __tablename__ = '"UserInfo"' 来定义包含大写的表名字,
|
|
|
+# 因为可能会导致与其他数据库系统不兼容,而且表查询的时候需要额外注意表格名的大小写
|
|
|
class UserOAuthToken(SQLModel, table=True):
|
|
|
id: Optional[int] = Field(default=None, primary_key=True)
|
|
|
access_token:str
|
|
|
expires_in: Optional[int] = None
|
|
|
- open_id:str
|
|
|
+ open_id:str = Field(index=True)
|
|
|
refresh_expires_in: Optional[int] = None
|
|
|
refresh_token:str
|
|
|
scope: str
|
|
|
- update_time: datetime = Field(default_factory=datetime.now) # 添加时间戳字段
|
|
|
- __table_args__ = (UniqueConstraint('open_id'),)
|
|
|
+ update_time: datetime = Field(default_factory=datetime.now) # 添加时间戳字段
|
|
|
+ __table_args__ = (UniqueConstraint('open_id'),)
|
|
|
|
|
|
class UserInfo(SQLModel, table=True):
|
|
|
id: Optional[int] = Field(default=None, primary_key=True)
|
|
|
@@ -30,121 +33,90 @@ class UserInfo(SQLModel, table=True):
|
|
|
client_key: str
|
|
|
e_account_role: str = Field(default="")
|
|
|
nickname: str
|
|
|
- open_id: str
|
|
|
+ # 外键约束有助于:级联操作、避免冗余、数据完整性
|
|
|
+ open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)
|
|
|
union_id: str
|
|
|
update_time: datetime = Field(default_factory=datetime.now)
|
|
|
__table_args__ = (UniqueConstraint('open_id'),)
|
|
|
|
|
|
-
|
|
|
-engine = create_engine(DB_URL) # 替换成你的 DB_URL
|
|
|
-SQLModel.metadata.create_all(engine)
|
|
|
|
|
|
-class UserInfoRepository:
|
|
|
+class UserInfoRepository(DouyinBaseRepository):
|
|
|
def __init__(self, engine=engine):
|
|
|
- self.engine = engine
|
|
|
-
|
|
|
- def create_user_info(self, user_info_data):
|
|
|
- # 剔除不需要的字段
|
|
|
- cleaned_data = {k: v for k, v in user_info_data.items() if k not in ["log_id", "error_code"]}
|
|
|
-
|
|
|
- # 添加或更新时间戳
|
|
|
- cleaned_data['update_time'] = func.now()
|
|
|
-
|
|
|
- with Session(self.engine) as session:
|
|
|
- # 使用 on_conflict_do_update 处理 open_id 的冲突
|
|
|
- insert_stmt = insert(UserInfo).values(**cleaned_data)
|
|
|
- update_stmt = insert_stmt.on_conflict_do_update(
|
|
|
- constraint="open_id", # 使用 open_id 作为冲突约束
|
|
|
- set_={**{k: cleaned_data[k] for k in cleaned_data if k != "open_id"}, "update_time": func.now()} # 更新其他字段,包括时间戳
|
|
|
- )
|
|
|
- result = session.exec(update_stmt)
|
|
|
- session.commit()
|
|
|
-
|
|
|
- def get_user_info_by_open_id(self, open_id):
|
|
|
- with Session(self.engine) as session:
|
|
|
- statement = select(UserInfo).where(UserInfo.open_id == open_id)
|
|
|
- result = session.exec(statement)
|
|
|
- return result.first()
|
|
|
-
|
|
|
- def update_user_info(self, user_id, user_info_data):
|
|
|
- with Session(self.engine) as session:
|
|
|
- update_user_info = session.get(UserInfo, user_id)
|
|
|
+ super().__init__(UserInfo, engine)
|
|
|
+ self.model:UserInfo
|
|
|
+
|
|
|
+ async def create_user_info(self, user_info_data):
|
|
|
+ return await self.aadd_or_update(user_info_data)
|
|
|
+
|
|
|
+ async def update_user_info(self, user_id, user_info_data):
|
|
|
+ async with self.session_factory() as session:
|
|
|
+ update_user_info = await session.get(UserInfo, user_id)
|
|
|
if update_user_info:
|
|
|
for key, value in user_info_data.items():
|
|
|
setattr(update_user_info, key, value)
|
|
|
- session.commit()
|
|
|
+ await session.commit()
|
|
|
return update_user_info
|
|
|
|
|
|
- def delete_user_info(self, user_id):
|
|
|
- with Session(self.engine) as session:
|
|
|
- delete_user_info = session.get(UserInfo, user_id)
|
|
|
+ async def delete_user_info(self, user_id):
|
|
|
+ async with self.session_factory() as session:
|
|
|
+ delete_user_info = await session.get(UserInfo, user_id)
|
|
|
if delete_user_info:
|
|
|
- session.delete(delete_user_info)
|
|
|
- session.commit()
|
|
|
+ await session.delete(delete_user_info)
|
|
|
+ await session.commit()
|
|
|
|
|
|
# Database manager class
|
|
|
-class UserOAuthRepository:
|
|
|
- def __init__(self, engine=engine):
|
|
|
- self.engine = engine
|
|
|
+class UserOAuthRepository(DouyinBaseRepository):
|
|
|
+ def __init__(self, engine=engine):
|
|
|
+ super().__init__(UserOAuthToken, engine)
|
|
|
+ self.model:UserOAuthToken
|
|
|
|
|
|
- def add_token(self, data: dict):
|
|
|
- # 剔除不需要的字段
|
|
|
- cleaned_data = {
|
|
|
- k: v for k, v in data.items()
|
|
|
- if k not in ["log_id", "error_code", "captcha", "desc_url", "description"]
|
|
|
- }
|
|
|
-
|
|
|
- # 添加或更新时间戳
|
|
|
- cleaned_data['update_time'] = func.now()
|
|
|
-
|
|
|
- # 构造插入语句
|
|
|
- insert_stmt = insert(UserOAuthToken).values(**cleaned_data)
|
|
|
- update_stmt = insert_stmt.on_conflict_do_update(
|
|
|
- index_elements=['open_id'], # 使用 open_id 作为冲突的目标列
|
|
|
- set_={
|
|
|
- **{k: insert_stmt.excluded[k] for k in cleaned_data if k != "open_id"},
|
|
|
- "update_time": func.now() # 更新时间戳
|
|
|
- }
|
|
|
- )
|
|
|
-
|
|
|
- # 执行插入/更新操作
|
|
|
- with Session(self.engine) as session:
|
|
|
- result = session.exec(update_stmt) # 注意:这里应该是 execute 而不是 exec
|
|
|
- session.commit()
|
|
|
- logger.debug(f"Record added/updated: Access Token, Open ID - {cleaned_data['open_id']}")
|
|
|
+ async def add_token(self, data: dict):
|
|
|
+ return await self.aadd_or_update(data)
|
|
|
|
|
|
|
|
|
- def delete_token(self, token_id: int):
|
|
|
- with Session(self.engine) as session:
|
|
|
- token = session.get(UserOAuthToken, token_id)
|
|
|
- if token:
|
|
|
- session.delete(token)
|
|
|
- session.commit()
|
|
|
- print(f"Record deleted: ID - {token_id}")
|
|
|
- else:
|
|
|
- print(f"Record with ID {token_id} not found")
|
|
|
+ async def delete_token(self, token_id: int):
|
|
|
+ async with self.session_factory() as session:
|
|
|
+ statement = select(UserOAuthToken).where(UserOAuthToken.id == token_id)
|
|
|
+ token = await session.execute(statement).scalars().first()
|
|
|
+ if token:
|
|
|
+ await session.delete(token)
|
|
|
+ await session.commit()
|
|
|
+ print(f"Record deleted: ID - {token_id}")
|
|
|
+ else:
|
|
|
+ print(f"Record with ID {token_id} not found")
|
|
|
|
|
|
- def display_all_records(self):
|
|
|
- with Session(self.engine) as session:
|
|
|
- statement = select(UserOAuthToken)
|
|
|
- user_tokens = session.exec(statement).all()
|
|
|
- return user_tokens
|
|
|
|
|
|
- # 根据 open_id 获取模型中某一行
|
|
|
- def get_by_open_id(self, open_id):
|
|
|
- with Session(self.engine) as session:
|
|
|
- statement = select(UserOAuthToken).where(UserOAuthToken.open_id == open_id)
|
|
|
- result = session.exec(statement)
|
|
|
- return result.first()
|
|
|
|
|
|
-
|
|
|
-def main():
|
|
|
+async def test_add():
|
|
|
+ await create_all()
|
|
|
+
|
|
|
+ user_oauth = {'access_token': 'act.3.wl8L3DFQ3sj3uKYzQShOSs8HbOgKh0FVvjxKeaTum0ZOEXoyBI8D1N7gTBqGbrY32KP-Pm41EAvcobSheOBi8tvRdhj7m5-5ZVoprZZu_GN5J2KnH2fZ_X9_l7Q6iFyvpPoMkX3Zyom3PCkeRZp4Jg9sE2ZiwuvZVdnvft0A25uBWXvj2IEbWW_0Bf8=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '20240129123749239735B0529965BC6D93', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
|
|
|
+ user_info = {
|
|
|
+ "avatar": "https://p26.douyinpic.com/aweme/100x100/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
|
|
|
+ "avatar_larger": "https://p3.douyinpic.com/aweme/1080x1080/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
|
|
|
+ "captcha": "",
|
|
|
+ "city": "",
|
|
|
+ "client_key": "123",
|
|
|
+ "country": "",
|
|
|
+ "desc_url": "",
|
|
|
+ "description": "",
|
|
|
+ "district": "",
|
|
|
+ "e_account_role": "",
|
|
|
+ "error_code": 0,
|
|
|
+ "gender": 0,
|
|
|
+ "log_id": "202401261424326FE877A6CAB03910C553",
|
|
|
+ "nickname": "程序员马工",
|
|
|
+ "open_id": "_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy",
|
|
|
+ "province": "",
|
|
|
+ "union_id": "123-01ae-59bd-978a-1de8566186a8"
|
|
|
+ }
|
|
|
db_manager = UserOAuthRepository()
|
|
|
- data = {'access_token': 'act.3.wl8L3DFQ3sj3uKYzQShOSs8HbOgKh0FVvjxKeaTum0ZOEXoyBI8D1N7gTBqGbrY32KP-Pm41EAvcobSheOBi8tvRdhj7m5-5ZVoprZZu_GN5J2KnH2fZ_X9_l7Q6iFyvpPoMkX3Zyom3PCkeRZp4Jg9sE2ZiwuvZVdnvft0A25uBWXvj2IEbWW_0Bf8=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '20240129123749239735B0529965BC6D93', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
|
|
|
- # db_manager.add_token(data)
|
|
|
- # res = db_manager.display_all_records()
|
|
|
- res = db_manager.get_from_id("_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy", "access_token")
|
|
|
+ res = await db_manager.add_token(user_oauth)
|
|
|
+ db_user_info = UserInfoRepository()
|
|
|
+ res = await db_user_info.create_user_info(user_info)
|
|
|
logger.debug(res)
|
|
|
+ return user_oauth["open_id"]
|
|
|
|
|
|
-if __name__ == "__main__":
|
|
|
- main()
|
|
|
+if __name__ == "__main__":
|
|
|
+ import asyncio
|
|
|
+ asyncio.run(test_add())
|