فهرست منبع

完成 gRPC 文件信息发送,让服务器向量转换

qyl 1 سال پیش
والد
کامیت
3884412051
17فایلهای تغییر یافته به همراه398 افزوده شده و 126 حذف شده
  1. 1 1
      api/jwt.py
  2. 7 2
      api/login.py
  3. 24 17
      api/upload.py
  4. 5 4
      config.py
  5. 10 7
      db/base.py
  6. 55 30
      db/docs.py
  7. 2 2
      db/engine.py
  8. 12 3
      db/user.py
  9. 24 27
      douyin/access_token.py
  10. 44 0
      grpc_m/client.py
  11. 44 0
      grpc_m/send_data_to_vector.py
  12. 27 0
      grpc_m/vector_service.proto
  13. 32 0
      grpc_m/vector_service_pb2.py
  14. 29 0
      grpc_m/vector_service_pb2.pyi
  15. 66 0
      grpc_m/vector_service_pb2_grpc.py
  16. 16 12
      main.py
  17. 0 21
      test/config.py

+ 1 - 1
api/jwt.py

@@ -58,4 +58,4 @@ def get_uer_oauth_and_verify(open_id: str = Depends(verify_jwt_token)):
     # 没有用户凭证,需要扫码登陆
     if not db_oauth:  
         raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="need login")  
-    return db_oauth  
+    return db_oauth

+ 7 - 2
api/login.py

@@ -30,10 +30,13 @@ class User(BaseModel):
 # 登录端点
 @login_router.post("/login")
 async def login(data: ScanCode):
+    logger.info(data)
     data = await get_access_token(data.code)
     if data.get("error_code") != 0:
-        return data, status.HTTP_400_BAD_REQUEST
-    
+        raise HTTPException(status_code=400, detail=data)
+    db_manager = UserOAuthRepository()
+    db_manager.add_or_update(data)
+
     # 计算过期时间戳(基于北京时间)  
     expires_in = data.get("expires_in", 1296000)
     # expires_in = 15
@@ -54,6 +57,8 @@ async def login(data: ScanCode):
 @login_router.get("/user_info")
 async def user_info(db_user_oauth: UserOAuthToken = Depends(get_uer_oauth_and_verify)):
     info = await get_user_info(db_user_oauth.open_id, db_user_oauth.access_token)
+    if info.get("error_code") != 0:
+        raise HTTPException(status_code=400, detail=info)  
     return info
 
 # 受保护资源示例

+ 24 - 17
api/updload.py → api/upload.py

@@ -1,3 +1,4 @@
+import asyncio
 import hashlib
 import os
 import pathlib
@@ -12,18 +13,22 @@ import aiofiles
 from starlette.status import HTTP_401_UNAUTHORIZED 
 from api.jwt import verify_jwt_token,get_current_user
 from config import *
-from db.docs import DocumentsRepository
+from db.docs import DocumentsRepository,DocumentCategories
+from grpc_m.send_data_to_vector import send_to_grpc_vetcor
 
 upload_router = APIRouter()  
 security = HTTPBearer()
 @upload_router.post('/upload')
 async def upload(open_id=Depends(verify_jwt_token),file: UploadFile = File(...)  ):  
-    
-    res = await save_to_user_dir(open_id, file)
-    if res:
-        # 保存文件到本地
-        DocumentsRepository()
-        return {"message": "upload success"}
+    path = await save_to_user_dir(open_id, file)
+    if path:
+        res = DocumentsRepository().add_document_with_categories(open_id, path)  
+        if res:
+            doc_model, category_model, doc_categ_model = res
+            send_to_grpc_vetcor(category_model.id, doc_model)
+            return {"message": "upload success"}
+        else:  
+            raise HTTPException(status_code=500, detail="Failed to add document to database")
     else:  
         # 保存文件失败,返回 400 Bad Request 或其他适当的错误状态码  
         raise HTTPException(status_code=400, detail="upload fail")
@@ -37,28 +42,30 @@ def is_valid_filename(s):
 def get_user_dir(open_id):  
     if is_valid_filename(open_id):  
         # 如果open_id合法,直接使用它作为目录名  
-        user_dir = os.path.join(DATA_DIR, open_id)  
+        user_dir = os.path.join(MNT_DOUYIN_DATA, open_id)  
     else:  
         # 否则,计算它的哈希值并用作目录名  
         hash_object = hashlib.md5(open_id.encode())  
