qyl 1 рік тому
батько
коміт
5ca3531385
14 змінених файлів з 540 додано та 167 видалено
  1. 2 1
      .gitignore
  2. 33 6
      api/jwt.py
  3. 16 25
      api/login.py
  4. 64 0
      api/updload.py
  5. 24 15
      config.py
  6. 153 0
      db/base.py
  7. 0 6
      db/common.py
  8. 132 0
      db/docs.py
  9. 11 0
      db/engine.py
  10. 71 99
      db/user.py
  11. 5 1
      douyin/access_token.py
  12. 9 6
      main.py
  13. 16 8
      test/config.py
  14. 4 0
      待办事项.md

+ 2 - 1
.gitignore

@@ -1,3 +1,4 @@
 __pycache__
 demo-release
-log
+log
+test

+ 33 - 6
api/jwt.py

@@ -1,8 +1,27 @@
 from fastapi import Depends, HTTPException, status, Header, Security  
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials  
 import jwt  
-from config import JWT_SECRET_KEY
-  
+from config import SECRET_KEY
+
+security = HTTPBearer()
+async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):  
+    if credentials:  
+        try:  
+            payload = await verify_jwt_token()  
+            return payload  # 或者返回一个包含用户信息的自定义对象  
+        except Exception as e:  
+            raise HTTPException(  
+                status_code=status.HTTP_403_FORBIDDEN,  
+                detail="Invalid token",  
+                headers={"WWW-Authenticate": "Bearer"},  
+            )  
+    else:  
+        raise HTTPException(  
+            status_code=status.HTTP_401_UNAUTHORIZED,  
+            detail="Unauthorized",  
+            headers={"WWW-Authenticate": "Bearer"},  
+        )  
+        
 async def get_token_from_header(authorization: str = Header(None)):  
     if not authorization:  
         raise HTTPException(  
@@ -16,11 +35,12 @@ async def get_token_from_header(authorization: str = Header(None)):
             detail="Invalid authentication scheme",  
         )  
     return authorization.replace("Bearer ", "")  
-  
+
+
 async def verify_jwt_token(token: str = Security(get_token_from_header)):  
     try:  
-        payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=["HS256"])  
-        return {"sub": payload.get("sub")}  
+        payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])  
+        return payload.get("sub")
     except jwt.ExpiredSignatureError:  
         raise HTTPException(  
             status_code=status.HTTP_403_FORBIDDEN,  
@@ -31,4 +51,11 @@ async def verify_jwt_token(token: str = Security(get_token_from_header)):
             status_code=status.HTTP_403_FORBIDDEN,  
             detail="Invalid token",  
         )  
-  
+        
+from db.user import UserOAuthRepository,UserOAuthToken
+def get_uer_oauth_and_verify(open_id: str = Depends(verify_jwt_token)):  
+    db_oauth:UserOAuthToken = UserOAuthRepository().get_by_open_id(open_id)  
+    # 没有用户凭证,需要扫码登陆
+    if not db_oauth:  
+        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="need login")  
+    return db_oauth  

+ 16 - 25
api/login.py

@@ -11,11 +11,13 @@ from fastapi.responses import JSONResponse
 from config import *
 from douyin.access_token import get_access_token
 from douyin.user_info import get_user_info
-from db.user import UserOAuthRepository,UserInfoRepository
-from api.jwt import verify_jwt_token
+from db.user import UserOAuthRepository,UserInfoRepository,UserOAuthToken
+from api.jwt import verify_jwt_token,get_uer_oauth_and_verify
 
 login_router = APIRouter()  
 
+# code=676a1101ea02bc5dTaUVtKg8c5enYaGqB4dT  只能被使用一次,用完失效
+# scopes=user_info,trial.whitelist 用户授权的范围
 class ScanCode(BaseModel):
     code: str
     scopes: str
@@ -28,17 +30,12 @@ class User(BaseModel):
 # 登录端点
 @login_router.post("/login")
 async def login(data: ScanCode):
-    if PRODUCE_ENV:
-        data = await get_access_token(data.code)
-    else:
-        # 测试环境使用。因为每次 get_access_token 的 code 只能使用一次就过期了,为了避免频繁扫码,直接模拟返回请求结果
-        data = {'access_token': 'act.3.UCzqnMwbL7uUTH0PkWbvDvIHcpy417HnfMqymbvBSpo9b1MJ3jOdwCxw-UPstOOjsGDWIdNwTGev4oEp8eUR-vHbU24XU5K4BkhPeOKJW1CLrEUS3XFxpG6SHqoQtvL6qhEgINcvt4V3KQX6C2qTeTkgQ-KwPO6jWi5uoin3YXo5DqwuGk3bbQ9dZoY=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '2024012915260549B5ED1A675515CD573C', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
-    
+    data = await get_access_token(data.code)
     if data.get("error_code") != 0:
-        return data
+        return data, status.HTTP_400_BAD_REQUEST
     
     # 计算过期时间戳(基于北京时间)  
-    expires_in = data.get("expires_in", 0)  # 如果没有 expires_in 键,则默认过期时间为 0  
+    expires_in = data.get("expires_in", 1296000)
     # expires_in = 15
     expiration_time_utc = datetime.datetime.utcnow() + datetime.timedelta(seconds=expires_in)  
     beijing_timezone_delta = datetime.timedelta(hours=8)  # 北京时间是UTC+8  
@@ -50,32 +47,26 @@ async def login(data: ScanCode):
         "sub": data["open_id"],
         "exp": exp  # 添加过期时间戳(北京时间)到 payload  
     }  
