Bläddra i källkod

数据库全部改为关系型语法

qyl 2 år sedan
förälder
incheckning
9652e6bf68

+ 3 - 1
.gitignore

@@ -1,4 +1,6 @@
 __pycache__
 demo-release
 log
-mtest
+mtest
+temp
+bin

+ 2 - 2
api/chat.py

@@ -1,5 +1,5 @@
 from fastapi import FastAPI, Depends, HTTPException  
-from api import jwt  
+from api import swl_jwt  
 from grpc_m.proto import vector_service_pb2, vector_service_pb2_grpc  
 import grpc  
   
@@ -10,7 +10,7 @@ channel = grpc.insecure_channel('localhost:18600')
 stub = vector_service_pb2_grpc.VectorServiceStub(channel)  
   
 @app.post("/chat")  
-async def chat(user_id: str, doc_id: str, message: str, token: str = Depends(jwt.get_token)):  
+async def chat(user_id: str, doc_id: str, message: str, token: str = Depends(swl_jwt.get_token)):  
     # 鉴权逻辑(根据实际情况进行修改)  
     # ...  
       

+ 73 - 16
api/comment.py

@@ -1,3 +1,4 @@
+import hashlib
 import json
 from typing import List
 import os
@@ -6,40 +7,91 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
 
 from db.comment import CommentRepository,CommentStatus
 from db.docs import CategoriesRepository,Categories
-from db.user import UserOAuthRepository,UserOAuthToken
-from config import logger
-from grpc_m.send_data_to_vector import langchain_chat
-import douyin.comment
+from db.user_oauth import UserOAuthRepository,UserOAuthToken
+from config import logger,TEMP_DIR
+from grpc_m.send_data_to_vector import langchain_chat,simarity_search
+import douyin.comment_reply
 from fastapi import FastAPI,APIRouter,  File, HTTPException
 from db.comment import CommentContent,Events,EventRepository
+from db.user_info import UserInfo,UserInfoRepository
+from db.docs import DocumentsRepository
+from douyin.video_get_iframe_by_item import download_video
+# def get_video_vector
 
 def gen_prompt(myname, content_models:List[CommentContent]):
     role = """
-你是一个抖音视频的创作者,你的昵称是'{nick_name}'。你发布了一个视频,<vedio_content> 是视频的简要总结。
+你是一个抖音视频的创作者,你的昵称是'{nick_name}'。你发布了一个视频,<vedio_content> 是视频的部分文案片段
 <chat> 是用户在视频评论、或用户各自的讨论。你需要回答最后一条 chat 信息。""".format(nick_name=myname)
     logger.info(f"{role}")
     prompt = """
 {role}
-基于以下聊天回答最后一个 <chat> 问题
 <chat>
 {chat}
 <chat/>
 <docs>
 {docs}
 <docs/>
+<vedio_content>
+{vedio_content}
+<vedio_content/>
 """
 
     chat_record = ''
     for model in content_models:
         chat_record += f"{model.nick_name}: {model.content}\n"
     
-    return prompt.format(chat=chat_record, role=role)
+    return prompt.format(chat=chat_record, role=role, docs = "docs", vedio_content="vedio_content")
+
+
+async def chat_to_langchain(event_model:Events, comment_model:CommentContent, oauth_model=None):
+    user_model:UserInfo = UserInfoRepository().get_by_open_id(event_model.to_user_id)
+    if not oauth_model:
+        oauth_model:UserOAuthToken = UserOAuthRepository().get_by_open_id(event_model.to_user_id)
+    # 递归查找对话
+    comment_replies = CommentRepository().get_comment_and_replies(comment_model.comment_id)
+    DocumentsRepository().select()
+    query = comment_replies[-1].content
+    download_video
+    prompt = gen_prompt(user_model.nickname, comment_replies)
+    logger.info(f"query:{query} prompt: {prompt}")
+    # langchain_res = await langchain_chat(str("4ff71182-5c43-497f-ba16-5b3ba252e478"), prompt)
+    langchain_res = "这是一个示例回复"
+    if not langchain_res:
+        logger.error(f"langchain_chat {langchain_res} ")
+        return
+    response = await douyin.comment_reply.reply_to_comment(oauth_model.open_id,oauth_model.access_token, content=langchain_res, comment_id=comment_model.comment_id, item_id=comment_model.reply_to_item_id)
+    if not response.get('data').get('error_code'):
+        logger.info(f"回复评论成功: {langchain_res}")
+        
+    else:
+        # 一般是秘钥过期、参数错误。还有一种特殊情况,两个账号都授权了思维链,一个账号在另一个账号下是 event_model.to_user_id 恰好授权了思维链,但是他是在别的授权
+        logger.error(f"回复评论失败: {response}")
+
+
+async def save_video_item(open_id, item_id):
+    hash_object = hashlib.md5(item_id.encode())  
+    hex_dig = hash_object.hexdigest()
+    download_dir = os.path.join(TEMP_DIR,"video")  
+    if not os.path.exists(download_dir):  
+        os.makedirs(download_dir)  
+    file_path = os.path.join(download_dir, hex_dig + ".mp4")
+    try:
+        file_path = await download_video(item_id,file_path)
+        if file_path:
+            
+    except Exception as e:
+        logger.exception(e)
+    if not os.path.exists(file_path):  
+        os.remove(file_path)  
+
 '''
 data = {'event': 'item_comment_reply', 'client_key': 'aw6aipmfdtplwtyq', 'from_user_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'to_user_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'content': '', 'log_id': '021708779405655fdbdfdbdfdbdfdbd0000000000000008cb19c7'}
 data_content = json.loads(data.get("content"))
 data_content = {"at_user_id":"","avatar":"https://p26.douyinpic.com/aweme/720x720/aweme-avatar/tos-cn-i-0813_a2afe121cfee43c7856b1ec0d6997690.jpeg?from=3782654143","comment_id":"@9VxS1/qCUc80K2etd8wkUc791mbgPP2DPZV2qA6mLFEQaPT960zdRmYqig357zEBoZm7vZ+ZZZz6H3mOVdTOlw==","comment_user_id":"_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk","content":"测试","create_time":1708779397,"digg_count":0,"level":1,"nick_name":"王孙草爱搞钱","parent_id":"7259290547288870144","reply_comment_total":0,"reply_to_comment_id":"0","reply_to_item_id":"@9VxS1/qCUc80K2etd8wkUc7912DgP/GCPpF2qwKuJ1YTb/X460zdRmYqig357zEBKzkoKYjBMUvl9Bs6h+CwYQ=="}
 
 定义:
+- from_user_id: 在你视频下评论的人
+- to_user_id: 你自己的 open_id ,只有授权了思维链,才会收到自己视频的评论事件
 - 评论: 点开视频评论区就能看到的评论列表,或者在视频评论区中被人评论。
   - 例如:打开抖音视频,在评论区中,任何类似 “头像 昵称 \n 评论” 的内容,都可以看做一条评论。示例值: parent_id 、 comment_id
   - parent_id 父级评论id ,如果该评论没有父级评论,默认使用视频id作为父级 
@@ -54,23 +106,23 @@ data_content = {"at_user_id":"","avatar":"https://p26.douyinpic.com/aweme/720x72
 # 这不仅仅是打广告,也是为了观众区分哪一条是人工回复,哪一条是 AI 回复。
 
 async def item_comment_reply(data):
-    logger.debug(f"new item_comment_reply event, comment_data: {data}")
+    logger.info(f"new item_comment_reply event, comment_data: {data}")
     db_events = EventRepository()
     db_comment = CommentRepository()
     event_model,comment_model = db_events.save_item_comment_reply(data)
     event_model:Events
     comment_model:CommentContent
+    
     '''原则上不能让AI自己回复自己的评论,原因如下
       - 如果 AI 回答不完整,想手动回复该条评论,会导致让AI再次回复你自己手动评论的内容
       - 如果你回复自己的视频,在评论区阐述自己的观点,AI会回复你这条评论。但这很矛盾,你为什么要自己阐述完又让AI补充?
       - 会陷入死循环,自己回复自己。虽然技术上可以做到不让AI回复AI产生的评论,但意义何在?既然你选择手动评论,说明这段对话中已经不需要AI
     '''
-    # 为了特殊情况,或者自己测试用,当 @思维链AI助手 或者 @自己 时,允许回复一次
+    # 任何一条评论都可能 @某个用户,当 @思维链AI助手 或者 @自己 时,允许回复一次
     if comment_model.at_user_id == event_model.to_user_id:
-        # 递归查找对话
-        comment_replies = db_comment.get_comment_and_replies(comment_model.comment_id)
-        prompt = gen_prompt(event_model.)
+        
         return
+    
     if comment_model.comment_user_id == event_model.to_user_id:
         return
     
@@ -79,17 +131,22 @@ async def item_comment_reply(data):
         exist_comment:CommentContent = CommentRepository().select(
             CommentContent.comment_id == comment_model.parent_id,
             CommentContent.comment_user_id != event_model.to_user_id).first()
-        # 如果存在,说明这条评论的回复自己的
+        # 如果存在,说明这条评论自己账号发表
         if exist_comment:
+            logger.info(f"收到 AI 发表评论的回调 {comment_model} exist_comment_id:{exist_comment.id}")
             pass
+        else:
+            await chat_to_langchain(event_model, comment_model)
     # 如果是回复他人的回复
     elif comment_model.reply_to_comment_id:
+        # 查询这个回复事件是回复哪一条已有评论。
         exist_comment:CommentContent = CommentRepository().select(
             CommentContent.comment_id == comment_model.reply_to_comment_id,
-            CommentContent.comment_user_id != event_model.to_user_id).first()
-        # 如果存在,说明这条视频评论是发给自己的
+            CommentContent.comment_user_id == event_model.to_user_id).first()
+        logger.info(f"reply to exist_comment {exist_comment}")
+        # 查询到已存在的评论,是回复自己
         if exist_comment:
-            pass
+            await chat_to_langchain(event_model, comment_model)
     # 判断这条评论是不是子评论发给自己的。如果父级 parent_id 的评论者
     '''❗ 如果自己回复自己,会导致webhook死循环。因此要过滤一下
     comment_id 你抖音视频评论的 id , 评论者可能是自己、别人、LangChain ,每一条评论或回复评论都会产生一个 id

+ 84 - 40
api/login.py

@@ -1,3 +1,4 @@
+import asyncio
 import datetime
 import os
 import sys
@@ -11,8 +12,13 @@ from fastapi.responses import JSONResponse
 from config import *
 from douyin.access_token import get_access_token
 from douyin.user_info import get_user_info
-from db.user import UserOAuthRepository,UserInfoRepository,UserOAuthToken
-from api.jwt import verify_jwt_token,get_uer_oauth_and_verify
+from db.user_oauth import UserOAuthRepository,UserOAuthToken
+from db.user_info import UserInfoRepository,UserInfo
+from db.user import User,UserRepo
+from db.base import update_from_model
+from api.swl_jwt import verify_jwt_token,verify_user
+from sqlmodel import Session,select
+from db.engine import engine,create_all_table
 
 login_router = APIRouter()  
 
@@ -22,64 +28,102 @@ class ScanCode(BaseModel):
     code: str
     scopes: str
 
-class User(BaseModel):
-    nickname: str
-    avatar: str
-
+async def save_login_data(data:dict):
+    access_token_expires_in = datetime.datetime.now() + datetime.timedelta(seconds=data.get("expires_in"))  
+    refresh_token_expires_in = datetime.datetime.now() + datetime.timedelta(seconds=data.get("refresh_expires_in"))  
+    oauth_model:UserOAuthToken = UserOAuthRepository().dict_to_model(data)
+    oauth_model.expires_at = access_token_expires_in
+    oauth_model.refresh_expires_at = refresh_token_expires_in
+    user_info_data = await get_user_info(oauth_model.open_id, oauth_model.access_token)
+    if not user_info_data.get("error_code"):
+       info_model = UserInfoRepository().dict_to_model(user_info_data)
+    else:
+        info_model = UserInfo()
+    with Session(engine) as session:
+        user = session.exec(
+            select(User).where(User.open_id == oauth_model.open_id)
+            ).first()
+        if not user:
+            user = User(open_id=oauth_model.open_id, oauth=oauth_model, info=info_model)
+        else:
+            user.open_id = oauth_model.open_id
+            update_from_model(user.oauth, oauth_model)
+            # user.info = info_model
+            update_from_model(user.info, info_model)
+            logger.info(f"update: {user.oauth}")
+        session.add(user)
+        session.commit()
         