-        hex_dig = hash_object.hexdigest()[:8] # 只取前 8 个字符   
-        user_dir = os.path.join(DATA_DIR, "hash8_" + hex_dig)  
-      
+        hex_dig = hash_object.hexdigest()
+        user_dir = os.path.join(MNT_DOUYIN_DATA, "hash8_" + hex_dig)  
     # 如果目录不存在,创建它  
     if not os.path.exists(user_dir):  
         os.makedirs(user_dir)  
-      
     return user_dir  
 
-    
-async def save_to_user_dir(open_id, file:UploadFile):
+def get_user_docs_dir(open_id):
     user_dir = get_user_dir(open_id)
-    file_path = os.path.join(user_dir,"docs", file.filename)
-      
+    user_docs_dir = os.path.join(user_dir,"docs")
+    if not os.path.exists(user_docs_dir):  
+        os.makedirs(user_docs_dir)
+    return user_docs_dir
+
+async def save_to_user_dir(open_id, file:UploadFile):
+    file_path = os.path.join(get_user_docs_dir(open_id), file.filename)
     async with aiofiles.open(file_path, "wb") as buffer:  
         chunk = await file.read(8192)  
         while chunk:  
             await buffer.write(chunk)  
             chunk = await file.read(8192)  
         logger.info(f"{open_id} save to {file_path}")
-        return True
+        return file_path

+ 5 - 4
config.py

@@ -13,8 +13,8 @@ FORMAT = '<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level>
 LOG_LEVEL = "DEBUG"
 logger.remove()
 # logger.add(sys.stderr, format='<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>')
-logger.add(sys.stderr, format=FORMAT)
-logger.add(LOG_FILE, format=FORMAT)
+logger.add(sys.stderr, format=FORMAT, level="INFO")
+logger.add(LOG_FILE, format=FORMAT, level="DEBUG")
 logger.info(f"load config:{ __file__}")
 
 # openssl rand -hex 32
@@ -30,14 +30,15 @@ DOUYIN_OPEN_API="https://open.douyin.com"
 HOST = '::'
 
 PORT = 8601 if os.environ.get("USER")=="mrh" else 8600
+GRPC_VECTOR="192.168.2.32:18600"
 SECRET_KEY = os.environ.get("SECRET_KEY")
 
 DB_URL=os.environ.get("DB_URL")
 # 生产环境中绝对不能使用硬编码的 DATA 路径。否则后期负载均衡、数据扩容、迁移将会造成很大影响。
 if not PRODUCE_ENV:
-    DATA_DIR="/home/user/code/open-douyin/test/data"
+    MNT_DOUYIN_DATA="/home/user/code/open-douyin/test/data"
 else:
-    DATA_DIR = os.environ["DATA_DIR"]
+    MNT_DOUYIN_DATA = os.environ["MNT_DOUYIN_DATA"]
 logger.info(f"API URL:{HOST}:{PORT}")
 
 

+ 10 - 7
db/base.py

@@ -28,19 +28,22 @@ class BaseRepository:
         session = ex_session or Session(bind=self.engine)
         return session.get(self.model, id)
 