-    account_token = jwt.encode(payload, JWT_SECRET_KEY, algorithm="HS256")  
+    account_token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")  
     logger.info(f"login success, expires_time:{datetime.datetime.fromtimestamp(exp).strftime('%Y-%m-%d %H:%M:%S') }, token:{account_token}")
     return {"token": account_token}
 
 @login_router.get("/user_info")
-async def user_info(jwt: dict = Depends(verify_jwt_token)):
-    open_id = jwt.get("sub")
-    logger.info(open_id)
-    oauth = UserOAuthRepository().get_by_open_id(open_id)
-    # 没有用户凭证,需要重新登陆
-    if not oauth:
-        return {"error": "need login"}
-    info = await get_user_info(open_id, oauth.access_token)
+async def user_info(db_user_oauth: UserOAuthToken = Depends(get_uer_oauth_and_verify)):
+    info = await get_user_info(db_user_oauth.open_id, db_user_oauth.access_token)
     return info
-    
+
 # 受保护资源示例
 @login_router.get("/account")
-async def read_account(jwt: dict = Depends(verify_jwt_token)): 
-    open_id = jwt.get("aud")
+async def read_account(open_id: str = Depends(verify_jwt_token)): 
     UserOAuthRepository().display_all_records()
-    logger.info(jwt.get("aud"))
-    return {"message": "Account information", "open_id": jwt.get("aud")}
+    return {"message": "Account information", "open_id": open_id}
     # 在这里返回当前用户的信息
     return {"nickname": current_user.username, "avatar": "https://p26.douyinpic.com/aweme/100x100/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038"}
 
-# 其他受保护的资源...
+@login_router.get("/token")
+async def read_account(open_id: str = Depends(verify_jwt_token)): 
+    pass
 
 # 启动应用
 def main():

+ 64 - 0
api/updload.py

@@ -0,0 +1,64 @@
+import hashlib
+import os
+import pathlib
+import re
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))
+import jwt
+from fastapi import FastAPI,APIRouter,  File, HTTPException, Depends, Request,Header, UploadFile  
+from fastapi.responses import JSONResponse 
+from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials  
+import aiofiles
+from starlette.status import HTTP_401_UNAUTHORIZED 
+from api.jwt import verify_jwt_token,get_current_user
+from config import *
+from db.docs import DocumentsRepository
+
+upload_router = APIRouter()  
+security = HTTPBearer()
+@upload_router.post('/upload')
+async def upload(open_id=Depends(verify_jwt_token),file: UploadFile = File(...)  ):  
+    
+    res = await save_to_user_dir(open_id, file)
+    if res:
+        # 保存文件到本地
+        DocumentsRepository()
+        return {"message": "upload success"}
+    else:  
+        # 保存文件失败,返回 400 Bad Request 或其他适当的错误状态码  
+        raise HTTPException(status_code=400, detail="upload fail")
+
+def is_valid_filename(s):  
+    """检查字符串是否只包含合法的文件名字符"""  
+    # 这里的正则表达式只允许字母、数字、下划线、点、破折号和空格  
+    # 你可以根据需要调整这个正则表达式  
+    return re.match(r"^[a-zA-Z0-9_.\- ]+$", s) is not None  
+  
+def get_user_dir(open_id):  
+    if is_valid_filename(open_id):  
+        # 如果open_id合法,直接使用它作为目录名  
+        user_dir = os.path.join(DATA_DIR, open_id)  
+    else:  
+        # 否则,计算它的哈希值并用作目录名  
+        hash_object = hashlib.md5(open_id.encode())  
+        hex_dig = hash_object.hexdigest()[:8] # 只取前 8 个字符   
+        user_dir = os.path.join(DATA_DIR, "hash8_" + hex_dig)  
+      
+    # 如果目录不存在,创建它  
+    if not os.path.exists(user_dir):  
+        os.makedirs(user_dir)  
+      
+    return user_dir  
+
+    
+async def save_to_user_dir(open_id, file:UploadFile):
+    user_dir = get_user_dir(open_id)
+    file_path = os.path.join(user_dir,"docs", file.filename)
+      
+    async with aiofiles.open(file_path, "wb") as buffer:  
+        chunk = await file.read(8192)  
+        while chunk:  
+            await buffer.write(chunk)  
+            chunk = await file.read(8192)  
+        logger.info(f"{open_id} save to {file_path}")
+        return True

+ 24 - 15
config.py

@@ -3,11 +3,23 @@ import os
 import socket
 import sys
 from loguru import logger
-WORK_DIR = os.path.dirname(__file__)
 # 是否为生产环境, None 则是调试环境(开发环境)
 PRODUCE_ENV = os.environ.get("PRODUCE_ENV", None)
