login.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import asyncio
  2. import datetime
  3. import os
  4. import sys
  5. sys.path.append(os.path.dirname(os.path.dirname(__file__)))
  6. import jwt
  7. from fastapi import FastAPI,APIRouter, HTTPException, Depends, Request,Header
  8. from fastapi import Depends, FastAPI, HTTPException, status
  9. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  10. from pydantic import BaseModel
  11. from fastapi.responses import JSONResponse
  12. from config import *
  13. from douyin.access_token import get_access_token
  14. from douyin.user_info import get_user_info
  15. from db.user_oauth import UserOAuthRepository,UserOAuthToken
  16. from db.user_info import UserInfoRepository,UserInfo
  17. from db.user import User,UserRepo
  18. from db.base import update_from_model
  19. from api.swl_jwt import verify_jwt_token,verify_user
  20. from sqlmodel import Session,select
  21. from db.engine import engine,create_all_table
  22. login_router = APIRouter()
  23. # code=676a1101ea02bc5dTaUVtKg8c5enYaGqB4dT 只能被使用一次,用完失效
  24. # scopes=user_info,trial.whitelist 用户授权的范围
  25. class ScanCode(BaseModel):
  26. code: str
  27. scopes: str
  28. async def save_login_data(data:dict):
  29. access_token_expires_in = datetime.datetime.now() + datetime.timedelta(seconds=data.get("expires_in"))
  30. refresh_token_expires_in = datetime.datetime.now() + datetime.timedelta(seconds=data.get("refresh_expires_in"))
  31. oauth_model:UserOAuthToken = UserOAuthRepository().dict_to_model(data)
  32. oauth_model.expires_at = access_token_expires_in
  33. oauth_model.refresh_expires_at = refresh_token_expires_in
  34. user_info_data = await get_user_info(oauth_model.open_id, oauth_model.access_token)
  35. if not user_info_data.get("error_code"):
  36. info_model = UserInfoRepository().dict_to_model(user_info_data)
  37. else:
  38. info_model = UserInfo()
  39. with Session(engine) as session:
  40. user = session.exec(
  41. select(User).where(User.open_id == oauth_model.open_id)
  42. ).first()
  43. if not user:
  44. user = User(open_id=oauth_model.open_id, oauth=oauth_model, info=info_model)
  45. else:
  46. user.open_id = oauth_model.open_id
  47. update_from_model(user.oauth, oauth_model)
  48. # user.info = info_model
  49. update_from_model(user.info, info_model)
  50. logger.info(f"update: {user.oauth}")
  51. session.add(user)
  52. session.commit()
  53. # 计算过期时间戳(基于北京时间)
  54. # expires_in = data.get("refresh_expires_in", 1296000)
  55. # # expires_in = 15
  56. # expiration_time_local = datetime.datetime.now() + datetime.timedelta(seconds=expires_in)
  57. # exp = int(expiration_time_local.timestamp())
  58. # db_manager = UserOAuthRepository()
  59. # oauth_model:UserOAuthToken = db_manager.save_login_data(data)
  60. # data = await get_user_info(oauth_model.open_id, oauth_model.access_token)
  61. # if data.get("error_code") != 0:
  62. # raise HTTPException(status_code=400, detail=data)
  63. # db_user = UserInfoRepository()
  64. # user_info_model = db_user.dict_to_model(data)
  65. # db_user.add_or_update(user_info_model)
  66. # 生成并返回 token,包含过期时间
  67. expiration_time_local = datetime.datetime.now() + datetime.timedelta(days=90)
  68. exp = int(expiration_time_local.timestamp())
  69. payload = {
  70. "sub": data["open_id"],
  71. "exp": exp
  72. }
  73. account_token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")
  74. logger.info(f"login success, expires_time:{datetime.datetime.fromtimestamp(exp).strftime('%Y-%m-%d %H:%M:%S') }, token:{account_token}")
  75. return {"token": account_token}
  76. # 登录端点
  77. @login_router.post("/login")
  78. async def login(data: ScanCode):
  79. logger.info(data)
  80. data = await get_access_token(data.code)
  81. if data.get("error_code") != 0:
  82. raise HTTPException(status_code=400, detail=data)
  83. return await save_login_data(data)
  84. @login_router.get("/user_info")
  85. async def user_info(user: User = Depends(verify_user)):
  86. return await get_user_info(user.open_id, user.oauth.access_token)
  87. @login_router.get("/verify_callback")
  88. async def verify_callback(code:str, scopes:str):
  89. return await login(ScanCode(code=code, scopes=scopes))
  90. @login_router.get("/token")
  91. async def read_account(open_id: str = Depends(verify_jwt_token)):
  92. pass
  93. # 启动应用
  94. async def main():
  95. create_all_table()
  96. data = {
  97. "access_token": "1act.f7094fbffab2ecbfc45e9af9c32bc241oYdckvBKe82BPx8T******",
  98. "captcha": "",
  99. "desc_url": "",
  100. "description": "",
  101. "error_code": 0,
  102. "expires_in": 1296000,
  103. "log_id": "20230525105733ED3ED7AC56A******",
  104. "open_id": "b9b71865-7fea-44cc-123",
  105. "refresh_expires_in": 2592000,
  106. "refresh_token": "rft.713900b74edde9f30ec4e246b706da30t******",
  107. "scope": "user_info"
  108. }
  109. user_oauth = {'access_token': 'act.3.m3kiZmfxxIH95i1bHZ7Bq3Wkv_Xm5TtD3kpLGjtCr3G96WIINBKEvzlsaObrGcH4GaxTQeLZA13jkzoZhpAwPwMRqFxlVuIcxpge_-BpdFib1xHqkcFa4B-LX4zpd2YK3kDFTFfMcJXN_fZ2eByg6oqqa1OieUWcvlaVgw==', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '202402171353461C33F969CEFB511B216F', 'open_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.e4b3da8bd3ef01d880d827b11e826391OEGHiRrTLcp5zsGYP1dh6F9Bo7fg', 'scope': 'user_info,trial.whitelist'}
  110. res = await save_login_data(user_oauth)
  111. logger.info(f"{res}")
  112. # import jwt
  113. if __name__ == "__main__":
  114. asyncio.run(main())