-    def update(self, id: int, obj_in: SQLModel, ex_session: Optional[Session] = None) -> bool:
-        def session_exec(session,obj_in: SQLModel):
+    def update(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> bool:
+        def session_exec(id: int=None, obj_in: SQLModel=None, ex_session: Optional[Session] = None):
             obj = session.get(self.model, id)
             if not obj:
-                return False
+                return obj_in
             for key, value in obj_in.model_dump().items():
                 setattr(obj, key, value)
-        
+            return obj
+        if not obj_in.id:
+            return
         session = ex_session or Session(bind=self.engine)
-        session_exec(session,obj_in)
+        obj = session_exec(obj_in.id,obj_in,session)
         if not ex_session:
             session.commit()
-        return obj_in
+            session.refresh(obj)
+        return obj
 
     def delete(self, id: int, ex_session: Optional[Session] = None) -> bool:
         def session_exec(session: Optional[Session],obj_in: SQLModel):
@@ -102,7 +105,7 @@ class BaseRepository:
             session.commit()
         return obj
     
-    def check_exist(self, obj: SQLModel, check_field=None, ex_session=None):
+    def check_exist(self, obj: SQLModel, check_field:List[str]=None, ex_session=None):
         session = ex_session or Session(bind=self.engine)
         if not check_field:
             check_field = self.get_unique_constraint_fields()

+ 55 - 30
db/docs.py

@@ -1,7 +1,7 @@
 import asyncio
 from datetime import datetime
 import re
-from typing import Optional
+from typing import Optional, Tuple
 from enum import Enum
 from typing import List, Any
 import os
@@ -108,7 +108,7 @@ class DocumentsRepository(DocumentBase):
     def __init__(self, engine=engine):  
         super().__init__(Documents, engine)  
 
-    def add_document_with_categories(self, open_id, file_path, category_name="default"):
+    def add_document_with_categories(self, open_id, file_path, category_name="default") -> DocumentCategories:
         with Session(bind=self.engine) as session:
             doc_model:Documents = self.exec_add_or_update_file(open_id, file_path, session)
             cr = CategoriesRepository()
@@ -116,11 +116,15 @@ class DocumentsRepository(DocumentBase):
             dcr = DocumentCategoriesRepository()
             doc_categ_model = dcr.add_or_update(DocumentCategories(id=doc_model.id, category_id=category_model.id), session)
             session.commit()
-            return True
+            # 强制刷新,让 model 从数据库总获取最新状态
+            session.refresh(doc_model)
+            session.refresh(category_model)
+            session.refresh(doc_categ_model)
+            return (doc_model, category_model, doc_categ_model)
     
     def exec_add_or_update_file(self, open_id, file_path, session):
-        # file_path = {DATA_DIR}/{open_id}/docs/xxx/example_file.pdf
-        relative_path = DocumentsRepository.get_relative_path(file_path)
+        # file_path = {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
+        relative_path = DocumentsRepository.get_doc_path_from_full_path(file_path)
         if relative_path == None:
             return
         self.instance_model = Documents(
@@ -131,44 +135,65 @@ class DocumentsRepository(DocumentBase):
         res = self.add_or_update(self.instance_model, session)
         return res
 
-    def get_relative_path(full_path):
-        pattern = r'docs(/.*?)$'  
+    '''
+    从绝对路径中提取相对路径
+    input: {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
+    output: xxx/example_file.pdf
+    '''
+    def get_doc_path_from_full_path(full_path):
+        pattern = r'docs/(.*?)$'  
         match = re.search(pattern, full_path)  
         if match:
             return match.group(1)
         else:
             logger.error(f"Can not get rel path:{full_path}")
     
-    
-    def get_user_files_path(self, open_id: str, category_id: Optional[UUID4] = None, category_name: Optional[str] = None) -> List[str]:  
-        with Session(self.engine) as session:  
-            # 基础查询,从 Documents 表中选择 path  
-            base_statement = select(Documents.path).where(Documents.open_id == open_id)  
-              
-            # 根据 category_id 或 category_name 进行过滤  
-            if category_id:  
-                base_statement = base_statement.where(Documents.id.in_(  
-                    select(DocumentCategories.id).where(DocumentCategories.category_id == category_id)  
-                ))  
-            elif category_name:  
-                category_subquery = select(Categories.id).where(Categories.name == category_name)  
-                doc_category_subquery = select(DocumentCategories.id).where(DocumentCategories.category_id.in_(category_subquery))  
-                base_statement = base_statement.where(Documents.id.in_(doc_category_subquery))  
-              
-            # 执行查询并返回结果  
+    '''
+    从 doc model 中提取文件相对路径
+    input: Documents(path=example_file.pdf)
+    output: {open_id}/docs/example_file.pdf
+    '''
+    def get_user_file_relpath_from_docmodel(doc_model:Documents):
+        return os.path.join(str(doc_model.open_id), "docs", doc_model.path)
+    '''
+    return: List[Tuple[file_path, category_id]]
+    '''
+    def get_user_files_path(self, open_id: str, category_id: Optional[UUID4] = None, category_name: Optional[str] = None) -> List[Tuple[str, UUID4]]:
+        with Session(self.engine) as session:
+            # 基础查询,从 Documents 表中选择 path 和 id
+            base_statement = select(Documents.path, Documents.id).where(Documents.open_id == open_id)
+            
+            # 如果提供了 category_id,则通过 DocumentCategories 进行关联查询
+            if category_id:
+                base_statement = base_statement.join(DocumentCategories, Documents.id == DocumentCategories.id).where(DocumentCategories.category_id == category_id)
+            # 如果提供了 category_name,则先找到对应的 category_id 再进行关联查询
+            elif category_name:
+                category_subquery = select(Categories.id).where(Categories.name == category_name)
+                doc_category_subquery = select(DocumentCategories.id).where(DocumentCategories.category_id.in_(category_subquery))
+                base_statement = (
+                    base_statement.join(DocumentCategories, Documents.id == DocumentCategories.id)
+                    .where(DocumentCategories.id.in_(doc_category_subquery))
+                )
+                
+            # 执行查询并返回结果(每个结果为一个元组:(文档路径, 分类ID))
             results = session.exec(base_statement)
-            return results.all()
-        
+            return [(result.path, result.id) for result in results]
+
 # 示例使用  
 def main():  
     from db.user import test_add
     open_id = test_add()
     # 创建实例  
     documents_repo = DocumentsRepository()  
-    documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme.md")
-    documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/99readme3.md")
-    documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme5.md")
-    logger.info(documents_repo.get_user_files_path(open_id))
+    # model = Documents(id=uuid.UUID("f7069528-ccb9-11ee-933a-00155db00104"), status=5)
+    # model = documents_repo.update(model)
+    # res = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme.md")
+    # documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/99readme3.md")
+    doc_model,_,_ = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme5.md")
+    rel_path = DocumentsRepository.get_user_file_relpath_from_docmodel(doc_model)
+    logger.info(rel_path)
+    # res = documents_repo.get_user_files_path(open_id)
+    # 假设调用服务端的代码。注意这里只是假设示例,实际上要自己编写调用的代码逻辑
     # 添加分类  
     # doc1 = Documents(open_id=open_id, document_name="docs_fn", status="ready", file_path="/path")
     # doc2 = Documents(open_id=open_id, document_name="docs_jj", status="ready", file_path="/path")

+ 2 - 2
db/engine.py

@@ -1,6 +1,6 @@
-import asyncio
 from sqlmodel import Field, SQLModel, create_engine
 from config import DB_URL,logger
-import logging
 
 engine = create_engine(DB_URL)  # 替换成你的 DB_URL 
+def create_all_table():
+    SQLModel.metadata.create_all(engine)

+ 12 - 3
db/user.py

@@ -59,9 +59,16 @@ class DouyinBaseRepository(BaseRepository):
                 session.commit()
                 return exist_obj
             else:
-                session.commit()
+                self.create(obj_model)
+                logger.debug(f"on table '{self.model.__tablename__}' create {obj_model}")
                 return obj_model
 
+    def get_by_open_id(self, open_id):
+        with Session(self.engine) as session:
+            logger.debug(f"get {open_id}")
+            base_statement = select(self.model).where(self.model.open_id == open_id)  
+            results = session.exec(base_statement)  
+            return results.first()
 
 
 class UserInfoRepository(DouyinBaseRepository):  
@@ -125,7 +132,7 @@ class UserOAuthRepository(DouyinBaseRepository):
 def test_add(open_id=None):  
     SQLModel.metadata.create_all(engine)
     
-    user_oauth = {'access_token': 'act.3.wl8L3DFQ3sj3uKYzQShOSs8HbOgKh0FVvjxKeaTum0ZOEXoyBI8D1N7gTBqGbrY32KP-Pm41EAvcobSheOBi8tvRdhj7m5-5ZVoprZZu_GN5J2KnH2fZ_X9_l7Q6iFyvpPoMkX3Zyom3PCkeRZp4Jg9sE2ZiwuvZVdnvft0A25uBWXvj2IEbWW_0Bf8=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '20240129123749239735B0529965BC6D93', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
+    user_oauth = {'access_token': 'act.3.m3kiZmfxxIH95i1bHZ7Bq3Wkv_Xm5TtD3kpLGjtCr3G96WIINBKEvzlsaObrGcH4GaxTQeLZA13jkzoZhpAwPwMRqFxlVuIcxpge_-BpdFib1xHqkcFa4B-LX4zpd2YK3kDFTFfMcJXN_fZ2eByg6oqqa1OieUWcvlaVgw==', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '202402171353461C33F969CEFB511B216F', 'open_id': '_000LiV_o0FGKMwaZgqGMCvcVSf-UAnNU_kk', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.e4b3da8bd3ef01d880d827b11e826391OEGHiRrTLcp5zsGYP1dh6F9Bo7fg', 'scope': 'user_info,trial.whitelist'}
     user_info = {
     "avatar": "https://p26.douyinpic.com/aweme/100x100/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
     "avatar_larger": "https://p3.douyinpic.com/aweme/1080x1080/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
@@ -149,12 +156,14 @@ def test_add(open_id=None):
         user_oauth["open_id"] = open_id
         user_info["open_id"] = open_id
         user_info["nickname"] = "user" + open_id[:5]
+    else:
+        open_id = user_oauth["open_id"]
     db_manager = UserOAuthRepository()
     res = db_manager.add_or_update(user_oauth)
     # logger.debug(res)
     db_user_info = UserInfoRepository()
     res = db_user_info.add_or_update(user_info)
-    # logger.debug(res)
+    logger.debug(db_manager.get_by_open_id(open_id))
     return user_oauth["open_id"]
 
 if __name__ == "__main__":

+ 24 - 27
douyin/access_token.py

@@ -18,30 +18,9 @@ async def check_access_token_response(response_json: dict):
         raise Exception(f"获取 access token 失败,错误码:{model_data.error_code}")
     return model_data.data
 
-async def get_access_token(code):
-    if not PRODUCE_ENV:
-        # 测试环境使用。因为每次 get_access_token 的 code 只能使用一次就过期了,为了避免频繁扫码,直接模拟返回请求结果
-        return {'access_token': 'act.3.UCzqnMwbL7uUTH0PkWbvDvIHcpy417HnfMqymbvBSpo9b1MJ3jOdwCxw-UPstOOjsGDWIdNwTGev4oEp8eUR-vHbU24XU5K4BkhPeOKJW1CLrEUS3XFxpG6SHqoQtvL6qhEgINcvt4V3KQX6C2qTeTkgQ-KwPO6jWi5uoin3YXo5DqwuGk3bbQ9dZoY=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '2024012915260549B5ED1A675515CD573C', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
-    
-    client_key = os.environ.get("CLIENT_KEY")  # 从环境变量中获取 client_key  
-    client_secret = os.environ.get("CLIENT_SECRET")  # 从环境变量中获取 client_secret  
-    async with httpx.AsyncClient() as client:  
-        response = await client.post(  
-            "https://open.douyin.com/oauth/access_token/",  
-            headers={"Content-Type": "application/json"},  
-            json={  
-                "grant_type": "authorization_code",  
-                "client_key": client_key,  
-                "client_secret": client_secret,  
-                "code": code,  
-            },  
-        )
-        res_json = response.json()
-        if res_json("data").get("error_code") != 0:
-            db_manager = UserOAuthRepository()
-            db_manager.add_token(res_json("data"))
 
-    ''' response success:
+''' 
+response success:
     {
     "data": {
         "access_token": "act.f7094fbffab2ecbfc45e9af9c32bc241oYdckvBKe82BPx8T******",
@@ -70,10 +49,28 @@ async def get_access_token(code):
             "now": 1594015876138
         }
     }
-    '''
-    logger.debug(res_json) 
-    return res_json.get("data")
-
+'''
+async def get_access_token(code):
+    # if not PRODUCE_ENV:
+    #     # 测试环境使用。因为每次 get_access_token 的 code 只能使用一次就过期了,为了避免频繁扫码,直接模拟返回请求结果
+    #     return {'access_token': 'act.3.UCzqnMwbL7uUTH0PkWbvDvIHcpy417HnfMqymbvBSpo9b1MJ3jOdwCxw-UPstOOjsGDWIdNwTGev4oEp8eUR-vHbU24XU5K4BkhPeOKJW1CLrEUS3XFxpG6SHqoQtvL6qhEgINcvt4V3KQX6C2qTeTkgQ-KwPO6jWi5uoin3YXo5DqwuGk3bbQ9dZoY=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '2024012915260549B5ED1A675515CD573C', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
+    
+    client_key = os.environ.get("CLIENT_KEY")  # 从环境变量中获取 client_key  
+    client_secret = os.environ.get("CLIENT_SECRET")  # 从环境变量中获取 client_secret  
+    async with httpx.AsyncClient() as client:  
+        response = await client.post(  
+            "https://open.douyin.com/oauth/access_token/",  
+            headers={"Content-Type": "application/json"},  
+            json={  
+                "grant_type": "authorization_code",  
+                "client_key": client_key,  
+                "client_secret": client_secret,  
+                "code": code,  
+            },  
+        )
+        res_json = response.json()
+        logger.debug(res_json.get("data"))
+        return res_json.get("data")
 # 单元测试
 def main():
     # 访问: https://swl-8l9.pages.dev/  点击立即体验

+ 44 - 0
grpc_m/client.py

@@ -0,0 +1,44 @@
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))
+import grpc  
+from grpc_m import vector_service_pb2, vector_service_pb2_grpc  
+from db.docs import DocumentsRepository  
+from db.user import test_add  # 假设这里有必要的导入来获取open_id  
+from config import logger,GRPC_VECTOR
+
+async def load_user_docs_async(open_id: str, file_path: list[str]):  
+    async with grpc.insecure_channel(GRPC_VECTOR) as channel:  
+        stub = vector_service_pb2_grpc.VectorServiceStub(channel)  
+  
+        request = vector_service_pb2.LoadUserDocRequest(  
+            open_id=open_id,  
+            data=[vector_service_pb2.Document(path=file_path, id=str(open_id))]  
+        )  
+        response = await stub.LoadUserDoc(request)  
+  
+        # 输出响应结果  
+        logger.info("Response: success = {}".format(response.success))  
+        return response  
+
+def run():  
+    # 创建gRPC通道和存根  
+    with grpc.insecure_channel('localhost:18600') as channel:  
+        stub = vector_service_pb2_grpc.VectorServiceStub(channel)  
+          
+        # 获取open_id(这里只是示例,具体实现可能会有所不同)  
+        open_id = "_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy"
+          
+        # 创建DocumentsRepository实例并获取用户文件路径列表  
+        documents_repo = DocumentsRepository()  # 假设这里有正确的导入和实例化逻辑  
+        res = documents_repo.get_user_files_path(open_id)  
+        logger.info(f"get docs {res}")
+        # 准备gRPC请求并发送  
+        request = vector_service_pb2.LoadUserDocRequest(open_id=open_id, data=[vector_service_pb2.Document(path=path, id=str(uuid)) for path, uuid in res])  
+        response = stub.LoadUserDoc(request)  
+          
+        # 输出响应结果  
+        logger.info("Response: success = {}".format(response.success))  
+  
+if __name__ == '__main__':  
+    run()

+ 44 - 0
grpc_m/send_data_to_vector.py

@@ -0,0 +1,44 @@
+import asyncio
+import grpc
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))
+
+from config import logger,GRPC_VECTOR
+from grpc_m import vector_service_pb2, vector_service_pb2_grpc  
+from db.docs import DocumentsRepository,DocumentCategories,Documents,DocStatus
+
+def send_to_grpc_vetcor(category_id,doc_model:Documents):
+    loop = asyncio.get_running_loop()
+    loop.create_task(load_user_docs_async(category_id,doc_model))
+    
+async def load_user_docs_async(category_id: str, doc_model:Documents):  
+    user_doc_relative_path = DocumentsRepository.get_user_file_relpath_from_docmodel(doc_model)
+    try:
+        async with grpc.aio.insecure_channel(GRPC_VECTOR) as channel:  
+            stub = vector_service_pb2_grpc.VectorServiceStub(channel)  
+            logger.info(f"send to vector: category_id={category_id} user_doc_relative_path={user_doc_relative_path}")
+            request = vector_service_pb2.SaveDocToVectorRequest(  
+                user_doc_relative_path=user_doc_relative_path,
+                category_id=str(category_id) # 将 uuid 类型转为字符串,使其符合 proto 文件定义的协议
+            )
+            response:vector_service_pb2.SaveDocToVectorResponse = await stub.SaveDocToVector(request)  
+    
+            # 输出响应结果  
+            if response.status == vector_service_pb2.ErrorCode.SUCCESS:
+                doc_model.status = DocStatus.COMPLETED
+                DocumentsRepository().update(doc_model)
+                logger.info(f"Document conver to vector sucess. category_id:{category_id}  doc_id:{doc_model.id} path:{doc_model.path}")
+            else:
+                logger.error(f"Response: status = {response.status}, {doc_model}")
+            return response.status
+    except Exception:
+        logger.exception(f"vector server {GRPC_VECTOR} error:{response.status}")
+
+async def main():
+    dm = Documents(open_id="user123", path="思维链-文档说明.md")
+    send_to_grpc_vetcor("123", dm)
+    while True:
+        await asyncio.sleep(1)
+if __name__ == "__main__":
+    asyncio.run(main())

+ 27 - 0
grpc_m/vector_service.proto

@@ -0,0 +1,27 @@
+// python -m grpc_tools.protoc -I. --python_out=. --pyi_out=. --grpc_python_out=. grpc_m/vector_service.proto
+// * `-I./grpc`: 指定搜索导入的 `.proto` 文件的目录。这里,我们告诉 `protoc` 在 `./grpc` 目录中查找其他 `.proto` 文件(如果有的话)。  
+// * `--python_out=./grpc/gen_code`: 指定生成的 Python 代码的输出目录。  
+// * `--grpc_python_out=./grpc/gen_code`: 指定生成的 gRPC Python 代码的输出目录。注意,这通常会覆盖上面的 `--python_out`,但为了确保清晰,我在这里都包括了。实际上,对于 gRPC,你只需要 `--grpc_python_out`。但是,如果你也想生成纯的 Protobuf Python 代码(不包括 gRPC 服务和服务端/客户端代码),那么你需要同时指定两者。  
+// * `./grpc/vector_service.proto`: 指定要编译的 `.proto` 文件的路径。
+syntax = "proto3";  
+  
+package grpc_m;  
+// 定义错误枚举类型
+enum ErrorCode {
+    SUCCESS = 0;
+    DOC_CONVERT_ERROR = 1;
+    VECTOR_SERVER_ERROR = 2;
+}
+
+service VectorService {  
+    rpc SaveDocToVector (SaveDocToVectorRequest) returns (SaveDocToVectorResponse) {}  
+}  
+  
+message SaveDocToVectorRequest {  
+    string category_id = 1;  
+    string user_doc_relative_path = 2; 
+}  
+
+message SaveDocToVectorResponse {  
+    ErrorCode status = 1;  
+}

+ 32 - 0
grpc_m/vector_service_pb2.py

@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler.  DO NOT EDIT!
+# source: grpc_m/vector_service.proto
+# Protobuf Python Version: 4.25.0
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1bgrpc_m/vector_service.proto\x12\x06grpc_m\"M\n\x16SaveDocToVectorRequest\x12\x13\n\x0b\x63\x61tegory_id\x18\x01 \x01(\t\x12\x1e\n\x16user_doc_relative_path\x18\x02 \x01(\t\"<\n\x17SaveDocToVectorResponse\x12!\n\x06status\x18\x01 \x01(\x0e\x32\x11.grpc_m.ErrorCode*H\n\tErrorCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\x15\n\x11\x44OC_CONVERT_ERROR\x10\x01\x12\x17\n\x13VECTOR_SERVER_ERROR\x10\x02\x32\x65\n\rVectorService\x12T\n\x0fSaveDocToVector\x12\x1e.grpc_m.SaveDocToVectorRequest\x1a\x1f.grpc_m.SaveDocToVectorResponse\"\x00\x62\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grpc_m.vector_service_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+  DESCRIPTOR._options = None
+  _globals['_ERRORCODE']._serialized_start=180
+  _globals['_ERRORCODE']._serialized_end=252
+  _globals['_SAVEDOCTOVECTORREQUEST']._serialized_start=39
+  _globals['_SAVEDOCTOVECTORREQUEST']._serialized_end=116
+  _globals['_SAVEDOCTOVECTORRESPONSE']._serialized_start=118
+  _globals['_SAVEDOCTOVECTORRESPONSE']._serialized_end=178
+  _globals['_VECTORSERVICE']._serialized_start=254
+  _globals['_VECTORSERVICE']._serialized_end=355
+# @@protoc_insertion_point(module_scope)

+ 29 - 0
grpc_m/vector_service_pb2.pyi

@@ -0,0 +1,29 @@
+from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import message as _message
+from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
+
+DESCRIPTOR: _descriptor.FileDescriptor
+
+class ErrorCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
+    __slots__ = ()
+    SUCCESS: _ClassVar[ErrorCode]
+    DOC_CONVERT_ERROR: _ClassVar[ErrorCode]
+    VECTOR_SERVER_ERROR: _ClassVar[ErrorCode]
+SUCCESS: ErrorCode
+DOC_CONVERT_ERROR: ErrorCode
+VECTOR_SERVER_ERROR: ErrorCode
+
+class SaveDocToVectorRequest(_message.Message):
+    __slots__ = ("category_id", "user_doc_relative_path")
+    CATEGORY_ID_FIELD_NUMBER: _ClassVar[int]
+    USER_DOC_RELATIVE_PATH_FIELD_NUMBER: _ClassVar[int]
+    category_id: str
+    user_doc_relative_path: str
+    def __init__(self, category_id: _Optional[str] = ..., user_doc_relative_path: _Optional[str] = ...) -> None: ...
+
+class SaveDocToVectorResponse(_message.Message):
+    __slots__ = ("status",)
+    STATUS_FIELD_NUMBER: _ClassVar[int]
+    status: ErrorCode
+    def __init__(self, status: _Optional[_Union[ErrorCode, str]] = ...) -> None: ...

+ 66 - 0
grpc_m/vector_service_pb2_grpc.py

@@ -0,0 +1,66 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+
+from grpc_m import vector_service_pb2 as grpc__m_dot_vector__service__pb2
+
+
+class VectorServiceStub(object):
+    """Missing associated documentation comment in .proto file."""
+
+    def __init__(self, channel):
+        """Constructor.
+
+        Args:
+            channel: A grpc.Channel.
+        """
+        self.SaveDocToVector = channel.unary_unary(
+                '/grpc_m.VectorService/SaveDocToVector',
+                request_serializer=grpc__m_dot_vector__service__pb2.SaveDocToVectorRequest.SerializeToString,
+                response_deserializer=grpc__m_dot_vector__service__pb2.SaveDocToVectorResponse.FromString,
+                )
+
+
+class VectorServiceServicer(object):
+    """Missing associated documentation comment in .proto file."""
+
+    def SaveDocToVector(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+
+def add_VectorServiceServicer_to_server(servicer, server):
+    rpc_method_handlers = {
+            'SaveDocToVector': grpc.unary_unary_rpc_method_handler(
+                    servicer.SaveDocToVector,
+                    request_deserializer=grpc__m_dot_vector__service__pb2.SaveDocToVectorRequest.FromString,
+                    response_serializer=grpc__m_dot_vector__service__pb2.SaveDocToVectorResponse.SerializeToString,
+            ),
+    }
+    generic_handler = grpc.method_handlers_generic_handler(
+            'grpc_m.VectorService', rpc_method_handlers)
+    server.add_generic_rpc_handlers((generic_handler,))
+
+
+ # This class is part of an EXPERIMENTAL API.
+class VectorService(object):
+    """Missing associated documentation comment in .proto file."""
+
+    @staticmethod
+    def SaveDocToVector(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(request, target, '/grpc_m.VectorService/SaveDocToVector',
+            grpc__m_dot_vector__service__pb2.SaveDocToVectorRequest.SerializeToString,
+            grpc__m_dot_vector__service__pb2.SaveDocToVectorResponse.FromString,
+            options, channel_credentials,
+            insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

+ 16 - 12
main.py

@@ -1,3 +1,4 @@
+import asyncio
 import socket
 import time
 from fastapi import FastAPI, Request  
@@ -10,32 +11,35 @@ from fastapi.responses import FileResponse,JSONResponse
 from fastapi import FastAPI, Depends, HTTPException, Form  
 import httpx
 import os
-# from db.user import UserOAuthToken
 from config import *
+from db.engine import create_all_table
 from fastapi.middleware.cors import CORSMiddleware 
 from api.login import login_router
-from api.updload import upload_router
+from api.upload import upload_router
 from contextlib import asynccontextmanager
+from sqlmodel import SQLModel
 
-app = FastAPI()  
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    # 在应用启动前运行的代码
+    create_all_table()
+    loop = asyncio.get_event_loop()
+    yield
+    # 在应用关闭后运行的代码
+    pass
+
+app = FastAPI(lifespan=lifespan)  
 app.add_middleware(  
     CORSMiddleware,  
     allow_origins=["*"],  
     allow_credentials=True,  
     allow_methods=["*"],  
     allow_headers=["*"], 
-     
+
 ) 
 app.include_router(login_router)  
 app.include_router(upload_router)
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
-    # 在应用启动前运行的代码
-    pass
-    yield
-    # 在应用关闭后运行的代码
-    pass
+    
 
 def get_session(request: Request):  
     return request.session  

+ 0 - 21
test/config.py

@@ -1,21 +0,0 @@
-
-import os
-import sys
-sys.path.append(os.path.dirname(os.path.dirname(__file__)))
-from sqlalchemy import UniqueConstraint
-from loguru import logger
-from sqlmodel import Field, SQLModel
-
-# 定义数据库模型,不推荐使用 __tablename__ = '"UserInfo"' 来定义包含大写的表名字,
-# 因为可能会导致与其他数据库系统不兼容,而且表查询的时候需要额外注意表格名的大小写
-class UserOAuthToken(SQLModel):  
-    access_token:str = Field(default=None, primary_key=True)
-    refresh_expires_in:str = Field(default=None, )
-    refresh_token:str = Field(default=1, )
-    __table_args__ = (UniqueConstraint('open_id', 'name', name='uq_open_id_name'),)
-    
-user = UserOAuthToken()
-HOST = '::'
-con:UniqueConstraint = [col.name for col in UserOAuthToken.__table__.constraints['uq_open_id_name'].columns] 
-logger.info(con)
-# print(HOST)