|
|
@@ -1,121 +1,65 @@
|
|
|
import datetime
|
|
|
+from typing import List
|
|
|
from typing import Optional
|
|
|
import os
|
|
|
import sys
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
|
|
-from db.engine import engine
|
|
|
+from db.engine import engine,create_all_table
|
|
|
|
|
|
-from sqlmodel import Field, SQLModel,create_engine,Session,select,func,Column
|
|
|
+from sqlmodel import Field, SQLModel,Relationship,create_engine,Session,select,func,Column
|
|
|
import psycopg2
|
|
|
from config import DB_URL,logger
|
|
|
# from db.common import engine
|
|
|
from sqlalchemy import UniqueConstraint, Index, asc
|
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
|
from db.base import DouyinBaseRepository
|
|
|
-
|
|
|
-# 定义数据库模型,不推荐使用 __tablename__ = '"UserInfo"' 来定义包含大写的表名字,
|
|
|
-# 因为可能会导致与其他数据库系统不兼容,而且表查询的时候需要额外注意表格名的大小写
|
|
|
-class UserOAuthToken(SQLModel, table=True):
|
|
|
- id: Optional[int] = Field(primary_key=True)
|
|
|
- access_token:Optional[str] = Field(nullable=True)
|
|
|
- expires_in: Optional[int] = Field(nullable=True)
|
|
|
- expires_at: Optional[datetime.datetime] = Field(nullable=True)
|
|
|
- open_id:str = Field(index=True, unique=True)
|
|
|
- refresh_token:Optional[str] = Field(nullable=True)
|
|
|
- refresh_expires_in: Optional[int] = Field(nullable=True)
|
|
|
- refresh_expires_at: Optional[datetime.datetime]
|
|
|
- scope: Optional[str] = Field(nullable=True)
|
|
|
- update_time: datetime.datetime = Field(default_factory=datetime.datetime.now,nullable=True) # 添加时间戳字段
|
|
|
-
|
|
|
-class UserInfo(SQLModel, table=True):
|
|
|
- id: Optional[int] = Field(primary_key=True)
|
|
|
- avatar: str
|
|
|
- avatar_larger: str
|
|
|
- client_key: str
|
|
|
- e_account_role: str = Field(default="")
|
|
|
- nickname: str
|
|
|
- # 外键约束有助于:级联操作、避免冗余、数据完整性
|
|
|
- open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)
|
|
|
- union_id: str
|
|
|
- update_time: datetime.datetime = Field(default_factory=datetime.datetime.now) # 添加时间戳字段
|
|
|
- __table_args__ = (UniqueConstraint('open_id'),)
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-class UserInfoRepository(DouyinBaseRepository):
|
|
|
- def __init__(self, engine=engine):
|
|
|
- super().__init__(UserInfo, engine)
|
|
|
- self.model:UserInfo
|
|
|
-
|
|
|
+from db.user_info import UserInfo,UserInfoRepository,UserInfoLink
|
|
|
+from db.user_oauth import UserOAuthToken,UserOAuthTokenLink,UserOAuthRepository
|
|
|
+from db.video_data import VideoUserLink,VideoData
|
|
|
+from db.docs import Documents,DocumentsLink
|
|
|
+# 避免循环导入问题,参考 SQLmodel 官方: https://sqlmodel.tiangolo.com/tutorial/code-structure/
|
|
|
+# from typing import TYPE_CHECKING, Optional
|
|
|
+# if TYPE_CHECKING:
|
|
|
+# from .user_oauth import UserOAuthToken
|
|
|
+# from .user_info import UserInfo,LinkUserInfo,UserInfoRepository
|
|
|
+# from .docs import Documents
|
|
|
|
|
|
+class User(SQLModel, table=True):
|
|
|
+ id: Optional[int] = Field(default=None, primary_key=True)
|
|
|
+ open_id:str = Field(index=True, unique=True)
|
|
|
+ info:Optional[UserInfo] = Relationship(back_populates="user", link_model=UserInfoLink)
|
|
|
+ oauth: Optional[UserOAuthToken] = Relationship(back_populates="user", link_model=UserOAuthTokenLink)
|
|
|
+ docs:List[Documents] = Relationship(back_populates="user", link_model=DocumentsLink)
|
|
|
+ video_data:List[VideoData] = Relationship(back_populates="user", link_model=VideoUserLink)
|
|
|
|
|
|
-# Database manager class
|
|
|
-class UserOAuthRepository(DouyinBaseRepository):
|
|
|
- def __init__(self, engine=engine):
|
|
|
- super().__init__(UserOAuthToken, engine)
|
|
|
- self.model:UserOAuthToken
|
|
|
-
|
|
|
- def save_login_data(self, data: dict):
|
|
|
- access_token_expires_in = datetime.datetime.now() + datetime.timedelta(seconds=data.get("expires_in"))
|
|
|
- refresh_token_expires_in = datetime.datetime.now() + datetime.timedelta(seconds=data.get("refresh_expires_in"))
|
|
|
-
|
|
|
- model:UserOAuthToken = self.dict_to_model(data)
|
|
|
- model.expires_at = access_token_expires_in
|
|
|
- model.refresh_expires_at = refresh_token_expires_in
|
|
|
- return super().add_or_update(model)
|
|
|
-
|
|
|
- # field = UserOAuthToken.expires_at | UserOAuthToken.refresh_expires_at
|
|
|
- def select_nearest_expire(self, field: Column) -> UserOAuthToken:
|
|
|
- with Session(bind=self.engine) as session:
|
|
|
- statement = select(self.model).where(field > datetime.datetime.now()).order_by(asc(field)).limit(1)
|
|
|
- res = session.exec(statement)
|
|
|
- return res.first()
|
|
|
+class UserRepo(DouyinBaseRepository):
|
|
|
+ def __init__(self, model: SQLModel=User):
|
|
|
+ super().__init__(model, engine)
|
|
|
+ self.model:User
|
|
|
|
|
|
- def update_refresh_token(self, open_id:str, refresh_token:str, refresh_expires_in:int):
|
|
|
- refresh_token_expires_at = datetime.datetime.now() + datetime.timedelta(seconds=refresh_expires_in)
|
|
|
- model = UserOAuthToken(open_id=open_id, refresh_expires_in=refresh_expires_in, refresh_token=refresh_token,refresh_expires_at=refresh_token_expires_at)
|
|
|
- return self.add_or_update(model)
|
|
|
-
|
|
|
- def update_access_token(self, open_id:str, access_token:str, expires_in:int):
|
|
|
- access_token_expires_at = datetime.datetime.now() + datetime.timedelta(seconds=expires_in)
|
|
|
- model = UserOAuthToken(open_id=open_id, access_token=access_token, expires_in=expires_in,expires_at=access_token_expires_at)
|
|
|
- return self.add_or_update(model)
|
|
|
- # def add_token(self, data: dict):
|
|
|
- # clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)}
|
|
|
- # obj_model = self.model(**clean_data)
|
|
|
- # with Session(bind=self.engine) as session:
|
|
|
- # exist_obj = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
|
|
|
- # if exist_obj:
|
|
|
- # self.set_update_time(exist_obj)
|
|
|
- # dict_data = self.model_dump_by_field(obj_model, self.unique_constraint_fields)
|
|
|
- # self.set_obj_by_dict(exist_obj,dict_data)
|
|
|
- # session.commit()
|
|
|
- # return exist_obj
|
|
|
- # else:
|
|
|
- # self.create(obj_model)
|
|
|
- # session.commit()
|
|
|
- # return obj_model
|
|
|
+ def delete(self, open_id):
|
|
|
+ with Session(engine) as session:
|
|
|
+ user = session.exec(
|
|
|
+ select(User).where(User.open_id == open_id)
|
|
|
+ ).one()
|
|
|
+ session.delete(user)
|
|
|
+ logger.info(f"del {user}")
|
|
|
+ session.commit()
|
|
|
|
|
|
+ def add(self, open_id, info, oauth):
|
|
|
+ with Session(engine) as session:
|
|
|
+ user = User(open_id=info.open_id, info=info, oauth=oauth2)
|
|
|
+ session.add(user)
|
|
|
+ logger.info(f"{user}")
|
|
|
+ session.commit()
|
|
|
+ session.refresh(user)
|
|
|
+ logger.info(f"{user.info}")
|
|
|
|
|
|
- async def delete_token(self, token_id: int):
|
|
|
- async with self.session_factory() as session:
|
|
|
- statement = select(UserOAuthToken).where(UserOAuthToken.id == token_id)
|
|
|
- token = await session.execute(statement).scalars().first()
|
|
|
- if token:
|
|
|
- await session.delete(token)
|
|
|
- await session.commit()
|
|
|
- print(f"Record deleted: ID - {token_id}")
|
|
|
- else:
|
|
|
- print(f"Record with ID {token_id} not found")
|
|
|
|
|
|
|
|
|
-
|
|
|
-def test_add(open_id=None):
|
|
|
- SQLModel.metadata.create_all(engine)
|
|
|
-
|
|
|
- user_oauth = {'access_token': 'act.3.m3kiZmfxxIH95i1bHZ7Bq3Wkv_Xm5TtD3kpLGjtCr3G96WIINBKEvzlsaObrGcH4GaxTQeLZA13jkzoZhpAwPwMRqFxlVuIcxpge_-BpdFib1xHqkcFa4B-LX4zpd2YK3kDFTFfMcJXN_fZ2eByg6oqqa1OieUWcvlaVgw==', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '202402171353461C33F969CEFB511B216F', 'open_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.e4b3da8bd3ef01d880d827b11e826391OEGHiRrTLcp5zsGYP1dh6F9Bo7fg', 'scope': 'user_info,trial.whitelist'}
|
|
|
+def main():
|
|
|
+ create_all_table()
|
|
|
+ user_oauth = {'access_token': 'act.3.TIKYfvDL4GEk0_8ol5HNBcDefjYynPt904OdtIiOH8SKquga23fjE1kTkWKqB8oLnCcqRhjPAUMUWq2uECQv0Bhm6m8cgq3Np8EGjDBD-fHI_BQAvIM7EIieNvfec-l-VMRPdxuAI9Cx8ih_59NvEC-JEFfomG18Oj8ICQ==', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '202403130157319F6A7830924CBE383292', 'open_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.e4b3da8bd3ef01d880d827b11e826391OEGHiRrTLcp5zsGYP1dh6F9Bo7fg', 'scope': 'user_info,trial.whitelist,item.comment'}
|
|
|
user_oauth2 = {'access_token': 'act', 'expires_in': 19290, 'open_id': '55test2', 'refresh_expires_in': 1950, 'refresh_token': 'rft', 'scope': 'user_info,trial.whitelist'}
|
|
|
user_info = {
|
|
|
"avatar": "https://p26.douyinpic.com/aweme/100x100/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
|
|
|
@@ -136,21 +80,10 @@ def test_add(open_id=None):
|
|
|
"province": "",
|
|
|
"union_id": "123-01ae-59bd-978a-1de8566186a8"
|
|
|
}
|
|
|
- if open_id:
|
|
|
- user_oauth["open_id"] = open_id
|
|
|
- user_info["open_id"] = open_id
|
|
|
- user_info["nickname"] = "user" + open_id[:5]
|
|
|
- else:
|
|
|
- open_id = user_oauth["open_id"]
|
|
|
- db_manager = UserOAuthRepository()
|
|
|
- res = db_manager.save_login_data(user_oauth2)
|
|
|
- res = db_manager.select_nearest_expire(UserOAuthToken.expires_at)
|
|
|
-
|
|
|
- logger.debug(res)
|
|
|
- # db_user_info = UserInfoRepository()
|
|
|
- # res = db_user_info.add_or_update(user_info)
|
|
|
- # logger.debug(db_manager.get_by_open_id(open_id))
|
|
|
- return user_oauth["open_id"]
|
|
|
-
|
|
|
+ db_user = UserRepo()
|
|
|
+ info:UserInfo = db_user.dict_to_model(user_info, UserInfo)
|
|
|
+ oauth:UserOAuthToken = UserOAuthRepository().dict_to_model(user_oauth, UserOAuthToken)
|
|
|
+ oauth2 = UserOAuthRepository().dict_to_model(user_oauth2, UserOAuthToken)
|
|
|
+ db_user.delete("b9b71865-7fea-44cc-123")
|
|
|
if __name__ == "__main__":
|
|
|
- test_add()
|
|
|
+ main()
|