-os.environ["JWT_SECRET_KEY"]="123"
-os.environ["DB_URL"]="postgresql://pg:pg@sv-v:5432/douyin"
+WORK_DIR = os.path.dirname(__file__)
+LOG_FILE = os.path.join(WORK_DIR,"log", "1.log")
+
+
+FORMAT = '<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{file}</cyan>:<cyan>{line}</cyan> :<cyan>{function}</cyan> - {message}'
+LOG_LEVEL = "DEBUG"
+logger.remove()
+# logger.add(sys.stderr, format='<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>')
+logger.add(sys.stderr, format=FORMAT)
+logger.add(LOG_FILE, format=FORMAT)
+logger.info(f"load config:{ __file__}")
+
+# openssl rand -hex 32
+os.environ["SECRET_KEY"]="34581f02dcdbab9dc176d6bb578fb15cc6b8e66159865c14e1fc81cd1d92c2a6"
+os.environ["DB_URL"]="postgresql+asyncpg://pg:pg@sv-v:5432/douyin"
 os.environ["CLIENT_KEY"] = 'aw6aipmfdtplwtyq'
 os.environ["CLIENT_SECRET"] = '53cf3dcd2663629e8a773ab59df0968b'
 DOUYIN_OPEN_API="https://open.douyin.com"
@@ -16,20 +28,17 @@ DOUYIN_OPEN_API="https://open.douyin.com"
 # HOST = socket.gethostbyname(socket.gethostname())
 # 这个网址 https://open-douyin.magong.site 对应这台服务器的 192.168.1.32:8600 端口,因为这台服务器没有公网ip,所以在本地计算机无法通过  http://192.168.1.32:8600/ 访问到 fastapi 接口,只能通过 https://open-douyin.magong.site/ 访问
 HOST = '::'
-PORT = int(os.environ.get("PORT")) if os.environ.get("PORT") else 8600
-JWT_SECRET_KEY = os.environ["JWT_SECRET_KEY"]
-
-DB_URL=os.environ["DB_URL"]
 
+PORT = 8601 if os.environ.get("USER")=="mrh" else 8600
+SECRET_KEY = os.environ.get("SECRET_KEY")
 
+DB_URL=os.environ.get("DB_URL")
+# 生产环境中绝对不能使用硬编码的 DATA 路径。否则后期负载均衡、数据扩容、迁移将会造成很大影响。
+if not PRODUCE_ENV:
+    DATA_DIR="/home/user/code/open-douyin/test/data"
+else:
+    DATA_DIR = os.environ["DATA_DIR"]
+logger.info(f"API URL:{HOST}:{PORT}")
 
 
-LOG_FILE = os.path.join(WORK_DIR,"log", "1.log")
-FORMAT = '<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{file}</cyan>:<cyan>{line}</cyan> :<cyan>{function}</cyan> - {message}'
-LOG_LEVEL = "DEBUG"
-logger.remove()
-# logger.add(sys.stderr, format='<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>')
-logger.add(sys.stderr, format=FORMAT)
-logger.add(LOG_FILE, format=FORMAT)
 
-logger.info("load config:", __file__)

+ 153 - 0
db/base.py