-# 登录端点
-@login_router.post("/login")
-async def login(data: ScanCode):
-    logger.info(data)
-    data = await get_access_token(data.code)
-    if data.get("error_code") != 0:
-        raise HTTPException(status_code=400, detail=data)
-
     # 计算过期时间戳(基于北京时间)  
-    expires_in = data.get("refresh_expires_in", 1296000)
-    # expires_in = 15
-    expiration_time_local = datetime.datetime.now() + datetime.timedelta(seconds=expires_in)  
-    exp = int(expiration_time_local.timestamp())  
+    # expires_in = data.get("refresh_expires_in", 1296000)
+    # # expires_in = 15
+    # expiration_time_local = datetime.datetime.now() + datetime.timedelta(seconds=expires_in)  
+    # exp = int(expiration_time_local.timestamp())  
 
-    db_manager = UserOAuthRepository()
+    # db_manager = UserOAuthRepository()
     
-    oauth_model:UserOAuthToken = db_manager.save_login_data(data)
+    # oauth_model:UserOAuthToken = db_manager.save_login_data(data)
     
-    data = await get_user_info(oauth_model.open_id, oauth_model.access_token)
-    if data.get("error_code") != 0:
-        raise HTTPException(status_code=400, detail=data)  
-    db_user = UserInfoRepository()
-    user_info_model = db_user.dict_to_model(data)
-    db_user.add_or_update(user_info_model)
+    # data = await get_user_info(oauth_model.open_id, oauth_model.access_token)
+    # if data.get("error_code") != 0:
+    #     raise HTTPException(status_code=400, detail=data)  
+    # db_user = UserInfoRepository()
+    # user_info_model = db_user.dict_to_model(data)
+    # db_user.add_or_update(user_info_model)
     
     # 生成并返回 token,包含过期时间  
+    expiration_time_local = datetime.datetime.now() + datetime.timedelta(days=90)
+    exp = int(expiration_time_local.timestamp())  
     payload = {  
         "sub": data["open_id"],
-        "exp": exp  # 添加过期时间戳(北京时间)到 payload  
+        "exp": exp
     }  
     account_token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")  
     logger.info(f"login success, expires_time:{datetime.datetime.fromtimestamp(exp).strftime('%Y-%m-%d %H:%M:%S') }, token:{account_token}")
     return {"token": account_token}
+        
+# 登录端点
+@login_router.post("/login")
+async def login(data: ScanCode):
+    logger.info(data)
+    data = await get_access_token(data.code)
+    if data.get("error_code") != 0:
+        raise HTTPException(status_code=400, detail=data)
+    return await save_login_data(data)
 
-@login_router.get("/user_info")
-async def user_info(open_id: str = Depends(verify_jwt_token)):
-    return UserInfoRepository().get_by_open_id(open_id)
 
-# 受保护资源示例
-@login_router.get("/account")
-async def read_account(open_id: str = Depends(verify_jwt_token)): 
-    UserOAuthRepository().display_all_records()
-    return {"message": "Account information", "open_id": open_id}
-    # 在这里返回当前用户的信息
-    return {"nickname": current_user.username, "avatar": "https://p26.douyinpic.com/aweme/100x100/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038"}
+@login_router.get("/user_info")
+async def user_info(user: User = Depends(verify_user)):
+    return await get_user_info(user.open_id, user.oauth.access_token)
 
+@login_router.get("/verify_callback")
+async def verify_callback(code:str, scopes:str): 
+    return await login(ScanCode(code=code, scopes=scopes))
+    
 @login_router.get("/token")
 async def read_account(open_id: str = Depends(verify_jwt_token)): 
     pass
 
 # 启动应用
-def main():
-    pass
+async def main():
+    create_all_table()
+    data = {
+        "access_token": "1act.f7094fbffab2ecbfc45e9af9c32bc241oYdckvBKe82BPx8T******",
+        "captcha": "",
+        "desc_url": "",
+        "description": "",
+        "error_code": 0,
+        "expires_in": 1296000,
+        "log_id": "20230525105733ED3ED7AC56A******",
+        "open_id": "b9b71865-7fea-44cc-123",
+        "refresh_expires_in": 2592000,
+        "refresh_token": "rft.713900b74edde9f30ec4e246b706da30t******",
+        "scope": "user_info"
+        }
+    user_oauth = {'access_token': 'act.3.m3kiZmfxxIH95i1bHZ7Bq3Wkv_Xm5TtD3kpLGjtCr3G96WIINBKEvzlsaObrGcH4GaxTQeLZA13jkzoZhpAwPwMRqFxlVuIcxpge_-BpdFib1xHqkcFa4B-LX4zpd2YK3kDFTFfMcJXN_fZ2eByg6oqqa1OieUWcvlaVgw==', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '202402171353461C33F969CEFB511B216F', 'open_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.e4b3da8bd3ef01d880d827b11e826391OEGHiRrTLcp5zsGYP1dh6F9Bo7fg', 'scope': 'user_info,trial.whitelist'}
 
+    res = await save_login_data(user_oauth)
+    logger.info(f"{res}")
+    # import jwt  
 if __name__ == "__main__":
-    main()
+    asyncio.run(main())

+ 15 - 7
api/jwt.py → api/swl_jwt.py

@@ -1,3 +1,4 @@
+import datetime
 from fastapi import Depends, HTTPException, status, Header, Security  
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials  
 import jwt  
@@ -52,10 +53,17 @@ async def verify_jwt_token(token: str = Security(get_token_from_header)):
             detail="Invalid token",  
         )  
         
-from db.user import UserOAuthRepository,UserOAuthToken
-def get_uer_oauth_and_verify(open_id: str = Depends(verify_jwt_token)):  
-    db_oauth:UserOAuthToken = UserOAuthRepository().get_by_open_id(open_id)  
-    # 没有用户凭证,需要扫码登陆
-    if not db_oauth:  
-        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="need login")  
-    return db_oauth
+from db.user_oauth import UserOAuthRepository,UserOAuthToken
+from db.user import User,UserRepo
+def verify_user(open_id: str = Depends(verify_jwt_token)):  
+    res = UserRepo().select(User.open_id == open_id)
+    user:User = res.first()
+    if not user:
+        return
+    oauth:UserOAuthToken = user.oauth
+    if (oauth.expires_at - datetime.datetime.now()).total_seconds() <= 0:
+        raise HTTPException(  
+            status_code=status.HTTP_403_FORBIDDEN,  
+            detail="open-douyin Token is expired",  
+        )  
+    return user

+ 2 - 2
api/upload.py

@@ -11,9 +11,9 @@ from fastapi.responses import JSONResponse
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials  
 import aiofiles
 from starlette.status import HTTP_401_UNAUTHORIZED 
-from api.jwt import verify_jwt_token,get_current_user
+from api.swl_jwt import verify_jwt_token,get_current_user
 from config import *
-from db.docs import DocumentsRepository,DocumentCategories
+from db.docs import DocumentsRepository
 from grpc_m.send_data_to_vector import send_to_grpc_vetcor
 
 upload_router = APIRouter()  

+ 6 - 2
config.py

@@ -59,5 +59,9 @@ if not PRODUCE_ENV:
 else:
     MNT_DOUYIN_DATA = os.environ["MNT_DOUYIN_DATA"]
 
-
-
+GRPC_VECTOR="10.0.0.32:18600"
+ASR_ADDRESS = "10.0.0.32:10095"
+ASR_EXE = "/home/user/code/open-douyin/bin/funasr_wss_client"
+TEMP_DIR = os.path.join(WORK_DIR, "temp")
+if not os.path.exists(TEMP_DIR):  
+    os.makedirs(TEMP_DIR)  

+ 20 - 0
db/ai_comment.py

@@ -0,0 +1,20 @@
+from sqlmodel import Field, SQLModel,Relationship,Column,create_engine,Session,select,func,UniqueConstraint
+from typing import Optional,List
+from .base import DouyinBaseRepository
+from .comment import CommentContent
+
+class LinkCommentContent(SQLModel, table=True):
+    commentcontent_id: Optional[int] = Field(
+        default=None, foreign_key="commentcontent.id", primary_key=True
+    )
+    aicomment_id: Optional[int] = Field(
+        default=None, foreign_key="aicomment.id", primary_key=True
+    )
+    
+class AIComment(SQLModel, table=True):  
+    id: int = Field(primary_key=True)  
+    prompt:str
+    reply:str
+    chunk_ids:List[str]
+    tokens:Optional[int]
+    commentcontent: "CommentContent" = Relationship(back_populates="aicomment", link_model=LinkCommentContent)

+ 11 - 3
db/base.py

@@ -11,6 +11,12 @@ from config import logger
 from sqlalchemy.sql.elements import BinaryExpression
 from sqlmodel.orm.session import ScalarResult
 
+def update_from_model(target: SQLModel, source: SQLModel, exclude: list[str] = ["id"]):  
+    for key, value in source.model_dump().items():  
+        if key not in exclude:  
+            setattr(target, key, value)  
+
+
 class BaseRepository:
     def __init__(self, model: SQLModel, engine=engine):
         self.model = model
@@ -187,9 +193,11 @@ class DouyinBaseRepository(BaseRepository):
     def __init__(self, model: SQLModel, engine=engine):  
         super().__init__(model, engine)  
 
-    def dict_to_model(self, dict_data: dict) -> SQLModel:
-        clean_data = {k: v for k, v in dict_data.items() if hasattr(self.model, k)}
-        obj_model = self.model(**clean_data)
+    def dict_to_model(self, dict_data: dict, model=None) -> SQLModel:
+        if not model:
+            model = self.model
+        clean_data = {k: v for k, v in dict_data.items() if hasattr(model, k)}
+        obj_model = model(**clean_data)
         return obj_model
     
     def get_by_open_id(self, open_id):

+ 12 - 15
db/comment.py

@@ -11,6 +11,11 @@ from db.base import BaseRepository
 from db.base import DouyinBaseRepository
 from db.engine import engine
 from config import logger
+# 避免循环导入问题,参考 SQLmodel 官方: https://sqlmodel.tiangolo.com/tutorial/code-structure/
+from typing import TYPE_CHECKING, Optional
+if TYPE_CHECKING:
+    from .ai_comment import LinkCommentContent,AIComment
+
 
 # https://developer.open-douyin.com/docs/resource/zh-CN/dop/develop/openapi/interaction-management/comment-management-user/accept-comment-reply-event
 
@@ -19,16 +24,6 @@ class CommentStatus(Enum):
     PROCESSED = "processed"    # 已处理完成(可选:APPROVED)  
     REJECTED = "rejected"      # 抖音评论被拒绝或未能通过处理
     ERROR = "error"            # 程序内部错误,例如程序报错、数据库异常连接不上、找不到该用户的文档数据、LangChain连接不上、LangChain返回错误
-    
-# class Comment(SQLModel, table=True):
-#     id: Optional[int] = Field(primary_key=True)
-#     comment_id: str = Field(nullable=False,unique=True, max_length=255)
-#     content: str = Field(nullable=False)
-#     from_user_id: str = Field(nullable=False, max_length=255)
-#     reply_to_item_id: str = Field(nullable=False, max_length=255)
-#     reply_content:str = Field(nullable=True)
-#     status: str = Field(default=CommentStatus.PROCESSING.value, max_length=50)
-#     update_time: datetime.datetime = Field(default_factory=datetime.datetime.now)
 
 
 class EventCommentLink(SQLModel, table=True):  
@@ -60,6 +55,7 @@ class CommentContent(SQLModel, table=True):
     reply_to_item_id: Optional[str] = Field(index=True, default=None) 
     nick_name: Optional[str] = Field(default=None)
     event: "Events" = Relationship(back_populates="content", link_model=EventCommentLink)
+    # aicomment:"AIComment" = Relationship(back_populates="commentcontent", link_model=LinkCommentContent)
 
 class Events(SQLModel, table=True):  
     id: int = Field(default=None, primary_key=True)  
@@ -73,15 +69,14 @@ class Events(SQLModel, table=True):
     content: Optional[CommentContent] = Relationship(back_populates="event", link_model=EventCommentLink)
     log_id:str
 
-  
-
-
-
 class EventRepository(DouyinBaseRepository):
     def __init__(self, engine=engine):  
         super().__init__(CommentContent, engine)  
         self.model:Events
 
