login.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import asyncio
  2. import datetime
  3. import os
  4. import sys
  5. sys.path.append(os.path.dirname(os.path.dirname(__file__)))
  6. import jwt
  7. from fastapi import FastAPI,APIRouter, HTTPException, Depends, Request,Header
  8. from fastapi import Depends, FastAPI, HTTPException, status
  9. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  10. from fastapi.responses import JSONResponse
  11. from config import *
  12. from douyin.access_token import get_access_token
  13. from douyin.user_info import get_user_info
  14. from db.user_oauth import UserOAuthRepository,UserOAuthToken
  15. from db.user_info import UserInfoRepository,UserInfo
  16. from db.user import User,UserRepo
  17. from db.base import update_from_model
  18. from api.swl_jwt import verify_jwt_token,verify_user
  19. from sqlmodel import Session,select
  20. from db.engine import engine,create_all_table
  21. from pydantic import BaseModel
  22. login_router = APIRouter()
  23. # code=676a1101ea02bc5dTaUVtKg8c5enYaGqB4dT 只能被使用一次,用完失效
  24. # scopes=user_info,trial.whitelist 用户授权的范围
  25. class ScanCode(BaseModel):
  26. code: str
  27. scopes: str
  28. async def save_login_data(data:dict):
  29. oauth_model:UserOAuthToken = UserOAuthRepository().dict_to_model(data)
  30. user_info_data = await get_user_info(oauth_model.open_id, oauth_model.access_token)
  31. if not user_info_data.get("error_code"):
  32. info_model = UserInfoRepository().dict_to_model(user_info_data)
  33. else:
  34. info_model = UserInfo()
  35. logger.debug(f"get oauth: {data}")
  36. logger.debug(f"get info: {user_info_data}")
  37. UserRepo().add_or_update(oauth_model, info_model)
  38. # 计算过期时间戳(基于北京时间)
  39. # expires_in = data.get("refresh_expires_in", 1296000)
  40. # # expires_in = 15
  41. # expiration_time_local = datetime.datetime.now() + datetime.timedelta(seconds=expires_in)
  42. # exp = int(expiration_time_local.timestamp())
  43. # db_manager = UserOAuthRepository()
  44. # oauth_model:UserOAuthToken = db_manager.save_login_data(data)
  45. # data = await get_user_info(oauth_model.open_id, oauth_model.access_token)
  46. # if data.get("error_code") != 0:
  47. # raise HTTPException(status_code=400, detail=data)
  48. # db_user = UserInfoRepository()
  49. # user_info_model = db_user.dict_to_model(data)
  50. # db_user.add_or_update(user_info_model)
  51. # 生成并返回 token,包含过期时间
  52. expiration_time_local = datetime.datetime.now() + datetime.timedelta(days=90)
  53. exp = int(expiration_time_local.timestamp())
  54. payload = {
  55. "sub": data["open_id"],
  56. "exp": exp
  57. }
  58. account_token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")
  59. logger.info(f"login success, expires_time:{datetime.datetime.fromtimestamp(exp).strftime('%Y-%m-%d %H:%M:%S') }, token:{account_token}")
  60. return {"token": account_token}
  61. # 登录端点
  62. @login_router.post("/login")
  63. async def login(data: ScanCode):
  64. logger.info(data)
  65. data = await get_access_token(data.code)
  66. if data.get("error_code") != 0:
  67. raise HTTPException(status_code=400, detail=data)
  68. return await save_login_data(data)
  69. @login_router.get("/user_info")
  70. async def user_info(user: User = Depends(verify_user)) -> UserInfo:
  71. user_info_data = await get_user_info(user.open_id, user.oauth.access_token)
  72. if not user_info_data.get("error_code"):
  73. info_model:UserInfo = UserInfoRepository().dict_to_model(user_info_data)
  74. # 如果用户修改了昵称,同步 open-douyin 的用户数据
  75. if info_model.nickname != user.info.nickname:
  76. with Session(engine) as session:
  77. update_from_model(user.info, info_model)
  78. session.add(user)
  79. session.commit()
  80. return user.info
  81. @login_router.get("/verify_callback")
  82. async def verify_callback(code:str, scopes:str):
  83. return await login(ScanCode(code=code, scopes=scopes))
  84. @login_router.get("/token")
  85. async def read_account(open_id: str = Depends(verify_jwt_token)):
  86. pass
  87. # 启动应用
  88. async def main():
  89. create_all_table()
  90. data = {
  91. "access_token": "1act.f7094fbffab2ecbfc45e9af9c32bc241oYdckvBKe82BPx8T******",
  92. "captcha": "",
  93. "desc_url": "",
  94. "description": "",
  95. "error_code": 0,
  96. "expires_in": 1296000,
  97. "log_id": "20230525105733ED3ED7AC56A******",
  98. "open_id": "b9b71865-7fea-44cc-123",
  99. "refresh_expires_in": 2592000,
  100. "refresh_token": "rft.713900b74edde9f30ec4e246b706da30t******",
  101. "scope": "user_info"
  102. }
  103. user_oauth = {'access_token': 'act.3.m3kiZmfxxIH95i1bHZ7Bq3Wkv_Xm5TtD3kpLGjtCr3G96WIINBKEvzlsaObrGcH4GaxTQeLZA13jkzoZhpAwPwMRqFxlVuIcxpge_-BpdFib1xHqkcFa4B-LX4zpd2YK3kDFTFfMcJXN_fZ2eByg6oqqa1OieUWcvlaVgw==', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '202402171353461C33F969CEFB511B216F', 'open_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.e4b3da8bd3ef01d880d827b11e826391OEGHiRrTLcp5zsGYP1dh6F9Bo7fg', 'scope': 'user_info,trial.whitelist'}
  104. res = await save_login_data(user_oauth)
  105. logger.info(f"{res}")
  106. # import jwt
  107. if __name__ == "__main__":
  108. asyncio.run(main())