@@ -0,0 +1,153 @@
+from typing import List, Any
+from sqlmodel import SQLModel
+from sqlalchemy.orm import sessionmaker 
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.declarative import DeclarativeMeta  
+from sqlalchemy.sql.expression import select
+from sqlalchemy.dialects.postgresql import insert
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy import UniqueConstraint
+from sqlalchemy.sql import func
+from db.engine import engine
+from config import logger
+
+class BaseRepository:  
+    def __init__(self, model: SQLModel, engine=engine):  
+        self.model = model  
+        self.engine = engine  
+        self.session_factory = sessionmaker(  
+            bind=engine, class_=AsyncSession, expire_on_commit=False  
+        )  
+    
+    '''
+    input: 
+      - instances: SQLModel 定义的 class 实例
+    return : instances
+    '''
+    async def aadd(self, instances: List[SQLModel]): 
+        if not isinstance(instances, list):  
+            instances = [instances]  
+        async with self.session_factory() as session:  
+            session.add_all(instances)  
+            await session.commit()  
+            return instances  
+  
+    '''
+    input: 
+      - instances: SQLModel 定义的字段名
+    return : instances:SQLModel
+    '''
+    async def aget(self, **kwargs):  
+        async with self.session_factory() as session:  
+            result = await session.get(self.model, **kwargs)  
+            return result  
+  
+    async def aget_all(self):  
+        async with self.session_factory() as session:  
+            result = await session.execute(select(self.model))  
+            return result.scalars().all() 
+        
+    async def aon_conflict_do_nothing(self, instances: SQLModel|List[SQLModel]) -> SQLModel|List[SQLModel]:
+        async def exec_one(session, instance,index_elements):
+            data = instance.model_dump()
+            stmt = insert(self.model).values(**data).on_conflict_do_nothing(  
+                index_elements=index_elements  # 对应联合唯一约束的列  
+            ).returning(self.model.id)
+            res = await session.execute(stmt)
+            instance.id = res.scalar()
+        
+        index_elements = self._get_unique_constraint_fields()
+        async with self.session_factory() as session:  
+            if not isinstance(instances, list):  
+                await exec_one(session, instances, index_elements)
+            else:
+                for instance in instances:  
+                   await exec_one(session, instance, index_elements)
+            await session.commit()  
+            return instances
+        
+    async def aon_conflict_do_update(
+            self,
+            instances: SQLModel | List[SQLModel],
+            update_fields: List[str],  # 需要更新的字段列表
+        ) -> SQLModel | List[SQLModel]:
+            async def exec_one(session, instance, index_elements):
+                data = instance.model_dump()
+                stmt = insert(self.model).values(**data).on_conflict_do_update(
+                    constraint=index_elements,  # 对应联合唯一约束
+                    set_=dict((k, data[k]) for k in update_fields if k in data),  # 更新指定字段
+                ).returning(self.model.id)
+                res = await session.execute(stmt)
+                instance.id = res.scalar()
+
+            index_elements = self._get_unique_constraint_fields()
+            async with self.session_factory() as session:
+                if not isinstance(instances, list):
+                    await exec_one(session, instances, index_elements)
+                else:
+                    for instance in instances:
+                        await exec_one(session, instance, index_elements)
+                await session.commit()
+                return instances
+            
+    def _get_unique_constraint_fields(self) -> List[str]:
+            constraints = getattr(self.model.__table__, 'constraints', [])
+            unique_constraints = [c for c in constraints if isinstance(c, UniqueConstraint)]
+            
+            index_elements = []
+            for uc in unique_constraints:
+                index_elements.extend([column.name for column in uc.columns])
+            
+            return list(set(index_elements))  # 去除重复字段
+
+class DouyinBaseRepository(BaseRepository):
+    def __init__(self, model: DeclarativeMeta, engine=engine):  
+        super().__init__(model, engine)  
+
+    '''
+    input: open_id:str
+    return : SQL row
+    '''
+    async def get_by_open_id(self, open_id):  
+        async with self.session_factory() as session:  
+            stmt = select(self.model).where(self.model.open_id == open_id)  
+            result = await session.execute(stmt)  
+            user_info = result.scalars().first()  # 获取查询结果的第一个记录,如果没有找到则返回 None  
+            return user_info
+    
+    '''
+    input: 
+      - data:dict  通常抖音返回数据是json格式,因此这里也用字典传参类型,如果是 SQLmodel 会自动用 data.model_dump() 方法转化成字典
+      - constraint_name 字段唯一值,如果 data 所含的字段存在于数据库则更新该行
+    return : res
+    '''
+    async def aadd_or_update(self, data: dict):  
+        if type(data) == self.model:
+            data = data.model_dump()
+        try:  
+            index_elements = self._get_unique_constraint_fields()
+            async with self.session_factory() as session:  
+                # 只获取 self.model 定义的字段
+                clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)} 
+                # logger.debug(f"clean data:{clean_data} from {self.model}") 
+                # 如果 self.model 中有 update_time 字段,则设置其为当前时间  
+                if hasattr(self.model, 'update_time'):  
+                    clean_data['update_time'] = func.now() 
+                # 构建 SQL 语句,实现插入或更新
+                insert_stmt = insert(self.model).values(**clean_data)
+                update_stmt = insert_stmt.on_conflict_do_update(  
+                    index_elements=index_elements,  
+                    set_={k: clean_data[k] for k in clean_data if k not in index_elements}  
+                ).returning(self.model.id)
+                result = await session.execute(update_stmt)  
+                new_id = result.scalar()
+                await session.commit()  
+                return new_id  
+        except IntegrityError as e:  
+            logger.exception(f"IntegrityError occurred: {e}")  
+            # 如果需要,可以在这里做更多的错误处理,比如回滚事务等。  
+            # 但注意,由于使用了async with,session在退出with块时通常会自动回滚未提交的事务。  
+        except Exception as e:  
+            # 捕获其他类型的异常  
+            logger.exception(f"An unexpected error occurred: {e}")  
+            raise  # 如果需要,可以重新抛出异常

+ 0 - 6
db/common.py

@@ -1,6 +0,0 @@
-from sqlmodel import SQLModel,create_engine
-from config import DB_URL
-
-# 创建引擎和仓储类实例  
-engine = create_engine(DB_URL)  # 替换成你的 DB_URL  
-SQLModel.metadata.create_all(engine)  

+ 132 - 0
db/docs.py