+    # data = {'event': 'item_comment_reply', 'client_key': 'aw6aipmfdtplwtyq', 'from_user_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'to_user_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'content': '', 'log_id': '021708779405655fdbdfdbdfdbdfdbd0000000000000008cb19c7'}
+    # content = json.loads(data.get("content"))
+    # content = {"at_user_id":"","avatar":"https://p26.douyinpic.com/aweme/720x720/aweme-avatar/tos-cn-i-0813_a2afe121cfee43c7856b1ec0d6997690.jpeg?from=3782654143","comment_id":"@9VxS1/qCUc80K2etd8wkUc791mbgPP2DPZV2qA6mLFEQaPT960zdRmYqig357zEBoZm7vZ+ZZZz6H3mOVdTOlw==","comment_user_id":"_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk","content":"测试","create_time":1708779397,"digg_count":0,"level":1,"nick_name":"王孙草爱搞钱","parent_id":"7259290547288870144","reply_comment_total":0,"reply_to_comment_id":"0","reply_to_item_id":"@9VxS1/qCUc80K2etd8wkUc7912DgP/GCPpF2qwKuJ1YTb/X460zdRmYqig357zEBKzkoKYjBMUvl9Bs6h+CwYQ=="}
     def save_item_comment_reply(self, data:dict) -> Tuple[Events,CommentContent]:
         with Session(self.engine) as session:
             content = json.loads(data.pop("content"))
@@ -90,6 +85,8 @@ class EventRepository(DouyinBaseRepository):
             event_model.content = content_model
             session.add(event_model)
             session.commit()
+            session.refresh(event_model)
+            session.refresh(content_model)
             return event_model,content_model
 
             
@@ -99,7 +96,7 @@ class CommentRepository(DouyinBaseRepository):
         super().__init__(CommentContent, engine)  
         self.model:CommentContent
     
-    def get_comment_and_replies(self, comment_id):
+    def get_comment_and_replies(self, comment_id) -> list[CommentContent]:
         comment_models = []
         with Session(self.engine) as session:
             search_id = comment_id

+ 50 - 35
db/docs.py

@@ -8,30 +8,37 @@ import os
 import sys
 sys.path.append(os.path.dirname(os.path.dirname(__file__)))
 
-from sqlmodel import Field, SQLModel,Session, Integer, Sequence, UniqueConstraint,select
+from sqlmodel import Field, SQLModel,Session,Relationship, Integer, Sequence, UniqueConstraint,select
 from config import DB_URL,logger
 # from db.common import engine
 from pydantic import UUID4
 import uuid
 from db.base import BaseRepository
 from db.engine import engine
+from typing import TYPE_CHECKING, Optional
+if TYPE_CHECKING:
+    from .user import User
+from db.video_data import VideoDocLink,VideoData
+
+class DocumentsLink(SQLModel, table=True):
+    user_id: Optional[int] = Field(
+        default=None, foreign_key="user.id", primary_key=True
+    )
+    doc_id: Optional[int] = Field(
+        default=None, foreign_key="documents.id", primary_key=True
+    )
 
   
+class DocumentCategoriesLink(SQLModel, table=True):
+    documents_id: int = Field(default=None, foreign_key="documents.id", primary_key=True)
+    category_id: int = Field(default=None, foreign_key="categories.id", primary_key=True)
+
 
 class Categories(SQLModel, table=True):  
-    id: UUID4 = Field(default_factory=uuid.uuid1, primary_key=True)  # 使用 UUID v1 作为主键 
-    open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)  # 关联到用户表的外键  
+    id: int = Field(default=None, primary_key=True)  # 使用 UUID v1 作为主键 
     name: str = Field(default="default", index=True)  # 分类的名称,添加索引以优化查询性能  
     update_time: datetime = Field(default_factory=datetime.now)  # 创建时间、更新时间
-    # 添加联合唯一约束  
-    __table_args__ = (UniqueConstraint('open_id', 'name', name='uq_open_id_ctname'),)
-
-
-    
-class DocumentCategories(SQLModel, table=True):
-    id: UUID4 = Field(foreign_key="documents.id",index=True, primary_key=True)  # 关联到文档表的外键  
-    category_id: UUID4 = Field(foreign_key="categories.id",index=True)  # 关联到分类表的外键  
-    __table_args__ = (UniqueConstraint('id', 'category_id', ),)
+    docs: "Documents" = Relationship(back_populates="categories", link_model=DocumentCategoriesLink)
     
 class DocStatus:  
     UNPROCESSED = 0  # 未处理  
@@ -39,13 +46,15 @@ class DocStatus:
     DISABLED = -1    # 禁用  
     
 class Documents(SQLModel, table=True):  
-    id: UUID4 = Field(default_factory=uuid.uuid1, primary_key=True,index=True)  # 使用 UUID v1 作为主键 
-    open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)  # 关联到用户表的外键  
-    path: str = Field(nullable=False, index=True) # 相对路径
+    id: int = Field(default=None, primary_key=True,index=True)  # 使用 UUID v1 作为主键 
+    # open_id: str = Field(index=True)  # 关联到用户表的外键  
+    path: str = Field(nullable=False) # 相对路径
     status: int = Field(nullable=False) # 文档状态  
     update_time: datetime = Field(default_factory=datetime.now)  # 创建时间、更新时间
-    __table_args__ = (UniqueConstraint('open_id', 'path', name='uq_documents'),) 
-
+    user:"User" = Relationship(back_populates="docs", link_model=DocumentsLink, sa_relationship_kwargs={"cascade": "all, delete-orphan", "single_parent":True})
+    categories: List[Categories] = Relationship(back_populates="docs", link_model=DocumentCategoriesLink)
+    video_data:VideoData = Relationship(back_populates="doc", link_model=VideoDocLink, sa_relationship_kwargs={"cascade": "all, delete-orphan", "single_parent":True})
+    
 class DocumentBase(BaseRepository):
     def __init__(self, model: Documents, engine=...):
         super().__init__(model, engine)
@@ -99,28 +108,34 @@ class CategoriesRepository(DocumentBase):
             ret.remove("update_time")
         return ret
 
-class DocumentCategoriesRepository(DocumentBase):  
-    def __init__(self, engine=engine):  
-        super().__init__(DocumentCategories, engine)  
-
-
 class DocumentsRepository(DocumentBase):  
     def __init__(self, engine=engine):  
         super().__init__(Documents, engine)  
-
-    def add_document_with_categories(self, open_id, file_path, category_name="default") -> DocumentCategories:
+        
+    def add_document_with_categories(self, user, file_path, category_name="default"):
         with Session(bind=self.engine) as session:
-            doc_model:Documents = self.exec_add_or_update_file(open_id, file_path, session)
-            cr = CategoriesRepository()
-            category_model:Categories = cr.add_or_update(Categories(open_id=open_id, name=category_name),session)
-            dcr = DocumentCategoriesRepository()
-            doc_categ_model = dcr.add_or_update(DocumentCategories(id=doc_model.id, category_id=category_model.id), session)
+            doc_model:Documents = self.exec_add_or_update_file(user.open_id, file_path, session)
+            doc_model.categories.append(category_name)
+            user.docs.append(doc_model)
+            session.add(user)
             session.commit()
-            # 强制刷新,让 model 从数据库总获取最新状态
-            session.refresh(doc_model)
-            session.refresh(category_model)
-            session.refresh(doc_categ_model)
-            return (doc_model, category_model, doc_categ_model)
+            session.refresh(user)
+            return user
+    
+    # def update_document(self, user:User, ):
+    # def add_document_with_categories(self, open_id, file_path, category_name="default") -> DocumentCategories:
+    #     with Session(bind=self.engine) as session:
+    #         doc_model:Documents = self.exec_add_or_update_file(open_id, file_path, session)
+    #         cr = CategoriesRepository()
+    #         category_model:Categories = cr.add_or_update(Categories(open_id=open_id, name=category_name),session)
+    #         dcr = DocumentCategoriesRepository()
+    #         doc_categ_model = dcr.add_or_update(DocumentCategories(id=doc_model.id, category_id=category_model.id), session)
+    #         session.commit()
+    #         # 强制刷新,让 model 从数据库总获取最新状态
+    #         session.refresh(doc_model)
+    #         session.refresh(category_model)
+    #         session.refresh(doc_categ_model)
+    #         return (doc_model, category_model, doc_categ_model)
     
     def exec_add_or_update_file(self, open_id, file_path, session):
         # file_path = {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
@@ -181,7 +196,7 @@ class DocumentsRepository(DocumentBase):
 
 # 示例使用  
 def main():  
-    from db.user import test_add
+    from db.user_oauth import test_add
     open_id = test_add()
     # 创建实例  
     documents_repo = DocumentsRepository()  

+ 49 - 116
db/user.py

@@ -1,121 +1,65 @@
 import datetime
+from typing import List
 from typing import Optional
 import os
 import sys
 sys.path.append(os.path.dirname(os.path.dirname(__file__)))
-from db.engine import engine
+from db.engine import engine,create_all_table
 
-from sqlmodel import Field, SQLModel,create_engine,Session,select,func,Column
+from sqlmodel import Field, SQLModel,Relationship,create_engine,Session,select,func,Column
 import psycopg2
 from config import DB_URL,logger
 # from db.common import engine
 from sqlalchemy import UniqueConstraint, Index, asc  
 from sqlalchemy.dialects.postgresql import insert
 from db.base import DouyinBaseRepository
-
-# 定义数据库模型,不推荐使用 __tablename__ = '"UserInfo"' 来定义包含大写的表名字,
-# 因为可能会导致与其他数据库系统不兼容,而且表查询的时候需要额外注意表格名的大小写
-class UserOAuthToken(SQLModel, table=True):  
-    id: Optional[int] = Field(primary_key=True)
-    access_token:Optional[str] = Field(nullable=True)
-    expires_in: Optional[int] = Field(nullable=True)
-    expires_at: Optional[datetime.datetime] = Field(nullable=True)
-    open_id:str = Field(index=True, unique=True)
-    refresh_token:Optional[str] = Field(nullable=True)
-    refresh_expires_in: Optional[int] = Field(nullable=True)
-    refresh_expires_at: Optional[datetime.datetime] 
-    scope: Optional[str] = Field(nullable=True)
-    update_time: datetime.datetime = Field(default_factory=datetime.datetime.now,nullable=True)  # 添加时间戳字段 
-
-class UserInfo(SQLModel, table=True):  
-    id: Optional[int] = Field(primary_key=True)  
-    avatar: str  
-    avatar_larger: str  
-    client_key: str  
-    e_account_role: str = Field(default="")  
-    nickname: str  
-    # 外键约束有助于:级联操作、避免冗余、数据完整性
-    open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)  
-    union_id: str  
-    update_time: datetime.datetime = Field(default_factory=datetime.datetime.now)  # 添加时间戳字段 
-    __table_args__ = (UniqueConstraint('open_id'),) 
-
-
-
-
-
-class UserInfoRepository(DouyinBaseRepository):  
-    def __init__(self, engine=engine):  
-        super().__init__(UserInfo, engine)  
-        self.model:UserInfo
-    
+from db.user_info import UserInfo,UserInfoRepository,UserInfoLink
+from db.user_oauth import UserOAuthToken,UserOAuthTokenLink,UserOAuthRepository
+from db.video_data import VideoUserLink,VideoData
+from db.docs import Documents,DocumentsLink
+# 避免循环导入问题,参考 SQLmodel 官方: https://sqlmodel.tiangolo.com/tutorial/code-structure/
+# from typing import TYPE_CHECKING, Optional
+# if TYPE_CHECKING:
+#     from .user_oauth import UserOAuthToken
+#     from .user_info import UserInfo,LinkUserInfo,UserInfoRepository
+#     from .docs import Documents
     
+class User(SQLModel, table=True):  
+    id: Optional[int] = Field(default=None, primary_key=True)
+    open_id:str = Field(index=True, unique=True)
+    info:Optional[UserInfo] = Relationship(back_populates="user", link_model=UserInfoLink)
+    oauth: Optional[UserOAuthToken] = Relationship(back_populates="user", link_model=UserOAuthTokenLink)
+    docs:List[Documents] = Relationship(back_populates="user", link_model=DocumentsLink)
+    video_data:List[VideoData] = Relationship(back_populates="user", link_model=VideoUserLink)
     
-# Database manager class
-class UserOAuthRepository(DouyinBaseRepository):  
-    def __init__(self, engine=engine):  
-        super().__init__(UserOAuthToken, engine)  
-        self.model:UserOAuthToken
-
-    def save_login_data(self, data: dict):  
-        access_token_expires_in = datetime.datetime.now() + datetime.timedelta(seconds=data.get("expires_in"))  
-        refresh_token_expires_in = datetime.datetime.now() + datetime.timedelta(seconds=data.get("refresh_expires_in"))  
-
-        model:UserOAuthToken = self.dict_to_model(data)
-        model.expires_at = access_token_expires_in
-        model.refresh_expires_at = refresh_token_expires_in
-        return super().add_or_update(model)
-    
-    # field = UserOAuthToken.expires_at | UserOAuthToken.refresh_expires_at 
-    def select_nearest_expire(self, field: Column) -> UserOAuthToken:  
-        with Session(bind=self.engine) as session:
-            statement = select(self.model).where(field > datetime.datetime.now()).order_by(asc(field)).limit(1)
-            res = session.exec(statement)
-            return res.first()
+class UserRepo(DouyinBaseRepository):
+    def __init__(self, model: SQLModel=User):
+        super().__init__(model, engine)
+        self.model:User
         
