| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- import asyncio
- from datetime import datetime
- import re
- from typing import Optional, Tuple
- from enum import Enum
- from typing import List, Any
- import os
- import sys
- sys.path.append(os.path.dirname(os.path.dirname(__file__)))
- from sqlmodel import Field, SQLModel,Session,Relationship, Integer, Sequence, UniqueConstraint,select
- from config import DB_URL,logger
- # from db.common import engine
- from pydantic import UUID4
- import uuid
- from db.base import BaseRepository
- from db.engine import engine
- from typing import TYPE_CHECKING, Optional
- if TYPE_CHECKING:
- from .user import User
- from db.video_data import VideoDocLink,VideoData
- class DocumentsLink(SQLModel, table=True):
- user_id: Optional[int] = Field(
- default=None, foreign_key="user.id", primary_key=True
- )
- doc_id: Optional[int] = Field(
- default=None, foreign_key="documents.id", primary_key=True
- )
-
- class DocumentCategoriesLink(SQLModel, table=True):
- documents_id: int = Field(default=None, foreign_key="documents.id", primary_key=True)
- category_id: int = Field(default=None, foreign_key="categories.id", primary_key=True)
- class Categories(SQLModel, table=True):
- id: int = Field(default=None, primary_key=True) # 使用 UUID v1 作为主键
- name: str = Field(default="default", index=True) # 分类的名称,添加索引以优化查询性能
- update_time: datetime = Field(default_factory=datetime.now) # 创建时间、更新时间
- docs: "Documents" = Relationship(back_populates="categories", link_model=DocumentCategoriesLink)
-
- class DocStatus:
- UNPROCESSED = 0 # 未处理
- COMPLETED = 100 # 已完成
- DISABLED = -1 # 禁用
-
- class Documents(SQLModel, table=True):
- id: int = Field(default=None, primary_key=True,index=True) # 使用 UUID v1 作为主键
- # open_id: str = Field(index=True) # 关联到用户表的外键
- path: str = Field(nullable=False) # 相对路径
- status: int = Field(nullable=False) # 文档状态
- update_time: datetime = Field(default_factory=datetime.now) # 创建时间、更新时间
- user:"User" = Relationship(back_populates="docs", link_model=DocumentsLink, sa_relationship_kwargs={"cascade": "all, delete-orphan", "single_parent":True})
- categories: List[Categories] = Relationship(back_populates="docs", link_model=DocumentCategoriesLink)
- video_data:VideoData = Relationship(back_populates="doc", link_model=VideoDocLink, sa_relationship_kwargs={"cascade": "all, delete-orphan", "single_parent":True})
-
- class DocumentBase(BaseRepository):
- def __init__(self, model: Documents, engine=...):
- super().__init__(model, engine)
-
- def before_update(self, obj_model: 'Documents', exist_obj: 'Documents') -> None:
- """在更新对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
- pass
-
- def before_create(self, obj_model: 'Documents') -> None:
- """在创建对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
- pass
-
- def get_model_dump_field_for_update(self):
- return self.non_unique_fields
-
- def add_or_update(self, obj_model: Documents, ex_session: Optional[Session] = None) -> SQLModel:
- session = ex_session or Session(bind=self.engine)
- exist_obj:Documents = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
- logger.debug(f"check table '{self.model.__tablename__}' where {self.unique_constraint_fields}")
- if exist_obj:
- dict_data = self.model_dump_by_field(obj_model, fields=self.get_model_dump_field_for_update())
- if dict_data:
- logger.debug(f"modify table '{self.model.__tablename__}' id '{exist_obj.id}' from {dict_data}")
- else:
- logger.debug(f"table '{self.model.__tablename__}' do nothing. exist '{exist_obj}'. update field {self.get_model_dump_field_for_update()}")
- self.before_update(obj_model, exist_obj)
- self.set_obj_by_dict(exist_obj,dict_data)
- if not ex_session:
- session.commit()
- # commit 之后 exist_obj 的值就被释放了,不存在了
- return True
- return exist_obj
- else:
- self.before_create(obj_model)
- self.create(obj_model,ex_session)
- if not ex_session:
- session.commit()
- return True
- logger.debug(f"on table '{self.model.__tablename__}' create {obj_model}")
- return obj_model
- class CategoriesRepository(DocumentBase):
- def __init__(self, engine=engine):
- super().__init__(Categories, engine)
- # 分类表不需要更新时间
- def get_model_dump_field_for_update(self):
- ret:list = self.non_unique_fields
- if "update_time" in ret:
- ret.remove("update_time")
- return ret
- class DocumentsRepository(DocumentBase):
- def __init__(self, engine=engine):
- super().__init__(Documents, engine)
-
- def add_document_with_categories(self, user, file_path, category_name="default"):
- with Session(bind=self.engine) as session:
- doc_model:Documents = self.exec_add_or_update_file(user.open_id, file_path, session)
- doc_model.categories.append(category_name)
- user.docs.append(doc_model)
- session.add(user)
- session.commit()
- session.refresh(user)
- return user
-
- # def update_document(self, user:User, ):
- # def add_document_with_categories(self, open_id, file_path, category_name="default") -> DocumentCategories:
- # with Session(bind=self.engine) as session:
- # doc_model:Documents = self.exec_add_or_update_file(open_id, file_path, session)
- # cr = CategoriesRepository()
- # category_model:Categories = cr.add_or_update(Categories(open_id=open_id, name=category_name),session)
- # dcr = DocumentCategoriesRepository()
- # doc_categ_model = dcr.add_or_update(DocumentCategories(id=doc_model.id, category_id=category_model.id), session)
- # session.commit()
- # # 强制刷新,让 model 从数据库总获取最新状态
- # session.refresh(doc_model)
- # session.refresh(category_model)
- # session.refresh(doc_categ_model)
- # return (doc_model, category_model, doc_categ_model)
-
- def exec_add_or_update_file(self, open_id, file_path, session):
- # file_path = {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
- relative_path = DocumentsRepository.get_doc_path_from_full_path(file_path)
- if relative_path == None:
- return
- self.instance_model = Documents(
- open_id=open_id,
- path=relative_path,
- status=DocStatus.UNPROCESSED,
- )
- res = self.add_or_update(self.instance_model, session)
- return res
- '''
- 从绝对路径中提取相对路径
- input: {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
- output: xxx/example_file.pdf
- '''
- def get_doc_path_from_full_path(full_path):
- pattern = r'docs/(.*?)$'
- match = re.search(pattern, full_path)
- if match:
- return match.group(1)
- else:
- logger.error(f"Can not get rel path:{full_path}")
-
- '''
- 从 doc model 中提取文件相对路径
- input: Documents(path=example_file.pdf)
- output: {open_id}/docs/example_file.pdf
- '''
- def get_user_file_relpath_from_docmodel(doc_model:Documents):
- return os.path.join(str(doc_model.open_id), "docs", doc_model.path)
- '''
- return: List[Tuple[file_path, category_id]]
- '''
- def get_user_files_path(self, open_id: str, category_id: Optional[UUID4] = None, category_name: Optional[str] = None) -> List[Tuple[str, UUID4]]:
- with Session(self.engine) as session:
- # 基础查询,从 Documents 表中选择 path 和 id
- base_statement = select(Documents.path, Documents.id).where(Documents.open_id == open_id)
-
- # 如果提供了 category_id,则通过 DocumentCategories 进行关联查询
- if category_id:
- base_statement = base_statement.join(DocumentCategories, Documents.id == DocumentCategories.id).where(DocumentCategories.category_id == category_id)
- # 如果提供了 category_name,则先找到对应的 category_id 再进行关联查询
- elif category_name:
- category_subquery = select(Categories.id).where(Categories.name == category_name)
- doc_category_subquery = select(DocumentCategories.id).where(DocumentCategories.category_id.in_(category_subquery))
- base_statement = (
- base_statement.join(DocumentCategories, Documents.id == DocumentCategories.id)
- .where(DocumentCategories.id.in_(doc_category_subquery))
- )
-
- # 执行查询并返回结果(每个结果为一个元组:(文档路径, 分类ID))
- results = session.exec(base_statement)
- return [(result.path, result.id) for result in results]
- # 示例使用
- def main():
- from db.user_oauth import test_add
- open_id = test_add()
- # 创建实例
- documents_repo = DocumentsRepository()
- # model = Documents(id=uuid.UUID("f7069528-ccb9-11ee-933a-00155db00104"), status=5)
- # model = documents_repo.update(model)
- # res = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme.md")
- # documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/99readme3.md")
- doc_model,_,_ = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme5.md")
- rel_path = DocumentsRepository.get_user_file_relpath_from_docmodel(doc_model)
- logger.info(rel_path)
- # res = documents_repo.get_user_files_path(open_id)
- # 假设调用服务端的代码。注意这里只是假设示例,实际上要自己编写调用的代码逻辑
- # 添加分类
- # doc1 = Documents(open_id=open_id, document_name="docs_fn", status="ready", file_path="/path")
- # doc2 = Documents(open_id=open_id, document_name="docs_jj", status="ready", file_path="/path")
- # 实现有关代码
-
- if __name__ == "__main__":
- main()
|