upload.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import asyncio
  2. import hashlib
  3. import os
  4. import pathlib
  5. import re
  6. import sys
  7. sys.path.append(os.path.dirname(os.path.dirname(__file__)))
  8. import jwt
  9. from fastapi import FastAPI,APIRouter, File, HTTPException, Depends, Request,Header, UploadFile
  10. from fastapi.responses import JSONResponse
  11. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  12. import aiofiles
  13. from starlette.status import HTTP_401_UNAUTHORIZED
  14. from api.swl_jwt import verify_jwt_token,get_current_user
  15. from config import *
  16. from db.docs import DocumentsRepository
  17. from grpc_m.send_data_to_vector import send_to_grpc_vetcor
  18. upload_router = APIRouter()
  19. security = HTTPBearer()
  20. @upload_router.post('/upload')
  21. async def upload(open_id=Depends(verify_jwt_token),file: UploadFile = File(...) ):
  22. path = await save_to_user_dir(open_id, file)
  23. if path:
  24. res = DocumentsRepository().add_document_with_categories(open_id, path)
  25. if res:
  26. doc_model, category_model, doc_categ_model = res
  27. send_to_grpc_vetcor(category_model.id, doc_model)
  28. return {"message": "upload success"}
  29. else:
  30. raise HTTPException(status_code=500, detail="Failed to add document to database")
  31. else:
  32. # 保存文件失败,返回 400 Bad Request 或其他适当的错误状态码
  33. raise HTTPException(status_code=400, detail="upload fail")
  34. def is_valid_filename(s):
  35. """检查字符串是否只包含合法的文件名字符"""
  36. # 这里的正则表达式只允许字母、数字、下划线、点、破折号和空格
  37. # 你可以根据需要调整这个正则表达式
  38. return re.match(r"^[a-zA-Z0-9_.\- ]+$", s) is not None
  39. def get_user_dir(open_id):
  40. if is_valid_filename(open_id):
  41. # 如果open_id合法,直接使用它作为目录名
  42. user_dir = os.path.join(MNT_DOUYIN_DATA, open_id)
  43. else:
  44. # 否则,计算它的哈希值并用作目录名
  45. hash_object = hashlib.md5(open_id.encode())
  46. hex_dig = hash_object.hexdigest()
  47. user_dir = os.path.join(MNT_DOUYIN_DATA, "hash8_" + hex_dig)
  48. # 如果目录不存在,创建它
  49. if not os.path.exists(user_dir):
  50. os.makedirs(user_dir)
  51. return user_dir
  52. def get_user_docs_dir(open_id):
  53. user_dir = get_user_dir(open_id)
  54. user_docs_dir = os.path.join(user_dir,"docs")
  55. if not os.path.exists(user_docs_dir):
  56. os.makedirs(user_docs_dir)
  57. return user_docs_dir
  58. async def save_to_user_dir(open_id, file:UploadFile):
  59. file_path = os.path.join(get_user_docs_dir(open_id), file.filename)
  60. async with aiofiles.open(file_path, "wb") as buffer:
  61. chunk = await file.read(8192)
  62. while chunk:
  63. await buffer.write(chunk)
  64. chunk = await file.read(8192)
  65. logger.info(f"{open_id} save to {file_path}")
  66. return file_path