-    def update_refresh_token(self, open_id:str, refresh_token:str, refresh_expires_in:int):
-        refresh_token_expires_at = datetime.datetime.now() + datetime.timedelta(seconds=refresh_expires_in)
-        model = UserOAuthToken(open_id=open_id, refresh_expires_in=refresh_expires_in, refresh_token=refresh_token,refresh_expires_at=refresh_token_expires_at)
-        return self.add_or_update(model)
-    
-    def update_access_token(self, open_id:str, access_token:str, expires_in:int):
-        access_token_expires_at = datetime.datetime.now() + datetime.timedelta(seconds=expires_in)
-        model = UserOAuthToken(open_id=open_id, access_token=access_token, expires_in=expires_in,expires_at=access_token_expires_at)
-        return self.add_or_update(model)
-    # def add_token(self, data: dict):  
-    #     clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)}
-    #     obj_model = self.model(**clean_data)
-    #     with Session(bind=self.engine) as session:
-    #         exist_obj = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
-    #         if exist_obj:
-    #             self.set_update_time(exist_obj)
-    #             dict_data = self.model_dump_by_field(obj_model, self.unique_constraint_fields)
-    #             self.set_obj_by_dict(exist_obj,dict_data)
-    #             session.commit()
-    #             return exist_obj
-    #         else:
-    #             self.create(obj_model)
-    #             session.commit()
-    #             return obj_model
+    def delete(self, open_id):
+        with Session(engine) as session:
+            user = session.exec(
+                select(User).where(User.open_id == open_id)
+            ).one()
+            session.delete(user)
+            logger.info(f"del {user}")
+            session.commit()
 
+    def add(self, open_id, info, oauth):
+        with Session(engine) as session:
+            user = User(open_id=info.open_id, info=info, oauth=oauth2)
+            session.add(user)
+            logger.info(f"{user}")
+            session.commit()
+            session.refresh(user)
+            logger.info(f"{user.info}")
 
-    async def delete_token(self, token_id: int):  
-        async with self.session_factory() as session:  
-            statement = select(UserOAuthToken).where(UserOAuthToken.id == token_id)  
-            token = await session.execute(statement).scalars().first()  
-            if token:  
-                await session.delete(token)  
-                await session.commit()  
-                print(f"Record deleted: ID - {token_id}")  
-            else:  
-                print(f"Record with ID {token_id} not found") 
 
     
-
-def test_add(open_id=None):  
-    SQLModel.metadata.create_all(engine)
-    
-    user_oauth = {'access_token': 'act.3.m3kiZmfxxIH95i1bHZ7Bq3Wkv_Xm5TtD3kpLGjtCr3G96WIINBKEvzlsaObrGcH4GaxTQeLZA13jkzoZhpAwPwMRqFxlVuIcxpge_-BpdFib1xHqkcFa4B-LX4zpd2YK3kDFTFfMcJXN_fZ2eByg6oqqa1OieUWcvlaVgw==', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '202402171353461C33F969CEFB511B216F', 'open_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.e4b3da8bd3ef01d880d827b11e826391OEGHiRrTLcp5zsGYP1dh6F9Bo7fg', 'scope': 'user_info,trial.whitelist'}
+def main():
+    create_all_table()
+    user_oauth = {'access_token': 'act.3.TIKYfvDL4GEk0_8ol5HNBcDefjYynPt904OdtIiOH8SKquga23fjE1kTkWKqB8oLnCcqRhjPAUMUWq2uECQv0Bhm6m8cgq3Np8EGjDBD-fHI_BQAvIM7EIieNvfec-l-VMRPdxuAI9Cx8ih_59NvEC-JEFfomG18Oj8ICQ==', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '202403130157319F6A7830924CBE383292', 'open_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.e4b3da8bd3ef01d880d827b11e826391OEGHiRrTLcp5zsGYP1dh6F9Bo7fg', 'scope': 'user_info,trial.whitelist,item.comment'}
     user_oauth2 = {'access_token': 'act', 'expires_in': 19290, 'open_id': '55test2', 'refresh_expires_in': 1950, 'refresh_token': 'rft', 'scope': 'user_info,trial.whitelist'}
     user_info = {
     "avatar": "https://p26.douyinpic.com/aweme/100x100/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
@@ -136,21 +80,10 @@ def test_add(open_id=None):
     "province": "",
     "union_id": "123-01ae-59bd-978a-1de8566186a8"
   }
-    if open_id:
-        user_oauth["open_id"] = open_id
-        user_info["open_id"] = open_id
-        user_info["nickname"] = "user" + open_id[:5]
-    else:
-        open_id = user_oauth["open_id"]
-    db_manager = UserOAuthRepository()
-    res = db_manager.save_login_data(user_oauth2)
-    res = db_manager.select_nearest_expire(UserOAuthToken.expires_at)
-    
-    logger.debug(res)
-    # db_user_info = UserInfoRepository()
-    # res = db_user_info.add_or_update(user_info)
-    # logger.debug(db_manager.get_by_open_id(open_id))
-    return user_oauth["open_id"]
-
+    db_user = UserRepo()
+    info:UserInfo = db_user.dict_to_model(user_info, UserInfo)
+    oauth:UserOAuthToken = UserOAuthRepository().dict_to_model(user_oauth, UserOAuthToken)
+    oauth2 = UserOAuthRepository().dict_to_model(user_oauth2, UserOAuthToken)
+    db_user.delete("b9b71865-7fea-44cc-123")
 if __name__ == "__main__":
-    test_add()
+    main()

+ 88 - 0
db/user_info.py

@@ -0,0 +1,88 @@
+import datetime
+from typing import Optional
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))
+from db.engine import engine
+
+from sqlmodel import Field, SQLModel,Relationship,create_engine,Session,select,func,Column
+import psycopg2
+from config import DB_URL,logger
+# from db.common import engine
+from sqlalchemy import UniqueConstraint, Index, asc  
+from sqlalchemy.dialects.postgresql import insert
+from db.base import DouyinBaseRepository
+
+'''
+https://developer.open-douyin.com/docs/resource/zh-CN/dop/develop/openapi/account-management/get-account-open-info
+理论上无需将用户信息存入数据库,因为抖音已经存储了这些信息,我本地再存储一份无疑造成数据冗余。
+
+直接从抖音获取:
+- 可以获得很好的实时性。比如用户更新了头像,在我的网站应用中也能实时更新
+- 避免数据冗余,不增加无意义的代码量
+- 没有请求次数限制
+
+存储在本地:
+- 我有大量的本地数据分析的需求
+- 避免延迟和其他网络问题,但通常抖音服务器不会有网络问题
+
+综上,我认为不需要将用户信息存储在本地数据库。因为抖音并没有限制该接口请求次数。
+
+'''
+
+# 避免循环导入问题,参考 SQLmodel 官方: https://sqlmodel.tiangolo.com/tutorial/code-structure/
+from typing import TYPE_CHECKING, Optional
+if TYPE_CHECKING:
+    from .user import User
+
+
+class UserInfoLink(SQLModel, table=True):
+    user_id: Optional[int] = Field(
+        default=None, foreign_key="user.id", primary_key=True
+    )
+    info_id: Optional[int] = Field(
+        default=None, foreign_key="userinfo.id", primary_key=True
+    )
+
+class UserInfo(SQLModel, table=True):  
+    id: Optional[int] = Field(default=None, primary_key=True)  
+    avatar: Optional[str] = Field(default=None)  
+    avatar_larger: Optional[str] = Field(default=None)  
+    client_key:  Optional[str] = Field(default=None)  
+    e_account_role:  Optional[str] = Field(default=None)    
+    nickname:  Optional[str] = Field(default=None)   
+    # 外键约束有助于:级联操作、避免冗余、数据完整性
+    open_id: Optional[str] = Field(index=True, unique=True)
+    union_id:  Optional[str] = Field(default=None)  
+    update_time: datetime.datetime = Field(default_factory=datetime.datetime.now)  # 添加时间戳字段 
+    user:"User" = Relationship(back_populates="info", link_model=UserInfoLink, sa_relationship_kwargs={"cascade": "all, delete-orphan", "single_parent":True})
+
+
+class UserInfoRepository(DouyinBaseRepository):  
+    def __init__(self, engine=engine):  
+        super().__init__(UserInfo, engine)  
+        self.model:UserInfo
+    
+async def main():
+    from douyin.user_info import get_user_info
+    from db.user_oauth import UserOAuthRepository
+    db_oauth = UserOAuthRepository()
+    oauth_model = db_oauth.get_by_open_id()
+    data = await get_user_info(oauth_model.open_id, oauth_model.access_token)
+    if data.get("error_code") != 0:
+        raise HTTPException(status_code=400, detail=data)  
+    db_user = UserInfoRepository()
+    user_info_model = db_user.dict_to_model(data)
+    db_user.add_or_update(user_info_model)
+    
+    # 生成并返回 token,包含过期时间  
+    payload = {  
+        "sub": data["open_id"],
+        "exp": exp  # 添加过期时间戳(北京时间)到 payload  
+    }  
+    account_token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")  
+    logger.info(f"login success, expires_time:{datetime.datetime.fromtimestamp(exp).strftime('%Y-%m-%d %H:%M:%S') }, token:{account_token}")
+
+
+if __name__ == "__main__":
+    main()

+ 131 - 0
db/user_oauth.py

