| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- import asyncio
- from datetime import datetime
- import re
- from typing import Optional, Tuple
- from enum import Enum
- from typing import List, Any
- import os
- import sys
- sys.path.append(os.path.dirname(os.path.dirname(__file__)))
- from sqlmodel import Field, SQLModel,Session,Relationship, Integer, Sequence, UniqueConstraint,select
- from config import DB_URL,logger
- # from db.common import engine
- import uuid
- from db.base import BaseRepository
- from db.engine import engine
- from db.model import *
- from typing import TYPE_CHECKING, Optional
- class DocumentBase(BaseRepository):
- def __init__(self, model: Documents, engine=...):
- super().__init__(model, engine)
-
- def before_update(self, obj_model: 'Documents', exist_obj: 'Documents') -> None:
- """在更新对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
- pass
-
- def before_create(self, obj_model: 'Documents') -> None:
- """在创建对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
- pass
-
- def get_model_dump_field_for_update(self):
- return self.non_unique_fields
-
- def add_or_update(self, obj_model: Documents, ex_session: Optional[Session] = None) -> SQLModel:
- session = ex_session or Session(bind=self.engine)
- exist_obj:Documents = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
- logger.debug(f"check table '{self.model.__tablename__}' where {self.unique_constraint_fields}")
- if exist_obj:
- dict_data = self.model_dump_by_field(obj_model, fields=self.get_model_dump_field_for_update())
- if dict_data:
- logger.debug(f"modify table '{self.model.__tablename__}' id '{exist_obj.id}' from {dict_data}")
- else:
- logger.debug(f"table '{self.model.__tablename__}' do nothing. exist '{exist_obj}'. update field {self.get_model_dump_field_for_update()}")
- self.before_update(obj_model, exist_obj)
- self.set_obj_by_dict(exist_obj,dict_data)
- if not ex_session:
- session.commit()
- # commit 之后 exist_obj 的值就被释放了,不存在了
- return True
- return exist_obj
- else:
- self.before_create(obj_model)
- self.create(obj_model,ex_session)
- if not ex_session:
- session.commit()
- return True
- logger.debug(f"on table '{self.model.__tablename__}' create {obj_model}")
- return obj_model
- class CategoriesRepository(DocumentBase):
- def __init__(self, engine=engine):
- super().__init__(Categories, engine)
- def get_or_create(self, user, category_name, ex_session=None):
- session = ex_session or Session(bind=self.engine)
- category = session.exec(
- select(Categories)
- .join(LinkUserCatgegory)
- .filter(LinkUserCatgegory.user_id == user.id)
- .filter(Categories.name == category_name)
- ).first()
- if not category:
- category = Categories(name=category_name, user=user)
- return category
-
-
- class DocumentsRepository(DocumentBase):
- def __init__(self, engine=engine):
- super().__init__(Documents, engine)
-
- def get_by_cate_path(self, session, category_id, doc_path)->Documents:
- doc_model = session.exec(
- select(Documents)
- .join(LinkDocumentCategories)
- .filter(LinkDocumentCategories.category_id == category_id)
- .filter(Documents.path == doc_path)
- ).first()
- return doc_model
-
- def get_if_exist(self, user, doc_path, category_name, ex_session=None):
- session = ex_session or Session(bind=self.engine)
- category = CategoriesRepository().get_or_create(user, category_name=category_name, ex_session=session)
- doc_model = None
- if category:
- doc_model = session.exec(
- select(Documents)
- .join(LinkDocumentCategories)
- .filter(LinkDocumentCategories.category_id == category.id)
- .filter(Documents.path == doc_path)
- ).first()
- return category, doc_model
-
- def add_category_doc(self, user, file_path, category_name, ex_session):
- category, doc_model = self.get_if_exist(user, file_path, category_name, ex_session)
- if not category:
- doc_model:Documents = self.gen_model(user.open_id, file_path)
- category = Categories(name=category_name, docs=[doc_model])
- category.user = user
- else:
- if not doc_model:
- doc_model:Documents = self.gen_model(user.open_id, file_path)
- category.docs.append(doc_model)
- return category
- def get_or_create(self, user:User, file_path, category_name, summarize=None, ex_session=None):
- session = ex_session or Session(bind=self.engine)
- doc_model = session.exec(
- select(Documents)
- .join(LinkUserDocument)
- .join(LinkDocumentCategories)
- .join(Categories).where(Categories.name == category_name)
- .where(Documents.path == file_path)
- ).first()
- if not doc_model:
- catgory = CategoriesRepository().get_or_create(user, category_name, ex_session=ex_session)
- doc_model = Documents(path=file_path,status=0,user=catgory.user)
- doc_model.categories.append(catgory)
- logger.info(f"create document {doc_model}")
- return doc_model
-
- # file_path 可以是全路径,也可以是 {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf 形式的相对路径,由上层决定
- # - 如果是相对路径,则说明它是一个挂载点。
- def add_document_with_categories(self, user, file_path, category_name="default"):
- file_path = DocumentsRepository.get_doc_path_from_full_path(file_path)
- with Session(bind=self.engine) as session:
- category = self.add_category_doc(user, file_path, category_name, session)
- session.add(category)
- session.commit()
- session.refresh(category)
- return category.user,category.docs[-1]
-
-
-
- def set_status(self, doc:Documents, status:DocStatus):
- with Session(bind=self.engine) as session:
- session.add(doc)
- session.commit()
- session.refresh(doc)
- return doc
-
- def gen_model(self, open_id, file_path):
- # file_path = {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
- doc_model = Documents(
- open_id=open_id,
- path=file_path,
- status=DocStatus.UNPROCESSED,
- )
- return doc_model
- '''
- 从绝对路径中提取相对路径
- input: {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
- output: xxx/example_file.pdf
- '''
- def get_doc_path_from_full_path(full_path):
- pattern = r'docs/(.*?)$'
- match = re.search(pattern, full_path)
- if match:
- return match.group(1)
- else:
- return full_path
- # 示例使用
- def main():
- from db.user_oauth import test_add
- open_id = test_add()
- # 创建实例
- documents_repo = DocumentsRepository()
- # model = Documents(id=uuid.UUID("f7069528-ccb9-11ee-933a-00155db00104"), status=5)
- # model = documents_repo.update(model)
- # res = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme.md")
- # documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/99readme3.md")
- doc_model,_,_ = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme5.md")
- rel_path = DocumentsRepository.get_user_file_relpath_from_docmodel(doc_model)
- logger.info(rel_path)
- # res = documents_repo.get_user_files_path(open_id)
- # 假设调用服务端的代码。注意这里只是假设示例,实际上要自己编写调用的代码逻辑
- # 添加分类
- # doc1 = Documents(open_id=open_id, document_name="docs_fn", status="ready", file_path="/path")
- # doc2 = Documents(open_id=open_id, document_name="docs_jj", status="ready", file_path="/path")
- # 实现有关代码
-
- if __name__ == "__main__":
- main()
|