@@ -0,0 +1,132 @@
+import asyncio
+from datetime import datetime
+import re
+from typing import Optional
+from enum import Enum
+from typing import List, Any
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))
+
+from sqlmodel import Field, SQLModel,Column, Integer, Sequence, UniqueConstraint  
+from config import DB_URL,logger
+# from db.common import engine
+from sqlalchemy.dialects.postgresql import insert
+from sqlalchemy.sql.sqltypes import Integer, String, DateTime
+from sqlalchemy.sql.schema import Column
+from sqlalchemy import UniqueConstraint
+from db.base import BaseRepository,DouyinBaseRepository
+from db.engine import engine,create_all
+
+  
+
+class Categories(SQLModel,DouyinBaseRepository, table=True):  
+    id: int = Field(primary_key=True)  # 分类的唯一标识符  
+    open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)  # 关联到用户表的外键  
+    name: str = Field(default="default", index=True)  # 分类的名称,添加索引以优化查询性能  
+    update_time: datetime = Field(default_factory=datetime.now)  # 创建时间、更新时间
+    # 添加联合唯一约束  
+    __table_args__ = (UniqueConstraint('open_id', 'name', name='uq_open_id_name'),)
+    
+        
+    
+class DocumentCategories(SQLModel, table=True):  
+    document_id: int = Field(foreign_key="documents.id", primary_key=True)  # 关联到文档表的外键  
+    category_id: int = Field(foreign_key="categories.id", primary_key=True)  # 关联到分类表的外键  
+
+class DocStatus:  
+    UNPROCESSED = 0  # 未处理  
+    COMPLETED = 100  # 已完成  
+    DISABLED = -1    # 禁用  
+    
+class Documents(SQLModel, table=True):  
+    id: Optional[int] = Field(primary_key=True)
+    open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)  # 关联到用户表的外键  
+    path: str = Field(nullable=False, index=True) # 相对路径
+    status: int = Field(nullable=False) # 文档状态  
+    update_time: datetime = Field(default_factory=datetime.now)  # 创建时间、更新时间
+    __table_args__ = (UniqueConstraint('path'),) 
+    
+class CategoriesRepository(DouyinBaseRepository):  
+    def __init__(self, engine=engine):  
+        super().__init__(Categories, engine)  
+  
+        
+class DocumentCategoriesRepository(DouyinBaseRepository):  
+    def __init__(self, engine=engine):  
+        super().__init__(DocumentCategories, engine)  
+
+  
+  
+class DocumentsRepository(DouyinBaseRepository):  
+    def __init__(self, open_id, file_path, category_name="default", engine=engine):  
+        # file_path = {DATA_DIR}/{open_id}/docs/xxx/example_file.pdf
+        relative_path = DocumentsRepository.get_relative_path(file_path)
+        if relative_path == None:
+            return
+        self.doc_model = Documents(
+                            open_id=open_id,
+                            path=relative_path,
+                            status=DocStatus.UNPROCESSED,
+                            )
+        self.category_model = Categories(
+            open_id=open_id,
+            name=category_name
+        )
+        super().__init__(Documents, engine)  
+
+    def get_relative_path(full_path):
+        pattern = r'docs(/.*?)$'  
+        match = re.search(pattern, full_path)  
+        if match:
+            return match.group(1)
+        else:
+            logger.error(f"Can not get rel path:{full_path}")
+    
+    async def add_document_with_categories(self):
+        # document_id = await self.aadd_or_update(self.doc_model.model_dump(), constraint_name="path")
+        # logger.debug(f"document_id:{document_id}")
+        c = CategoriesRepository()
+        category_id = await c.aon_conflict_do_nothing(self.category_model, index_elements=["open_id", "name"])
+        logger.debug(f"category_id:{category_id}")
+        return
+        # 添加或更新文档
+        await self.add_or_update_document(new_document.model_dump(), "document_id")
+
+        # 获取已存在的分类
+        categories_repo = CategoriesRepository()
+        existing_categories = await categories_repo.get_all_by_ids(category_ids)
+        existing_category_ids = {category.category_id for category in existing_categories}
+
+        # 添加不存在的分类
+        for category_id in set(category_ids) - existing_category_ids:
+            new_category = Categories(open_id=new_document.open_id, category_id=category_id, category_name=f"Category_{category_id}")  # 假设名称由 ID 生成
+            await categories_repo.add([new_category])
+
+        # 创建并添加文档分类关联关系
+        document_categories_to_add = []
+        for category_id in category_ids:
+            doc_cat = DocumentCategories(document_id=new_document.document_id, category_id=category_id)
+            document_categories_to_add.append(doc_cat)
+
+        # 添加文档分类关联关系到数据库
+        document_categories_repo = DocumentCategoriesRepository()
+        await document_categories_repo.add(document_categories_to_add)
+  
+# 示例使用  
+async def main():  
+    from db.user import test_add
+    open_id = await test_add()
+    # 创建实例  
+    categories_repo = CategoriesRepository()  
+    documents_repo = DocumentsRepository(open_id,"/home/user/code/open-douyin/open_id/docs/readme2.md")  
+    document_categories_repo = DocumentCategoriesRepository()  
+    await documents_repo.add_document_with_categories()
+    # 添加分类  
+    # doc1 = Documents(open_id=open_id, document_name="docs_fn", status="ready", file_path="/path")
+    # doc2 = Documents(open_id=open_id, document_name="docs_jj", status="ready", file_path="/path")
+    # 实现有关代码
+  
+if __name__ == "__main__":  
+    import asyncio  
+    asyncio.run(main())

+ 11 - 0
db/engine.py

@@ -0,0 +1,11 @@
+import asyncio
+from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine  
+from sqlmodel import Field, SQLModel
+from config import DB_URL,logger
+
+engine = create_async_engine(DB_URL)  # 替换成你的 DB_URL 
+# SQLModel.metadata.create_all() 
+async def create_all():  
+    async with engine.begin() as conn:  
+        await conn.run_sync(SQLModel.metadata.create_all)  
+  

+ 71 - 99
db/user.py

@@ -10,18 +10,21 @@ from config import DB_URL,logger
 # from db.common import engine
 from sqlalchemy import UniqueConstraint, Index
 from sqlalchemy.dialects.postgresql import insert
+from db.base import BaseRepository,DouyinBaseRepository
+from db.engine import engine,create_all
 