@@ -0,0 +1,131 @@
+import datetime
+from typing import Optional
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))
+from db.engine import engine
+
+from sqlmodel import Field, SQLModel,create_engine,Session,select,func,Column,Relationship
+import psycopg2
+from config import DB_URL,logger
+# from db.common import engine
+from sqlalchemy import UniqueConstraint, Index, asc  
+from sqlalchemy.dialects.postgresql import insert
+from db.base import DouyinBaseRepository
+from typing import TYPE_CHECKING, Optional
+if TYPE_CHECKING:
+    from .user import User
+
+class UserOAuthTokenLink(SQLModel, table=True):
+    user_id: Optional[int] = Field(
+        default=None, foreign_key="user.id", primary_key=True
+    )
+    oauth_id: Optional[int] = Field(
+        default=None, foreign_key="useroauthtoken.id", primary_key=True
+    )
+    
+# 定义数据库模型,不推荐使用 __tablename__ = '"UserInfo"' 来定义包含大写的表名字,
+# 因为可能会导致与其他数据库系统不兼容,而且表查询的时候需要额外注意表格名的大小写
+class UserOAuthToken(SQLModel, table=True):  
+    id: Optional[int] = Field(default=None, primary_key=True)
+    access_token:Optional[str] = Field(nullable=True)
+    expires_in: Optional[int] = Field(nullable=True)
+    expires_at: Optional[datetime.datetime] = Field(nullable=True)
+    open_id:Optional[str] = Field(index=True, unique=True)
+    refresh_token:Optional[str] = Field(nullable=True)
+    refresh_expires_in: Optional[int] = Field(nullable=True)
+    refresh_expires_at: Optional[datetime.datetime] 
+    scope: Optional[str] = Field(nullable=True)
+    update_time: datetime.datetime = Field(default_factory=datetime.datetime.now,nullable=True)  # 添加时间戳字段 
+    user:"User" = Relationship(back_populates="oauth", link_model=UserOAuthTokenLink, sa_relationship_kwargs={"cascade": "all, delete-orphan", "single_parent":True})
+    
+
+
+# Database manager class
+class UserOAuthRepository(DouyinBaseRepository):  
+    def __init__(self, engine=engine):  
+        super().__init__(UserOAuthToken, engine)  
+        self.model:UserOAuthToken
+
+    def save_login_data(self, data: dict):  
+        access_token_expires_in = datetime.datetime.now() + datetime.timedelta(seconds=data.get("expires_in"))  
+        refresh_token_expires_in = datetime.datetime.now() + datetime.timedelta(seconds=data.get("refresh_expires_in"))  
+
+        model:UserOAuthToken = self.dict_to_model(data)
+        model.expires_at = access_token_expires_in
+        model.refresh_expires_at = refresh_token_expires_in
+        return super().add_or_update(model)
+    
+    # field = UserOAuthToken.expires_at | UserOAuthToken.refresh_expires_at 
+    def select_nearest_expire(self, field: Column) -> UserOAuthToken:  
+        with Session(bind=self.engine) as session:
+            statement = select(self.model).where(field > datetime.datetime.now()).order_by(asc(field)).limit(1)
+            res = session.exec(statement)
+            return res.first()
+        
+    def update_refresh_token(self, open_id:str, refresh_token:str, refresh_expires_in:int):
+        refresh_token_expires_at = datetime.datetime.now() + datetime.timedelta(seconds=refresh_expires_in)
+        model = UserOAuthToken(open_id=open_id, refresh_expires_in=refresh_expires_in, refresh_token=refresh_token,refresh_expires_at=refresh_token_expires_at)
+        return self.add_or_update(model)
+    
+    def update_access_token(self, open_id:str, access_token:str, expires_in:int):
+        access_token_expires_at = datetime.datetime.now() + datetime.timedelta(seconds=expires_in)
+        model = UserOAuthToken(open_id=open_id, access_token=access_token, expires_in=expires_in,expires_at=access_token_expires_at)
+        return self.add_or_update(model)
+
+
+    async def delete_token(self, token_id: int):  
+        async with self.session_factory() as session:  
+            statement = select(UserOAuthToken).where(UserOAuthToken.id == token_id)  
+            token = await session.execute(statement).scalars().first()  
+            if token:  
+                await session.delete(token)  
+                await session.commit()  
+                print(f"Record deleted: ID - {token_id}")  
+            else:  
+                print(f"Record with ID {token_id} not found") 
+
+    
+
+def test_add(open_id=None):  
+    SQLModel.metadata.create_all(engine)
+    
+    user_oauth = {'access_token': 'act.3.m3kiZmfxxIH95i1bHZ7Bq3Wkv_Xm5TtD3kpLGjtCr3G96WIINBKEvzlsaObrGcH4GaxTQeLZA13jkzoZhpAwPwMRqFxlVuIcxpge_-BpdFib1xHqkcFa4B-LX4zpd2YK3kDFTFfMcJXN_fZ2eByg6oqqa1OieUWcvlaVgw==', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '202402171353461C33F969CEFB511B216F', 'open_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.e4b3da8bd3ef01d880d827b11e826391OEGHiRrTLcp5zsGYP1dh6F9Bo7fg', 'scope': 'user_info,trial.whitelist'}
+    user_oauth2 = {'access_token': 'act', 'expires_in': 19290, 'open_id': '55test2', 'refresh_expires_in': 1950, 'refresh_token': 'rft', 'scope': 'user_info,trial.whitelist'}
+    user_info = {
+    "avatar": "https://p26.douyinpic.com/aweme/100x100/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
+    "avatar_larger": "https://p3.douyinpic.com/aweme/1080x1080/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
+    "captcha": "",
+    "city": "",
+    "client_key": "55",
+    "country": "",
+    "desc_url": "",
+    "description": "",
+    "district": "",
+    "e_account_role": "",
+    "error_code": 0,
+    "gender": 0,
+    "log_id": "202401261424326FE877A6CAB03910C553",
+    "nickname": "程序员马工",
+    "open_id": "_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy",
+    "province": "",
+    "union_id": "123-01ae-59bd-978a-1de8566186a8"
+  }
+    if open_id:
+        user_oauth["open_id"] = open_id
+        user_info["open_id"] = open_id
+        user_info["nickname"] = "user" + open_id[:5]
+    else:
+        open_id = user_oauth["open_id"]
+    db_manager = UserOAuthRepository()
+    res = db_manager.save_login_data(user_oauth2)
+    res = db_manager.select_nearest_expire(UserOAuthToken.expires_at)
+    
+    logger.debug(res)
+    # db_user_info = UserInfoRepository()
+    # res = db_user_info.add_or_update(user_info)
+    # logger.debug(db_manager.get_by_open_id(open_id))
+    return user_oauth["open_id"]
+
+if __name__ == "__main__":
+    test_add()

+ 110 - 0
db/video_data.py

@@ -0,0 +1,110 @@
+import json
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))
+from enum import Enum  
+import datetime
+from sqlmodel import Field, SQLModel,Relationship,Column,create_engine,Session,select,func,UniqueConstraint
+from sqlalchemy.dialects.postgresql import JSONB,JSON
+from typing import Optional,Dict,Tuple
+from db.base import BaseRepository
+from db.base import DouyinBaseRepository
+
+from db.engine import engine,create_all_table
+from config import logger
+from uuid import UUID  
+# 避免循环导入问题,参考 SQLmodel 官方: https://sqlmodel.tiangolo.com/tutorial/code-structure/
+from typing import TYPE_CHECKING, Optional
+if TYPE_CHECKING:
+    from .user import User
+    from db.docs import Documents,DocumentsRepository
+
+class VideoStatisticsLinks(SQLModel, table=True): 
+    video_data_id: Optional[int] = Field(  
+        default=None,   
+        foreign_key="videodata.id",
+        primary_key=True
+    )
+    statistics_id: Optional[int] = Field(  
+        default=None,   
+        foreign_key="statistics.id",
+        primary_key=True
+    )
+    
+class VideoUserLink(SQLModel, table=True):  
+    __tablename__ = "video_data_link"  
+    video_data_id: Optional[int] = Field(  
+        default=None,   
+        foreign_key="videodata.id",
+        primary_key=True
+    )
+    user_id: Optional[int] = Field(  
+        default=None,   
+        foreign_key="user.id",
+        primary_key=True,
+    )
+
+class VideoDocLink(SQLModel, table=True):  
+    video_data_id: Optional[int] = Field(  
+        default=None,   
+        foreign_key="videodata.id",
+        primary_key=True,
+    )
+    doc_id: Optional[int] = Field(  
+        default=None,   
+        foreign_key="documents.id",
+        primary_key=True,
+    )
+    
+class Statistics(SQLModel, table=True):  
+    id: int = Field(default=None, primary_key=True)  
+    digg_count: Optional[int] = Field(default=None)  
+    download_count: Optional[int] = Field(default=None)  
+    play_count: Optional[int] = Field(default=None)  
+    share_count: Optional[int] = Field(default=None)  
+    forward_count: Optional[int] = Field(default=None)  
+    comment_count: Optional[int] = Field(default=None)  
+    video_data: Optional["VideoData"] = Relationship(back_populates="statistics", link_model=VideoStatisticsLinks, sa_relationship_kwargs={"cascade": "all, delete-orphan", "single_parent":True})  
+  
+class VideoData(SQLModel, table=True):  
+    id: int = Field(default=None, primary_key=True,index=True)  
+    title: Optional[str] = Field(default=None)  
+    create_time: Optional[int] = Field(default=None)  
+    video_status: Optional[int] = Field(default=None)  
+    share_url: Optional[str] = Field(default=None)  
+    cover: Optional[str] = Field(default=None)  
+    is_top: Optional[bool] = Field(default=None)  
+    item_id: str = Field(index=True, unique=True)  
+    is_reviewed: Optional[bool] = Field(default=None)  
+    media_type: Optional[int] = Field(default=None)  
+    
+    user:Optional["User"] = Relationship(back_populates="video_data", link_model=VideoUserLink, sa_relationship_kwargs={"cascade": "all, delete-orphan", "single_parent":True})
+    statistics: Optional[Statistics] = Relationship(back_populates="video_data", link_model=VideoStatisticsLinks)
+    doc:Optional["Documents"] = Relationship(back_populates="video_data", link_model=VideoDocLink)
+
+
+class VideoItemDocRepo(DouyinBaseRepository):
+    def __init__(self, model: VideoData=VideoData, engine=...):
+        super().__init__(model, engine)
+    
+    def add_vedio_item_doc(self, item_id, open_id, file_path, category_name="_video"):
+        with Session(bind=self.engine) as session:
+            db_doc = DocumentsRepository()
+            doc_model = db_doc.add_document_with_categories(open_id, file_path, category_name)
+            logger.info(f"doc_model: {doc_model}")
+            # video_model = VideoItemDoc(item_id=item_id, doc=doc_model)
+            # logger.info(f"{video_model}")
+            # session.add(video_model)
+            # session.commit()
+            # session.refresh(video_model)
+            # logger.info(f"{video_model}")
+            # logger.info(f"{video_model.doc}")
+            return video_model
+            
+def main():
+    create_all_table()
+    db = VideoItemDocRepo()
+    db.add_vedio_item_doc("item_123", "_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk", "_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk/docs/_video/item_id_1.txt")
+
+if __name__ == "__main__":
+    main()

+ 0 - 0
db/video_item_doc.py


+ 68 - 131
docs/gpt.md

@@ -1,141 +1,78 @@
-## 项目简介 swl-douyin
-你当前在 open-douyin 这个目录下面。这个项目是一个基于 Fastapi 框架的网站后端。用户实现授权第三方平台抖音登录,注册到我的网站中。网站包含了文档上传、删除文档、文档向量转换、抖音 API 获取用户信息、视频评论回复等的功能。
-
-git ls-files
-.gitignore
-
-api/jwt.py
-Fastapi 用户鉴权登录,扫码后会通过 jwt token 解析到用户的 open_id,查询 open_id 是否在数据库中,如果存在说明登录成功,允许用户访问后台数据
-
-api/login.py
-抖音扫码登录后,前端通过扫码结果的 code 信息请求 login 到这个端点上,然后 login.py 通过 code 访问 抖音 API 如果返回成功,则登录成功,并记录和返回携带用户 open_id 的 Token
-
-
-api/readme.md :
-说明文档
-
-api/swl.http
-http 请求草稿
-
-api/upload.py
-前端请求上传文件的路径
-
-api/weixin_pay copy.py
-api/weixin_pay.py:
-用户通过微信支付扫码充值。充值后才能使用网站的付费应用
-
-config.py :
-本项目的配置文件
-
-db/base.py :
-数据库增删改查基类
-
-db/docs.py :
-用户上传的文档 docs 数据库模型定义,操作数据库的实现
-
-db/engine.py
-数据库引擎,用法: from db.engine import engine,create_all_table
-
-db/readme.md
-本目录的说明文档
-
-db/user.py
-用户登录的相关数据库定义和模型
-
-douyin/access_token.py
-抖音 API 接口,请求用户 access_toekn ,用户获取抖音用户公开信息
-
-douyin/user_info.py
-抖音 API 接口,获取抖音用户公开信息
-
-grpc_m/client.py
-gRPC 客户端,用于将文档发送到我的另一个 微服务 vector server 并转换成向量数据存放在该服务中
-
-grpc_m/send_data_to_vector.py
-发送文档数据到向量服务器
-
-grpc_m/vector_service.proto
-gRPC 协议
-
-grpc_m/vector_service_pb2.py
-grpc_m/vector_service_pb2.pyi
-grpc_m/vector_service_pb2_grpc.py
-
-
-main.py
-主程序入口,已经完成代码
+## 
+项目简介
+项目名称: swl-douyin
+
+项目简介: 一个基于 FastAPI 框架的网站后端,用于用户登录、文档上传、文档向量转换、抖音 API 获取用户信息、视频评论回复等功能。
+
+项目结构如下:
+
+api/ 目录下包含了一系列处理用户登录鉴权、上传文档、微信支付等相关功能的 Python 脚本。
+douyin/ 目录下是与抖音 API 相关的接口,用于获取 access_token 和用户信息。
+grpc_m/ 目录下为 gRPC 相关的客户端和服务端代码,其中微服务 VectorService 用于接收文档并将其转换为向量数据存储。
+
+├── api
+│   ├── jwt.py
+│   ├── login.py
+│   ├── readme.md
+│   ├── swl.http
+│   ├── upload.py
+│   ├── weixin_pay.py
+├── config.py
+├── db
+│   ├── base.py
+│   ├── docs.py
+│   ├── engine.py
+│   ├── readme.md
+│   ├── user.py
+├── douyin
+│   ├── access_token.py
+│   ├── user_info.py
+├── grpc_m
+│   ├── client.py
+│   ├── send_data_to_vector.py
+│   ├── vector_service.proto
+│   ├── vector_service_pb2.py
+│   ├── vector_service_pb2_grpc.py
+├── main.py
+├── readme.md
+项目功能:
+
+用户通过抖音扫码登录
+用户上传文档
+将文档转换为向量
+获取抖音用户公开信息
+评论抖音视频
 
-readme.md
-```
-
-在另一个微服务 LangChain 项目中 main.py 如下:
 ```python
