| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- import asyncio
- from datetime import datetime
- import re
- from typing import Optional, Tuple
- from enum import Enum
- from typing import List, Any
- import os
- import sys
- sys.path.append(os.path.dirname(os.path.dirname(__file__)))
- from sqlmodel import Field, SQLModel,Session, Integer, Sequence, UniqueConstraint,select
- from config import DB_URL,logger
- # from db.common import engine
- from pydantic import UUID4
- import uuid
- from db.base import BaseRepository
- from db.engine import engine
-
- class Categories(SQLModel, table=True):
- id: UUID4 = Field(default_factory=uuid.uuid1, primary_key=True) # 使用 UUID v1 作为主键
- open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True) # 关联到用户表的外键
- name: str = Field(default="default", index=True) # 分类的名称,添加索引以优化查询性能
- update_time: datetime = Field(default_factory=datetime.now) # 创建时间、更新时间
- # 添加联合唯一约束
- __table_args__ = (UniqueConstraint('open_id', 'name', name='uq_open_id_ctname'),)
-
- class DocumentCategories(SQLModel, table=True):
- id: UUID4 = Field(foreign_key="documents.id",index=True, primary_key=True) # 关联到文档表的外键
- category_id: UUID4 = Field(foreign_key="categories.id",index=True) # 关联到分类表的外键
- __table_args__ = (UniqueConstraint('id', 'category_id', ),)
-
- class DocStatus:
- UNPROCESSED = 0 # 未处理
- COMPLETED = 100 # 已完成
- DISABLED = -1 # 禁用
-
- class Documents(SQLModel, table=True):
- id: UUID4 = Field(default_factory=uuid.uuid1, primary_key=True,index=True) # 使用 UUID v1 作为主键
- open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True) # 关联到用户表的外键
- path: str = Field(nullable=False, index=True) # 相对路径
- status: int = Field(nullable=False) # 文档状态
- update_time: datetime = Field(default_factory=datetime.now) # 创建时间、更新时间
- __table_args__ = (UniqueConstraint('open_id', 'path', name='uq_documents'),)
- class DocumentBase(BaseRepository):
- def __init__(self, model: Documents, engine=...):
- super().__init__(model, engine)
-
- def before_update(self, obj_model: 'Documents', exist_obj: 'Documents') -> None:
- """在更新对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
- pass
-
- def before_create(self, obj_model: 'Documents') -> None:
- """在创建对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
- pass
-
- def get_model_dump_field_for_update(self):
- return self.non_unique_fields
-
- def add_or_update(self, obj_model: Documents, ex_session: Optional[Session] = None) -> SQLModel:
- session = ex_session or Session(bind=self.engine)
- exist_obj:Documents = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
- logger.debug(f"check table '{self.model.__tablename__}' where {self.unique_constraint_fields}")
- if exist_obj:
- dict_data = self.model_dump_by_field(obj_model, fields=self.get_model_dump_field_for_update())
- if dict_data:
- logger.debug(f"modify table '{self.model.__tablename__}' id '{exist_obj.id}' from {dict_data}")
- else:
- logger.debug(f"table '{self.model.__tablename__}' do nothing. exist '{exist_obj}'. update field {self.get_model_dump_field_for_update()}")
- self.before_update(obj_model, exist_obj)
- self.set_obj_by_dict(exist_obj,dict_data)
- if not ex_session:
- session.commit()
- # commit 之后 exist_obj 的值就被释放了,不存在了
- return True
- return exist_obj
- else:
- self.before_create(obj_model)
- self.create(obj_model,ex_session)
- if not ex_session:
- session.commit()
- return True
- logger.debug(f"on table '{self.model.__tablename__}' create {obj_model}")
- return obj_model
- class CategoriesRepository(DocumentBase):
- def __init__(self, engine=engine):
- super().__init__(Categories, engine)
- # 分类表不需要更新时间
- def get_model_dump_field_for_update(self):
- ret:list = self.non_unique_fields
- if "update_time" in ret:
- ret.remove("update_time")
- return ret
- class DocumentCategoriesRepository(DocumentBase):
- def __init__(self, engine=engine):
- super().__init__(DocumentCategories, engine)
- class DocumentsRepository(DocumentBase):
- def __init__(self, engine=engine):
- super().__init__(Documents, engine)
- def add_document_with_categories(self, open_id, file_path, category_name="default") -> DocumentCategories:
- with Session(bind=self.engine) as session:
- doc_model:Documents = self.exec_add_or_update_file(open_id, file_path, session)
- cr = CategoriesRepository()
- category_model:Categories = cr.add_or_update(Categories(open_id=open_id, name=category_name),session)
- dcr = DocumentCategoriesRepository()
- doc_categ_model = dcr.add_or_update(DocumentCategories(id=doc_model.id, category_id=category_model.id), session)
- session.commit()
- # 强制刷新,让 model 从数据库总获取最新状态
- session.refresh(doc_model)
- session.refresh(category_model)
- session.refresh(doc_categ_model)
- return (doc_model, category_model, doc_categ_model)
-
- def exec_add_or_update_file(self, open_id, file_path, session):
- # file_path = {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
- relative_path = DocumentsRepository.get_doc_path_from_full_path(file_path)
- if relative_path == None:
- return
- self.instance_model = Documents(
- open_id=open_id,
- path=relative_path,
- status=DocStatus.UNPROCESSED,
- )
- res = self.add_or_update(self.instance_model, session)
- return res
- '''
- 从绝对路径中提取相对路径
- input: {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
- output: xxx/example_file.pdf
- '''
- def get_doc_path_from_full_path(full_path):
- pattern = r'docs/(.*?)$'
- match = re.search(pattern, full_path)
- if match:
- return match.group(1)
- else:
- logger.error(f"Can not get rel path:{full_path}")
-
- '''
- 从 doc model 中提取文件相对路径
- input: Documents(path=example_file.pdf)
- output: {open_id}/docs/example_file.pdf
- '''
- def get_user_file_relpath_from_docmodel(doc_model:Documents):
- return os.path.join(str(doc_model.open_id), "docs", doc_model.path)
- '''
- return: List[Tuple[file_path, category_id]]
- '''
- def get_user_files_path(self, open_id: str, category_id: Optional[UUID4] = None, category_name: Optional[str] = None) -> List[Tuple[str, UUID4]]:
- with Session(self.engine) as session:
- # 基础查询,从 Documents 表中选择 path 和 id
- base_statement = select(Documents.path, Documents.id).where(Documents.open_id == open_id)
-
- # 如果提供了 category_id,则通过 DocumentCategories 进行关联查询
- if category_id:
- base_statement = base_statement.join(DocumentCategories, Documents.id == DocumentCategories.id).where(DocumentCategories.category_id == category_id)
- # 如果提供了 category_name,则先找到对应的 category_id 再进行关联查询
- elif category_name:
- category_subquery = select(Categories.id).where(Categories.name == category_name)
- doc_category_subquery = select(DocumentCategories.id).where(DocumentCategories.category_id.in_(category_subquery))
- base_statement = (
- base_statement.join(DocumentCategories, Documents.id == DocumentCategories.id)
- .where(DocumentCategories.id.in_(doc_category_subquery))
- )
-
- # 执行查询并返回结果(每个结果为一个元组:(文档路径, 分类ID))
- results = session.exec(base_statement)
- return [(result.path, result.id) for result in results]
- # 示例使用
- def main():
- from db.user import test_add
- open_id = test_add()
- # 创建实例
- documents_repo = DocumentsRepository()
- # model = Documents(id=uuid.UUID("f7069528-ccb9-11ee-933a-00155db00104"), status=5)
- # model = documents_repo.update(model)
- # res = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme.md")
- # documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/99readme3.md")
- doc_model,_,_ = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme5.md")
- rel_path = DocumentsRepository.get_user_file_relpath_from_docmodel(doc_model)
- logger.info(rel_path)
- # res = documents_repo.get_user_files_path(open_id)
- # 假设调用服务端的代码。注意这里只是假设示例,实际上要自己编写调用的代码逻辑
- # 添加分类
- # doc1 = Documents(open_id=open_id, document_name="docs_fn", status="ready", file_path="/path")
- # doc2 = Documents(open_id=open_id, document_name="docs_jj", status="ready", file_path="/path")
- # 实现有关代码
-
- if __name__ == "__main__":
- main()
|