-# 定义数据库模型  
+# 定义数据库模型,不推荐使用 __tablename__ = '"UserInfo"' 来定义包含大写的表名字,
+# 因为可能会导致与其他数据库系统不兼容,而且表查询的时候需要额外注意表格名的大小写
 class UserOAuthToken(SQLModel, table=True):  
     id: Optional[int] = Field(default=None, primary_key=True)
     access_token:str
     expires_in: Optional[int] = None
-    open_id:str
+    open_id:str = Field(index=True)
     refresh_expires_in: Optional[int] = None
     refresh_token:str
     scope: str
-    update_time: datetime = Field(default_factory=datetime.now)  # 添加时间戳字段  
-    __table_args__ = (UniqueConstraint('open_id'),) 
+    update_time: datetime = Field(default_factory=datetime.now)  # 添加时间戳字段 
+    __table_args__ = (UniqueConstraint('open_id'),)  
 
 class UserInfo(SQLModel, table=True):  
     id: Optional[int] = Field(default=None, primary_key=True)  
@@ -30,121 +33,90 @@ class UserInfo(SQLModel, table=True):
     client_key: str  
     e_account_role: str = Field(default="")  
     nickname: str  
-    open_id: str  
+    # 外键约束有助于:级联操作、避免冗余、数据完整性
+    open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)  
     union_id: str  
     update_time: datetime = Field(default_factory=datetime.now)  
     __table_args__ = (UniqueConstraint('open_id'),) 
     
-    
-engine = create_engine(DB_URL)  # 替换成你的 DB_URL  
-SQLModel.metadata.create_all(engine)  
 
-class UserInfoRepository:  
+class UserInfoRepository(DouyinBaseRepository):  
     def __init__(self, engine=engine):  
-        self.engine = engine  
-  
-    def create_user_info(self, user_info_data):  
-        # 剔除不需要的字段  
-        cleaned_data = {k: v for k, v in user_info_data.items() if k not in ["log_id", "error_code"]}  
-          
-        # 添加或更新时间戳  
-        cleaned_data['update_time'] = func.now()  
-  
-        with Session(self.engine) as session:  
-            # 使用 on_conflict_do_update 处理 open_id 的冲突  
-            insert_stmt = insert(UserInfo).values(**cleaned_data)  
-            update_stmt = insert_stmt.on_conflict_do_update(  
-                constraint="open_id",  # 使用 open_id 作为冲突约束  
-                set_={**{k: cleaned_data[k] for k in cleaned_data if k != "open_id"}, "update_time": func.now()}  # 更新其他字段,包括时间戳  
-            )  
-            result = session.exec(update_stmt)  
-            session.commit() 
-  
-    def get_user_info_by_open_id(self, open_id):  
-        with Session(self.engine) as session:  
-            statement = select(UserInfo).where(UserInfo.open_id == open_id)  
-            result = session.exec(statement)  
-            return result.first()  
-  
-    def update_user_info(self, user_id, user_info_data):  
-        with Session(self.engine) as session:  
-            update_user_info = session.get(UserInfo, user_id)  
+        super().__init__(UserInfo, engine)  
+        self.model:UserInfo
+        
+    async def create_user_info(self, user_info_data):  
+        return await self.aadd_or_update(user_info_data)
+    
+    async def update_user_info(self, user_id, user_info_data):  
+        async with self.session_factory() as session:  
+            update_user_info = await session.get(UserInfo, user_id)  
             if update_user_info:  
                 for key, value in user_info_data.items():  
                     setattr(update_user_info, key, value)  
-                session.commit()  
+                await session.commit()  
                 return update_user_info  
   
-    def delete_user_info(self, user_id):  
-        with Session(self.engine) as session:  
-            delete_user_info = session.get(UserInfo, user_id)  
+    async def delete_user_info(self, user_id):  
+        async with self.session_factory() as session:  
+            delete_user_info = await session.get(UserInfo, user_id)  
             if delete_user_info:  
-                session.delete(delete_user_info)  
-                session.commit()  
+                await session.delete(delete_user_info)  
+                await session.commit() 
         
 # Database manager class
-class UserOAuthRepository:
-    def __init__(self, engine=engine):
-        self.engine = engine
+class UserOAuthRepository(DouyinBaseRepository):  
+    def __init__(self, engine=engine):  
+        super().__init__(UserOAuthToken, engine)  
+        self.model:UserOAuthToken
 