-import os
-import sys
-sys.path.append(os.path.dirname(os.path.dirname(__file__)))
-from concurrent import futures
-import time  
-import grpc  
-import logging  
-import uuid  
-from db_vector.vector_for_douyin import save_user_doc_to_vector  # 假设这里有必要的导入  
-from grpc_m import vector_service_pb2, vector_service_pb2_grpc  
-from config import MNT_DOUYIN_DATA
-
-
-class VectorService(vector_service_pb2_grpc.VectorServiceServicer):  
-    def SaveDocToVector(self, request, context):  
-        category_id = request.category_id
-        user_doc_relative_path = request.user_doc_relative_path
-        mnt_user_docs_path = os.path.join(MNT_DOUYIN_DATA, user_doc_relative_path)
-        if not mnt_user_docs_path:
-            return vector_service_pb2.SaveDocToVectorResponse(status=vector_service_pb2.ErrorCode.MNT_DOUYIN_DATA_ERROR)  
-        status = save_user_doc_to_vector(category_id, mnt_user_docs_path)  # 这个函数需要根据实际情况来实现  
-        return vector_service_pb2.SaveDocToVectorResponse(status=status)  
-
-
-def serve():
-    port = "18600"
-    server = grpc.server(futures.ThreadPoolExecutor(max_workers=30))  
-    vector_service_pb2_grpc.add_VectorServiceServicer_to_server(VectorService(), server)
-    server.add_insecure_port("0.0.0.0:" + port)
-    server.start()
-    print("Server started, listening on " + port)
-    server.wait_for_termination()
-  
-if __name__ == '__main__':  
-    serve()
+class Documents(SQLModel, table=True):  
+    id: UUID4 = Field(default_factory=uuid.uuid1, primary_key=True,index=True)  # 使用 UUID v1 作为主键 
+    open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)  # 关联到用户表的外键  
+    path: str = Field(nullable=False) # 相对路径
+    status: int = Field(nullable=False) # 文档状态  
+    update_time: datetime = Field(default_factory=datetime.now)  # 创建时间、更新时间
+    __table_args__ = (UniqueConstraint('open_id', 'path', name='uq_documents'),) 
 ```
+如上向量文档数据库模型,用来存储用户自定义上传的文档,以便于后续向量化处理。
+同时用户自己也有许多视频,我需要把视频也转化成文本,每一个视频 item_id ,对应一个文档 Documents.id ,并且进行向量化。
+对于自定义文档和视频文本两者的数据关系,我应该如何定义表结构?
+如果自定义文档我假设有一个分类表,可以按不同文档分类:
+```python
+class Categories(SQLModel, table=True):  
+    id: UUID4 = Field(default_factory=uuid.uuid1, primary_key=True)  # 使用 UUID v1 作为主键 
+    open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)  # 关联到用户表的外键  
+    name: str = Field(default="default", index=True)  # 分类的名称,添加索引以优化查询性能  
+    update_time: datetime = Field(default_factory=datetime.now)  # 创建时间、更新时间
+    # 添加联合唯一约束  
+    __table_args__ = (UniqueConstraint('open_id', 'name', name='uq_open_id_ctname'),)
 
-gRPC grpc_m/vector_service.proto 文件:
-```
-syntax = "proto3";  
-  
-package grpc_m;  
-// 定义错误枚举类型
-enum ErrorCode {
-    SUCCESS = 0;
-    DOC_CONVERT_ERROR = 1;
-    VECTOR_SERVER_ERROR = 2;
-}
 
-service VectorService {  
-    rpc SaveDocToVector (SaveDocToVectorRequest) returns (SaveDocToVectorResponse) {}  
-}  
-  
-message SaveDocToVectorRequest {  
-    string category_id = 1;  
-    string user_doc_relative_path = 2; 
-}  
+class DocumentCategories(SQLModel, table=True):
+    id: UUID4 = Field(foreign_key="documents.id",index=True, primary_key=True)  # 关联到文档表的外键  
+    category_id: UUID4 = Field(foreign_key="categories.id",index=True)  # 关联到分类表的外键  
+    __table_args__ = (UniqueConstraint('id', 'category_id', ),)
 
-message SaveDocToVectorResponse {  
-    ErrorCode status = 1;  
-}
 ```