-    def add_token(self, data: dict):  
-        # 剔除不需要的字段  
-        cleaned_data = {  
-            k: v for k, v in data.items()  
-            if k not in ["log_id", "error_code", "captcha", "desc_url", "description"]  
-        }  
-          
-        # 添加或更新时间戳  
-        cleaned_data['update_time'] = func.now()  
-          
-        # 构造插入语句  
-        insert_stmt = insert(UserOAuthToken).values(**cleaned_data)  
-        update_stmt = insert_stmt.on_conflict_do_update(  
-            index_elements=['open_id'],  # 使用 open_id 作为冲突的目标列  
-            set_={  
-                **{k: insert_stmt.excluded[k] for k in cleaned_data if k != "open_id"},  
-                "update_time": func.now()  # 更新时间戳  
-            }  
-        )  
-          
-        # 执行插入/更新操作  
-        with Session(self.engine) as session:  
-            result = session.exec(update_stmt)  # 注意:这里应该是 execute 而不是 exec  
-            session.commit()  
-            logger.debug(f"Record added/updated: Access Token, Open ID - {cleaned_data['open_id']}")
+    async def add_token(self, data: dict):  
+        return await self.aadd_or_update(data)
 
 
-    def delete_token(self, token_id: int):
-        with Session(self.engine) as session:
-            token = session.get(UserOAuthToken, token_id)
-            if token:
-                session.delete(token)
-                session.commit()
-                print(f"Record deleted: ID - {token_id}")
-            else:
-                print(f"Record with ID {token_id} not found")
+    async def delete_token(self, token_id: int):  
+        async with self.session_factory() as session:  
+            statement = select(UserOAuthToken).where(UserOAuthToken.id == token_id)  
+            token = await session.execute(statement).scalars().first()  
+            if token:  
+                await session.delete(token)  
+                await session.commit()  
+                print(f"Record deleted: ID - {token_id}")  
+            else:  
+                print(f"Record with ID {token_id} not found") 
 
-    def display_all_records(self):
-        with Session(self.engine) as session:
-            statement = select(UserOAuthToken)
-            user_tokens = session.exec(statement).all()
-            return user_tokens
     
-    # 根据 open_id 获取模型中某一行
-    def get_by_open_id(self, open_id):  
-        with Session(self.engine) as session:  
-            statement = select(UserOAuthToken).where(UserOAuthToken.open_id == open_id)  
-            result = session.exec(statement)  
-            return result.first()  
 
-
-def main():
+async def test_add():  
+    await create_all()
+    
+    user_oauth = {'access_token': 'act.3.wl8L3DFQ3sj3uKYzQShOSs8HbOgKh0FVvjxKeaTum0ZOEXoyBI8D1N7gTBqGbrY32KP-Pm41EAvcobSheOBi8tvRdhj7m5-5ZVoprZZu_GN5J2KnH2fZ_X9_l7Q6iFyvpPoMkX3Zyom3PCkeRZp4Jg9sE2ZiwuvZVdnvft0A25uBWXvj2IEbWW_0Bf8=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '20240129123749239735B0529965BC6D93', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
+    user_info = {
+    "avatar": "https://p26.douyinpic.com/aweme/100x100/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
+    "avatar_larger": "https://p3.douyinpic.com/aweme/1080x1080/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
+    "captcha": "",
+    "city": "",
+    "client_key": "123",
+    "country": "",
+    "desc_url": "",
+    "description": "",
+    "district": "",
+    "e_account_role": "",
+    "error_code": 0,
+    "gender": 0,
+    "log_id": "202401261424326FE877A6CAB03910C553",
+    "nickname": "程序员马工",
+    "open_id": "_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy",
+    "province": "",
+    "union_id": "123-01ae-59bd-978a-1de8566186a8"
+  }
     db_manager = UserOAuthRepository()
-    data = {'access_token': 'act.3.wl8L3DFQ3sj3uKYzQShOSs8HbOgKh0FVvjxKeaTum0ZOEXoyBI8D1N7gTBqGbrY32KP-Pm41EAvcobSheOBi8tvRdhj7m5-5ZVoprZZu_GN5J2KnH2fZ_X9_l7Q6iFyvpPoMkX3Zyom3PCkeRZp4Jg9sE2ZiwuvZVdnvft0A25uBWXvj2IEbWW_0Bf8=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '20240129123749239735B0529965BC6D93', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
-    # db_manager.add_token(data)
-    # res = db_manager.display_all_records()
-    res = db_manager.get_from_id("_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy", "access_token")
+    res = await db_manager.add_token(user_oauth)
+    db_user_info = UserInfoRepository()
+    res = await db_user_info.create_user_info(user_info)
     logger.debug(res)
+    return user_oauth["open_id"]
 
-if __name__ == "__main__":
-    main()
+if __name__ == "__main__":  
+    import asyncio  
+    asyncio.run(test_add())

+ 5 - 1
douyin/access_token.py

@@ -2,10 +2,10 @@ import os
 import sys
 sys.path.append(os.path.dirname(os.path.dirname(__file__)))
 import httpx
-from config import logger
 from typing import Optional
 from pydantic import BaseModel
 from db.user import UserOAuthToken,UserOAuthRepository
+from config import *
 
 class DouyinAccessTokenResponse(BaseModel):
     error_code: int
@@ -19,6 +19,10 @@ async def check_access_token_response(response_json: dict):
     return model_data.data
 
 async def get_access_token(code):
+    if not PRODUCE_ENV:
+        # 测试环境使用。因为每次 get_access_token 的 code 只能使用一次就过期了,为了避免频繁扫码,直接模拟返回请求结果
+        return {'access_token': 'act.3.UCzqnMwbL7uUTH0PkWbvDvIHcpy417HnfMqymbvBSpo9b1MJ3jOdwCxw-UPstOOjsGDWIdNwTGev4oEp8eUR-vHbU24XU5K4BkhPeOKJW1CLrEUS3XFxpG6SHqoQtvL6qhEgINcvt4V3KQX6C2qTeTkgQ-KwPO6jWi5uoin3YXo5DqwuGk3bbQ9dZoY=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '2024012915260549B5ED1A675515CD573C', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
+    
     client_key = os.environ.get("CLIENT_KEY")  # 从环境变量中获取 client_key  
     client_secret = os.environ.get("CLIENT_SECRET")  # 从环境变量中获取 client_secret  
     async with httpx.AsyncClient() as client:  

+ 9 - 6
main.py

@@ -14,6 +14,7 @@ import os
 from config import *
 from fastapi.middleware.cors import CORSMiddleware 
 from api.login import login_router
+from api.updload import upload_router
 from contextlib import asynccontextmanager
 
 app = FastAPI()  
@@ -26,6 +27,7 @@ app.add_middleware(
      
 ) 
 app.include_router(login_router)  
+app.include_router(upload_router)
 
 @asynccontextmanager
 async def lifespan(app: FastAPI):
@@ -40,15 +42,16 @@ def get_session(request: Request):
     
 @app.get("/")  
 async def read_root(request: Request):  
-    return {"message": "ok", "time":time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())}  
+    return {"message": "FastApi server is running.", "time":time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())}  
 
 
 def main():
-    print(f"https://open-douyin-cf.magong.site  公网代理地址,cloudflare dns proxy ,由 caddy 转发到 8600 端口")
-    print(f"https://open-douyin-wk.magong.site  公网代理地址,cloudflare workers 转发到 8600 端口")
-    print(f"http://sv-v2.magong.site:{PORT}  ⭐ 推荐,仅支持 ipv6 ,直连、满速、无延迟。缺点是不支持 https 协议,因为不经过 Caddy 代理,直达 Fastapi 没有配置 https")
-    print(f"https://open-douyin.magong.site  内网穿透隧道,cloudflare tunnel ,经常访问不了")
-    print(f"https://swl-8l9.pages.dev/  访问前端网站")
+    logger.debug(f"https://open-douyin-cf.magong.site  公网代理地址,cloudflare dns proxy ,由 caddy 转发到 8600 端口")
+    logger.debug(f"https://open-douyin-wk.magong.site  公网代理地址,cloudflare workers 转发到 8600 端口")
+    logger.debug(f"http://sv-v2.magong.site:{PORT}  仅支持 ipv6 ,直连、满速、无延迟。缺点是不支持 https 协议,因为不经过 Caddy 代理,直达 Fastapi 没有配置 https")
+    logger.debug(f"https://open-douyin.magong.site  内网穿透隧道,cloudflare tunnel ,经常访问不了")
+    logger.debug(f"http://localhost:{PORT} ⭐ 推荐。 vscode 会自动建立一条本地隧道,可以在本地浏览器直接打开")
+    logger.debug(f"https://swl-8l9.pages.dev/  访问前端网站")
     uvicorn.run(app, host=None, port=PORT, log_level="info")
 
 if __name__ == "__main__":

+ 16 - 8
test/config.py

@@ -1,13 +1,21 @@
 
 import os
-import socket
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))
+from sqlalchemy import UniqueConstraint
+from loguru import logger
+from sqlmodel import Field, SQLModel
 
-os.environ["CLIENT_KEY"] = 'aw6aipmfdtplwtyq'
-os.environ["CLIENT_SECRET"] = '53cf3dcd2663629e8a773ab59df0968b'
-os.environ["JWT_SECRET_KEY"]="123"
-# HOST = socket.gethostbyname(socket.gethostname())
-# 这个网址 https://open-douyin.magong.site 对应这台服务器的 192.168.1.32:8600 端口,因为这台服务器没有公网ip,所以在本地计算机无法通过  http://192.168.1.32:8600/ 访问到 fastapi 接口,只能通过 https://open-douyin.magong.site/ 访问
+# 定义数据库模型,不推荐使用 __tablename__ = '"UserInfo"' 来定义包含大写的表名字,
+# 因为可能会导致与其他数据库系统不兼容,而且表查询的时候需要额外注意表格名的大小写
+class UserOAuthToken(SQLModel):  
+    access_token:str = Field(default=None, primary_key=True)
+    refresh_expires_in:str = Field(default=None, )
+    refresh_token:str = Field(default=1, )
+    __table_args__ = (UniqueConstraint('open_id', 'name', name='uq_open_id_name'),)
+    
+user = UserOAuthToken()
 HOST = '::'
-PORT = 8601
-JWT_SECRET_KEY = os.environ["JWT_SECRET_KEY"]
+con:UniqueConstraint = [col.name for col in UserOAuthToken.__table__.constraints['uq_open_id_name'].columns] 
+logger.info(con)
 # print(HOST)

+ 4 - 0
待办事项.md

@@ -14,7 +14,11 @@
 
 
 # mrh
+- [ ] **数据库**
+  - [ ] 向量存储库避免单点故障:主从复制
+  - [ ] 改为 supabase 向量存储数据库,基于 Postgres 构建,支持 embedding 存储,支持 REST API 访问自动鉴权,支持文件存储、边缘函数  https://github.com/supabase/supabase
 - [ ] **登录页**
+
   - [ ] 从数据库获取用户数据,展示到用户登录页中。保存到本地 localstore
   - [ ] 实现文档上传、下载、更新、删除
   - [ ] 文档预览