-
-具体逻辑:前端用户上传文档 网站后端服务 Fastapi 接收 - 在 PostgreSQL 数据库查询用户是否注册:鉴权成功 - Fastapi 根据已经鉴权的用户请求 LangChain 微服务将文档保存到向量数据库 - Fastapi 将收到的向量数据库 id 保存在后端 PostgreSQL 数据库 - 前端用户发起文档对话 - 网站后端服务 Fastapi 接收鉴权成功 - 后端Fastapi请求 LangChain 微服务 进行向量相似度计算,文本片段组合成 prompt 请求 LLM api ,将 LLM 回复的结果给后端 - 后端服务 Fastapi将结果返回给前端用户
-
-现在的业务逻辑我感觉有点冗余。似乎前端向 Fastapi 后端获得鉴权jwt 认证后,前端直接向 PostgreSQL 数据库服务器和向量数据库获取文档以及向量数据搜索转换,岂不是不需要 Fastapi 多写一层中转代码?我这种思路符合程序最佳架构吗?
+那么视频对应的一个文档,又该如何关联这些表结构或分类?给我最佳合理的程序设计、表结构设计架构。
 
 # 级联数据结构
 ```python

+ 1 - 1
douyin/access_token.py

@@ -5,7 +5,7 @@ import httpx
 from typing import Optional
 from enum import Enum
 from pydantic import BaseModel
-from db.user import UserOAuthToken,UserOAuthRepository
+from db.user_oauth import UserOAuthToken,UserOAuthRepository
 from config import *
 
 class DouyinAccessTokenResponse(BaseModel):

+ 0 - 0
douyin/comment.py → douyin/comment_reply.py


+ 3 - 3
douyin/manage_user.py

@@ -5,10 +5,10 @@ import os
 import sys
 sys.path.append(os.path.dirname(os.path.dirname(__file__)))
 
-from db.user import UserInfo,UserOAuthRepository,UserOAuthToken,UserInfoRepository
-from db.docs import DocStatus,DocumentCategories,DocumentCategoriesRepository,Documents,DocumentBase,DocumentsRepository
+from db.user_oauth import UserOAuthRepository,UserOAuthToken
+from db.user_info import UserInfo,UserInfoRepository
 from douyin.access_token import get_access_token,refresh_access_token,renew_refresh_token
-from douyin.comment import reply_to_comment
+from douyin.comment_reply import reply_to_comment
 import loguru
 from config import LOG_DIR,logger
 

+ 10 - 0
douyin/search_video.py

@@ -0,0 +1,10 @@
+'''
+接口说明
+抖音的 OpenAPI 以 https://open.douyin.com/  开头。
+基本信息
+名称	描述
+HTTP URL	
+https://open.douyin.com/api/douyin/v1/video/video_list/
+HTTP Method	GET 
+video.list.bind
+'''

+ 12 - 2
douyin/user_info.py

@@ -27,9 +27,19 @@ async def get_user_info(open_id, access_token):
     logger.debug(res)
     return res
     '''return 
-    {'data': {'avatar': 'https://p6.douyinpic.com/aweme/100x100/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038', 'avatar_larger': 'https://p11.douyinpic.com/aweme/1080x1080/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038', 'captcha': '', 'city': '', 'client_key': 'aw6aipmfdtplwtyq', 'country': '', 'desc_url': '', 'description': '', 'district': '', 'e_account_role': '', 'error_code': 0, 'gender': 0, 'log_id': '20240129142818189D643B12E3055CE271', 'nickname': '程序员马工', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'province': '', 'union_id': 'b138db97-01ae-59bd-978a-1de8566186a8'}, 'message': 'success'}
+    {'data': 
+        {'avatar': 'https://p6.douyinpic.com/aweme/100x100/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038',
+        'avatar_larger': 'https://p11.douyinpic.com/aweme/1080x1080/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038',
+        'captcha': '', 'city': '', 'client_key': 'aw6aipmfdtplwtyq', 'country': '',
+        'desc_url': '', 'description': '', 'district': '',
+        'e_account_role': '', 'error_code': 0, 'gender': 0, 'log_id': '20240129142818189D643B12E3055CE271',
+        'nickname': '程序员马工', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'province': '',
+        'union_id': 'b138db97-01ae-59bd-978a-1de8566186a8'}, 'message': 'success'}
     '''
-    
+
+# async def get_userinfo_from_db_or_request()
+
+
 async def main():
     open_id = "_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy"
     access_token = "act.3.wl8L3DFQ3sj3uKYzQShOSs8HbOgKh0FVvjxKeaTum0ZOEXoyBI8D1N7gTBqGbrY32KP-Pm41EAvcobSheOBi8tvRdhj7m5-5ZVoprZZu_GN5J2KnH2fZ_X9_l7Q6iFyvpPoMkX3Zyom3PCkeRZp4Jg9sE2ZiwuvZVdnvft0A25uBWXvj2IEbWW_0Bf8="

+ 5 - 0
douyin/video_datapy

@@ -0,0 +1,5 @@
+# https://developer.open-douyin.com/docs/resource/zh-CN/dop/develop/openapi/video-management/douyin/search-video/video-data
+# 查询特定视频的数据。
+# 由于抖音开放平台网站应用没有正是上线,无法通过用户授权查询视频列表,只能通过评论回调查询到视频 item_id ,然后根据 item_id 查询到特定视频数据
+
+

+ 169 - 0
douyin/video_get_iframe_by_item.py

@@ -0,0 +1,169 @@
+'''
+通过ItemID获取IFrame代码
+更新时间 2023-07-06 19:38:15
+该接口用于通过视频 ItemID获取 IFrame 代码。视频 ItemID可以通过【查询授权账号视频列表】、【查询特定视频的视频数据】、【查询视频分享结果及数据】等能力获取。
+
+该接口以 https://open.douyin.com/ 开头
+请求地址: GET /api/douyin/v1/video/get_iframe_by_item
+
+参数名称
+item_id 视频ItemID(请注意使用Base64URL编码):
+QDcyTTBXQ3F0eHVZR2l3dGtWVVlGdkgxbnR1TTBSZlF3VkgrUmZUMzBXVVY4Q3RMUXJVOU1wSlNZRVZWNi8yaEsrTm8zQVBUYy84T1U3VkJiQkhJM3NnPT0=
+
+client_key 应用标识 awexxxxxxx
+
+请求样例
+curl --location --request GET 'https://open.douyin.com/api/douyin/v1/video/get_iframe_by_item?item_id=QDcyTTBXQ3F0eHVZR2l3dGtWVVlGdkgxbnR1TTBSZlF3VkgrUmZUMzBXVVY4Q3RMUXJVOU1wSlNZRVZWNi8yaEsrTm8zQVBUYy84T1U3VkJiQkhJM3NnPT0=&client_key=awexxxxxxx'
+
+响应样例
+{
+  "data": {
+      'iframe_code': '<iframe width="1080" height="1920" frameborder="0" src="https://open.douyin.com/player/video?vid=7259290547288870144&autoplay=0" referrerpolicy="unsafe-url" allowfullscreen></iframe>',
+      'video_height': 1920, 
+      'video_title': '全国肥胖率胖子多的省份排名,有你们省吗? #减肥 #肥胖率 #广西',
+      'video_width': 1080
+      },
+  "err_msg": "",
+  "err_no": 0,
+  "log_id": "20221025205044010225243125063FAD62"
+}
+
+响应错误样例
+{
+  "err_msg": "系统内部错误,请重试",
+  "err_no": 28001005,
+  "log_id": "20221025205044010225243125063FAD62"
+}
+'''
+import json
+import os
+import re  
+import httpx  
+import base64  
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))
+
+from config import logger,ASR_EXE,ASR_ADDRESS,TEMP_DIR
+
+async def get_iframe_by_item_id(item_id):  
+    client_key = os.environ.get("CLIENT_KEY")  # 从环境变量中获取 client_key  
+    base64_encoded_item_id = base64.urlsafe_b64encode(item_id.encode()).decode('utf-8')  
+      
+    async with httpx.AsyncClient() as client:  
+        response = await client.get(  
+            f"https://open.douyin.com/api/douyin/v1/video/get_iframe_by_item?item_id={base64_encoded_item_id}&client_key={client_key}"  
+        )  
+          
+        res_json = response.json()  
+        err_no = res_json.get("err_no")  
+          
+        if err_no != 0:  
+            raise Exception(f"Error fetching IFrame code: {res_json.get('err_msg')}")  
+        return res_json.get("data")
+    
+
+async def get_video_id(item_id):  
+    data = await get_iframe_by_item_id(item_id)
+    # 'iframe_code': '<iframe width="1080" height="1920" frameborder="0" src="https://open.douyin.com/player/video?vid=7259290547288870144&autoplay=0" referrerpolicy="unsafe-url" allowfullscreen></iframe>',
+    iframe = data.get("iframe_code")
+    match = re.search(r'vid=(\d+)&', iframe)  
+    
+    if match:  
+        vid = match.group(1)  
+        logger.info(f"提取的vid: {vid}")
+        return vid
+    else:  
+        logger.info("未找到vid")   
+
+async def fetch_video_data(url, video_type="", platform=""):  
+    payload = {  
+        "url": url,  
+        "type": video_type,  
+        "platform": platform  
+    }  
+      
+    async with httpx.AsyncClient() as client:  
+        logger.info(f"{json.dumps(payload)  }")
+        response = await client.post(  
+            "http://10.0.0.12:9082/video-get",  
+            headers={"Content-Type": "application/json"},  
+            content=json.dumps(payload)  
+        )  
+          
+        res_json = response.json()  
+          
+        # 假设响应中包含'err_no'和可能的'err_msg'字段来表示错误  
+        err_no = res_json.get("code")  
+          
+        if err_no != 0:  
+            logger.error(f"Error fetching video data: {res_json}")  
+          
+        return res_json.get("data")
+
+
+async def download_video(item_id, file_path):
+    data = await get_iframe_by_item_id(item_id)
+    # iframe = '<iframe width="1080" height="1920" frameborder="0" src="https://open.douyin.com/player/video?vid=7259290547288870144&autoplay=0" referrerpolicy="unsafe-url" allowfullscreen></iframe>',
+    iframe = data.get("iframe_code")
+    url = await fetch_video_data(iframe)
+    async with httpx.AsyncClient() as client:  
+        response = await client.get(url[0])  
+
+        # 检查是否成功获取响应  
+        if response.status_code == 200:  
+            # 以二进制写模式打开一个文件  
+            with open(file_path, 'wb') as file:  
+                # 异步迭代响应的内容块并写入文件  
+                async for chunk in response.aiter_bytes():  
+                    file.write(chunk)  
+        else:  
+            logger.error(f"Failed to download video: {response.status_code}")    
+            return
+        return file_path
+
+async def video_to_txt(file_path):
+    host, port = ASR_ADDRESS.split(":")  
+    output_dir = os.path.join(TEMP_DIR,"asr")  
+    if not os.path.exists(output_dir):  
+        os.makedirs(output_dir)  
+
+    cmd = [  
+        ASR_EXE,  
+        "--host", host,  
+        "--port", port,  
+        "--mode", "offline",  
+        "--output_dir", output_dir,  
+        "--ssl", "0",  
+        "--audio_in", file_path  
+    ]  
+    logger.info(f"cmd: {' '.join(cmd)}")
+    process = await asyncio.create_subprocess_exec(  
+        *cmd,  
+        stdout=asyncio.subprocess.PIPE,  
+        stderr=asyncio.subprocess.PIPE
+    )  
+      
+    stdout, stderr = await process.communicate()  
+    # 处理结果,例如检查返回值、输出等  
+    # 注意:这里只是简单地打印了输出,您可能需要根据实际情况进行处理  
+    logger.info(f"{stdout.decode()}")  
+    logger.info(f"{stderr.decode()}")
+       
+async def task():
+    item_id = "@9VxS1/qCUc80K2etd8wkUc7912DgP/GCPpF2qwKuJ1YTb/X460zdRmYqig357zEBKzkoKYjBMUvl9Bs6h+CwYQ=="
+    # res = await get_iframe_by_item_id()
+    file_path = "/home/user/code/open-douyin/log/t.mp4"
+    # vid = await download_video(item_id, file_path)
+    # print(f"{vid}")
+    await video_to_txt("/home/user/code/open-douyin/log/t.mp4")
+    
+import asyncio
+import aiofiles
+import os
+import sys
+def main():
+    asyncio.run(task())
+
+if __name__ == "__main__":
+    main()

+ 25 - 18
grpc_m/client.py

@@ -1,10 +1,14 @@
+import asyncio
 import os
 import sys
 sys.path.append(os.path.dirname(os.path.dirname(__file__)))
-import grpc  
-from grpc_m import vector_service_pb2, vector_service_pb2_grpc  
+import asyncio
+
+from grpclib.client import Channel
+
+from grpc_m.vector import vector_grpc, vector_pb2  
 from db.docs import DocumentsRepository  
-from db.user import test_add  # 假设这里有必要的导入来获取open_id  
+from db.user_oauth import test_add  # 假设这里有必要的导入来获取open_id  
 from config import logger,GRPC_VECTOR
 
 async def load_user_docs_async(open_id: str, file_path: list[str]):  
@@ -21,24 +25,27 @@ async def load_user_docs_async(open_id: str, file_path: list[str]):
         logger.info("Response: success = {}".format(response.success))  
         return response  
 
-def run():  
+async def run():  
     # 创建gRPC通道和存根  
-    with grpc.insecure_channel('localhost:18600') as channel:  
-        stub = vector_service_pb2_grpc.VectorServiceStub(channel)  
-          
-        # 获取open_id(这里只是示例,具体实现可能会有所不同)  
-        open_id = "_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy"
+    async with Channel('localhost',18600) as channel:  
+        stub = vector_grpc.VectorServiceStub(channel)  
           
-        # 创建DocumentsRepository实例并获取用户文件路径列表  
-        documents_repo = DocumentsRepository()  # 假设这里有正确的导入和实例化逻辑  
-        res = documents_repo.get_user_files_path(open_id)  
-        logger.info(f"get docs {res}")
         # 准备gRPC请求并发送  
-        request = vector_service_pb2.LoadUserDocRequest(open_id=open_id, data=[vector_service_pb2.Document(path=path, id=str(uuid)) for path, uuid in res])  
-        response = stub.LoadUserDoc(request)  
-          
+        request =  vector_pb2.SearchRequest(collection_name="some-uuid-1", query="价格多少")  
+        responses:vector_pb2.SearchResponses = await stub.SimilaritySearch(request)  
+        data = responses.data
         # 输出响应结果  
-        logger.info("Response: success = {}".format(response.success))  
+        logger.info("Response: success = {}".format(data))  
+        for item in data:
+            chunk_data = item.chunk
+            metadata_dict = dict(item.metadata)
+            score_value = item.score
+            uuid_value = item.uuid
+
+            print(f"Chunk: {chunk_data}")
+            print(f"Metadata: {metadata_dict}")
+            print(f"Score: {score_value}")
+            print(f"UUID: {uuid_value}")
   
 if __name__ == '__main__':  
-    run()
+    asyncio.run(run())

+ 8 - 5
grpc_m/send_data_to_vector.py

@@ -6,7 +6,7 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
 
 from config import logger,GRPC_VECTOR_HOST,GRPC_VECTOR_PORT
 from grpc_m.vector import vector_pb2, vector_grpc
-from db.docs import DocumentsRepository,DocumentCategories,Documents,DocStatus
+from db.docs import DocumentsRepository,Documents,DocStatus
 
 def send_to_grpc_vetcor(category_id,doc_model:Documents):
     loop = asyncio.get_running_loop()
@@ -40,10 +40,13 @@ async def langchain_chat(collection_name:str, prompt:str):
     async with Channel(GRPC_VECTOR_HOST, GRPC_VECTOR_PORT) as channel:
         vector = vector_grpc.VectorServiceStub(channel)  
         request = vector_pb2.DocChatRequest(collection_name=collection_name, prompt=prompt)
-        reply = await vector.DocChat(request)
-        logger.debug(f"{collection_name, reply, prompt}")
-        return reply
-        
+        res = await vector.DocChat(request)
+        logger.debug(f"{collection_name, res.reply, prompt}")
+        return res.reply
+    
+async def simarity_search():
+    pass
+    
 async def main():
     dm = Documents(open_id="user123", path="思维链-文档说明.md")
     # send_to_grpc_vetcor("123", dm)

+ 27 - 2
grpc_m/vector/vector.proto

@@ -1,4 +1,4 @@
-// python3 -m grpc_tools.protoc -I. --python_out=. --pyi_out=. --grpclib_python_out=. grpc_m/vector/vector.proto
+// python -m grpc_tools.protoc -I. --python_out=. --pyi_out=. --grpclib_python_out=. grpc_m/vector/vector.proto
 // https://github.com/vmagamedov/grpclib
 // https://github.com/danielgtaylor/python-betterproto
 // python -m grpc_tools.protoc -I. --python_out=. --grpclib_python_out=. langchain_service.proto
@@ -20,8 +20,10 @@ enum ErrorCode {
 service VectorService {  
     rpc SaveDocToVector (SaveDocToVectorRequest) returns (SaveDocToVectorResponse) {}  
     rpc DocChat (DocChatRequest) returns (DocChatResponse) {}  
+    rpc SearchWithIds (SearchWithIdsRequests) returns (SearchResponses) {}  
+    rpc SimilaritySearch (SearchRequest) returns (SearchResponses) {}  
 }  
-  
+
 message SaveDocToVectorRequest {  
     string collection_name = 1;  
     // {open_id}/docs/exzample.pdf
@@ -39,4 +41,27 @@ message DocChatRequest {
 
 message DocChatResponse {
     string reply = 1;
+}
+
+message SearchRequest {
+    string collection_name = 1;
+    string query = 2;
+}
+
+message SearchWithIdsRequest {
+    string uuid = 1;
+}
+message SearchWithIdsRequests {
+    repeated SearchWithIdsRequest request = 1;  
+}
+
+message SearchResponse {  
+    string chunk = 1;  
+    map<string, string> metadata = 2;  
+    float score = 3;  
+    string uuid = 4;  
+}  
+
+message SearchResponses {  
+    repeated SearchResponse data = 1;  
 }

+ 32 - 0
grpc_m/vector/vector_grpc.py

@@ -22,6 +22,14 @@ class VectorServiceBase(abc.ABC):
     async def DocChat(self, stream: 'grpclib.server.Stream[grpc_m.vector.vector_pb2.DocChatRequest, grpc_m.vector.vector_pb2.DocChatResponse]') -> None:
         pass
 
+    @abc.abstractmethod
+    async def SearchWithIds(self, stream: 'grpclib.server.Stream[grpc_m.vector.vector_pb2.SearchWithIdsRequests, grpc_m.vector.vector_pb2.SearchResponses]') -> None:
+        pass
+
+    @abc.abstractmethod
+    async def SimilaritySearch(self, stream: 'grpclib.server.Stream[grpc_m.vector.vector_pb2.SearchRequest, grpc_m.vector.vector_pb2.SearchResponses]') -> None:
+        pass
+
     def __mapping__(self) -> typing.Dict[str, grpclib.const.Handler]:
         return {
             '/langchain_service.VectorService/SaveDocToVector': grpclib.const.Handler(
@@ -36,6 +44,18 @@ class VectorServiceBase(abc.ABC):
                 grpc_m.vector.vector_pb2.DocChatRequest,
                 grpc_m.vector.vector_pb2.DocChatResponse,
             ),
+            '/langchain_service.VectorService/SearchWithIds': grpclib.const.Handler(
+                self.SearchWithIds,
+                grpclib.const.Cardinality.UNARY_UNARY,
+                grpc_m.vector.vector_pb2.SearchWithIdsRequests,
+                grpc_m.vector.vector_pb2.SearchResponses,
+            ),
+            '/langchain_service.VectorService/SimilaritySearch': grpclib.const.Handler(
+                self.SimilaritySearch,
+                grpclib.const.Cardinality.UNARY_UNARY,
+                grpc_m.vector.vector_pb2.SearchRequest,
+                grpc_m.vector.vector_pb2.SearchResponses,
+            ),
         }
 
 
@@ -54,3 +74,15 @@ class VectorServiceStub:
             grpc_m.vector.vector_pb2.DocChatRequest,
             grpc_m.vector.vector_pb2.DocChatResponse,
         )
+        self.SearchWithIds = grpclib.client.UnaryUnaryMethod(
+            channel,
+            '/langchain_service.VectorService/SearchWithIds',
+            grpc_m.vector.vector_pb2.SearchWithIdsRequests,
+            grpc_m.vector.vector_pb2.SearchResponses,
+        )
+        self.SimilaritySearch = grpclib.client.UnaryUnaryMethod(
+            channel,
+            '/langchain_service.VectorService/SimilaritySearch',
+            grpc_m.vector.vector_pb2.SearchRequest,
+            grpc_m.vector.vector_pb2.SearchResponses,
+        )

+ 19 - 5
grpc_m/vector/vector_pb2.py

@@ -14,15 +14,17 @@ _sym_db = _symbol_database.Default()
 
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1agrpc_m/vector/vector.proto\x12\x11langchain_service\"Q\n\x16SaveDocToVectorRequest\x12\x17\n\x0f\x63ollection_name\x18\x01 \x01(\t\x12\x1e\n\x16user_doc_relative_path\x18\x02 \x01(\t\"G\n\x17SaveDocToVectorResponse\x12,\n\x06status\x18\x01 \x01(\x0e\x32\x1c.langchain_service.ErrorCode\"9\n\x0e\x44ocChatRequest\x12\x17\n\x0f\x63ollection_name\x18\x01 \x01(\t\x12\x0e\n\x06prompt\x18\x02 \x01(\t\" \n\x0f\x44ocChatResponse\x12\r\n\x05reply\x18\x01 \x01(\t*H\n\tErrorCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\x15\n\x11\x44OC_CONVERT_ERROR\x10\x01\x12\x17\n\x13VECTOR_SERVER_ERROR\x10\x02\x32\xcf\x01\n\rVectorService\x12j\n\x0fSaveDocToVector\x12).langchain_service.SaveDocToVectorRequest\x1a*.langchain_service.SaveDocToVectorResponse\"\x00\x12R\n\x07\x44ocChat\x12!.langchain_service.DocChatRequest\x1a\".langchain_service.DocChatResponse\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1agrpc_m/vector/vector.proto\x12\x11langchain_service\"Q\n\x16SaveDocToVectorRequest\x12\x17\n\x0f\x63ollection_name\x18\x01 \x01(\t\x12\x1e\n\x16user_doc_relative_path\x18\x02 \x01(\t\"G\n\x17SaveDocToVectorResponse\x12,\n\x06status\x18\x01 \x01(\x0e\x32\x1c.langchain_service.ErrorCode\"9\n\x0e\x44ocChatRequest\x12\x17\n\x0f\x63ollection_name\x18\x01 \x01(\t\x12\x0e\n\x06prompt\x18\x02 \x01(\t\" \n\x0f\x44ocChatResponse\x12\r\n\x05reply\x18\x01 \x01(\t\"7\n\rSearchRequest\x12\x17\n\x0f\x63ollection_name\x18\x01 \x01(\t\x12\r\n\x05query\x18\x02 \x01(\t\"$\n\x14SearchWithIdsRequest\x12\x0c\n\x04uuid\x18\x01 \x01(\t\"Q\n\x15SearchWithIdsRequests\x12\x38\n\x07request\x18\x01 \x03(\x0b\x32\'.langchain_service.SearchWithIdsRequest\"\xb0\x01\n\x0eSearchResponse\x12\r\n\x05\x63hunk\x18\x01 \x01(\t\x12\x41\n\x08metadata\x18\x02 \x03(\x0b\x32/.langchain_service.SearchResponse.MetadataEntry\x12\r\n\x05score\x18\x03 \x01(\x02\x12\x0c\n\x04uuid\x18\x04 \x01(\t\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"B\n\x0fSearchResponses\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.langchain_service.SearchResponse*H\n\tErrorCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\x15\n\x11\x44OC_CONVERT_ERROR\x10\x01\x12\x17\n\x13VECTOR_SERVER_ERROR\x10\x02\x32\x8c\x03\n\rVectorService\x12j\n\x0fSaveDocToVector\x12).langchain_service.SaveDocToVectorRequest\x1a*.langchain_service.SaveDocToVectorResponse\"\x00\x12R\n\x07\x44ocChat\x12!.langchain_service.DocChatRequest\x1a\".langchain_service.DocChatResponse\"\x00\x12_\n\rSearchWithIds\x12(.langchain_service.SearchWithIdsRequests\x1a\".langchain_service.SearchResponses\"\x00\x12Z\n\x10SimilaritySearch\x12 .langchain_service.SearchRequest\x1a\".langchain_service.SearchResponses\"\x00\x62\x06proto3')
 
 _globals = globals()
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
 _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grpc_m.vector.vector_pb2', _globals)
 if _descriptor._USE_C_DESCRIPTORS == False:
   DESCRIPTOR._options = None
-  _globals['_ERRORCODE']._serialized_start=298
-  _globals['_ERRORCODE']._serialized_end=370
+  _globals['_SEARCHRESPONSE_METADATAENTRY']._options = None
+  _globals['_SEARCHRESPONSE_METADATAENTRY']._serialized_options = b'8\001'
+  _globals['_ERRORCODE']._serialized_start=723
+  _globals['_ERRORCODE']._serialized_end=795
   _globals['_SAVEDOCTOVECTORREQUEST']._serialized_start=49
   _globals['_SAVEDOCTOVECTORREQUEST']._serialized_end=130
   _globals['_SAVEDOCTOVECTORRESPONSE']._serialized_start=132
@@ -31,6 +33,18 @@ if _descriptor._USE_C_DESCRIPTORS == False:
   _globals['_DOCCHATREQUEST']._serialized_end=262
   _globals['_DOCCHATRESPONSE']._serialized_start=264
   _globals['_DOCCHATRESPONSE']._serialized_end=296
-  _globals['_VECTORSERVICE']._serialized_start=373
-  _globals['_VECTORSERVICE']._serialized_end=580
+  _globals['_SEARCHREQUEST']._serialized_start=298
+  _globals['_SEARCHREQUEST']._serialized_end=353
+  _globals['_SEARCHWITHIDSREQUEST']._serialized_start=355
+  _globals['_SEARCHWITHIDSREQUEST']._serialized_end=391
+  _globals['_SEARCHWITHIDSREQUESTS']._serialized_start=393
+  _globals['_SEARCHWITHIDSREQUESTS']._serialized_end=474
+  _globals['_SEARCHRESPONSE']._serialized_start=477
+  _globals['_SEARCHRESPONSE']._serialized_end=653
+  _globals['_SEARCHRESPONSE_METADATAENTRY']._serialized_start=606
+  _globals['_SEARCHRESPONSE_METADATAENTRY']._serialized_end=653
+  _globals['_SEARCHRESPONSES']._serialized_start=655
+  _globals['_SEARCHRESPONSES']._serialized_end=721
+  _globals['_VECTORSERVICE']._serialized_start=798
+  _globals['_VECTORSERVICE']._serialized_end=1194
 # @@protoc_insertion_point(module_scope)

+ 47 - 1
grpc_m/vector/vector_pb2.pyi

@@ -1,7 +1,8 @@
+from google.protobuf.internal import containers as _containers
 from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
 from google.protobuf import descriptor as _descriptor
 from google.protobuf import message as _message
-from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
+from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
 
 DESCRIPTOR: _descriptor.FileDescriptor
 
@@ -41,3 +42,48 @@ class DocChatResponse(_message.Message):
     REPLY_FIELD_NUMBER: _ClassVar[int]
     reply: str
     def __init__(self, reply: _Optional[str] = ...) -> None: ...
+
+class SearchRequest(_message.Message):
+    __slots__ = ("collection_name", "query")
+    COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
+    QUERY_FIELD_NUMBER: _ClassVar[int]
+    collection_name: str
+    query: str
+    def __init__(self, collection_name: _Optional[str] = ..., query: _Optional[str] = ...) -> None: ...
+
+class SearchWithIdsRequest(_message.Message):
+    __slots__ = ("uuid",)
+    UUID_FIELD_NUMBER: _ClassVar[int]
+    uuid: str
+    def __init__(self, uuid: _Optional[str] = ...) -> None: ...
+
+class SearchWithIdsRequests(_message.Message):
+    __slots__ = ("request",)
+    REQUEST_FIELD_NUMBER: _ClassVar[int]
+    request: _containers.RepeatedCompositeFieldContainer[SearchWithIdsRequest]
+    def __init__(self, request: _Optional[_Iterable[_Union[SearchWithIdsRequest, _Mapping]]] = ...) -> None: ...
+
+class SearchResponse(_message.Message):
+    __slots__ = ("chunk", "metadata", "score", "uuid")
+    class MetadataEntry(_message.Message):
+        __slots__ = ("key", "value")
+        KEY_FIELD_NUMBER: _ClassVar[int]
+        VALUE_FIELD_NUMBER: _ClassVar[int]
+        key: str
+        value: str
+        def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ...
+    CHUNK_FIELD_NUMBER: _ClassVar[int]
+    METADATA_FIELD_NUMBER: _ClassVar[int]
+    SCORE_FIELD_NUMBER: _ClassVar[int]
+    UUID_FIELD_NUMBER: _ClassVar[int]
+    chunk: str
+    metadata: _containers.ScalarMap[str, str]
+    score: float
+    uuid: str
+    def __init__(self, chunk: _Optional[str] = ..., metadata: _Optional[_Mapping[str, str]] = ..., score: _Optional[float] = ..., uuid: _Optional[str] = ...) -> None: ...
+
+class SearchResponses(_message.Message):
+    __slots__ = ("data",)
+    DATA_FIELD_NUMBER: _ClassVar[int]
+    data: _containers.RepeatedCompositeFieldContainer[SearchResponse]
+    def __init__(self, data: _Optional[_Iterable[_Union[SearchResponse, _Mapping]]] = ...) -> None: ...

+ 2 - 2
main.py

@@ -38,8 +38,7 @@ app.add_middleware(
     allow_credentials=True,  
     allow_methods=["*"],  
     allow_headers=["*"], 
-
-) 
+)
 app.include_router(login_router)  
 app.include_router(upload_router)
 app.include_router(webhook_route)
@@ -59,6 +58,7 @@ def main():
     logger.debug(f"http://sv-v2.magong.site:{PORT}  仅支持 ipv6 ,直连、满速、无延迟。缺点是不支持 https 协议,因为不经过 Caddy 代理,直达 Fastapi 没有配置 https")
     logger.debug(f"https://open-douyin.magong.site  内网穿透隧道,cloudflare tunnel ,经常访问不了")
     logger.info(f"http://localhost:{PORT} ⭐ 推荐。 vscode 会自动建立一条本地隧道,可以在本地浏览器直接打开")
+    logger.info(f"扫码登录 https://open.douyin.com/platform/oauth/connect/?client_key=aw6aipmfdtplwtyq&response_type=code&scope=user_info,renew_refresh_token,trial.whitelist,item.comment&redirect_uri=https://api.magong.site/swl/douyin/verify_callback")
     logger.info(f"https://swl-8l9.pages.dev/  访问前端网站")
     uvicorn.run(app, host=None, port=PORT, log_level="info")
 

+ 1 - 1
readme.md

@@ -16,7 +16,7 @@ https://open.douyin.com/platform/oauth/connect/?client_key=aw6aipmfdtplwtyq&resp
 ```shell
 # WEB扫码接入 参考 https://developer.open-douyin.com/docs/resource/zh-CN/dop/develop/sdk/web-app/web/permission
 # 打开链接,扫码登录 
-GET https://open.douyin.com/platform/oauth/connect/?client_key=aw6aipmfdtplwtyq&response_type=code&scope=user_info,renew_refresh_token,trial.whitelist,item.comment&redirect_uri=https://api.magong.site/verify_callback HTTP/1.1  
+GET https://open.douyin.com/platform/oauth/connect/?client_key=aw6aipmfdtplwtyq&response_type=code&scope=user_info,renew_refresh_token,trial.whitelist,item.comment,video.list&redirect_uri=https://api.magong.site/verify_callback HTTP/1.1  
 
 
 Host: open.douyin.com

+ 0 - 10
tloguru.log

@@ -1,10 +0,0 @@
-2024-02-26 04:14:28 | INFO     | __main__ tloguru/ tloguru.py:22 :<module> - logger info
-2024-02-26 04:14:28 | INFO     | __main__ tloguru/ tloguru.py:27 :get_logger - /home/user/code/open-douyin
-2024-02-26 04:14:28 | INFO     | __main__ tloguru/ tloguru.py:28 :get_logger - /home/user/code/open-douyin
-2024-02-26 04:14:28 | INFO     | __main__ tloguru/ tloguru.py:29 :get_logger - ..
-2024-02-26 04:14:28 | INFO     | __main__ tloguru/ tloguru.py:30 :get_logger - ../..
-2024-02-26 04:15:33 | INFO     | __main__ tloguru/ tloguru.py:22 :<module> - logger info
-2024-02-26 04:15:33 | INFO     | __main__ tloguru/ tloguru.py:27 :get_logger - /home/user/code/open-douyin
-2024-02-26 04:15:33 | INFO     | __main__ tloguru/ tloguru.py:28 :get_logger - /home/user/code/open-douyin
-2024-02-26 04:15:33 | INFO     | __main__ tloguru/ tloguru.py:29 :get_logger - .
-2024-02-26 04:15:33 | INFO     | __main__ tloguru/ tloguru.py:30 :get_logger - ../..