| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- import asyncio
- from datetime import datetime
- import re
- from typing import Optional
- from enum import Enum
- from typing import List, Any
- import os
- import sys
- sys.path.append(os.path.dirname(os.path.dirname(__file__)))
- from sqlmodel import Field, SQLModel,Session, Integer, Sequence, UniqueConstraint,select
- from config import DB_URL,logger
- # from db.common import engine
- from pydantic import UUID4
- import uuid
- from db.base import BaseRepository
- from db.engine import engine
-
- class Categories(SQLModel, table=True):
- id: UUID4 = Field(default_factory=uuid.uuid1, primary_key=True) # 使用 UUID v1 作为主键
- open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True) # 关联到用户表的外键
- name: str = Field(default="default", index=True) # 分类的名称,添加索引以优化查询性能
- update_time: datetime = Field(default_factory=datetime.now) # 创建时间、更新时间
- # 添加联合唯一约束
- __table_args__ = (UniqueConstraint('open_id', 'name', name='uq_open_id_ctname'),)
-
- class DocumentCategories(SQLModel, table=True):
- id: UUID4 = Field(foreign_key="documents.id",index=True, primary_key=True) # 关联到文档表的外键
- category_id: UUID4 = Field(foreign_key="categories.id",index=True) # 关联到分类表的外键
- __table_args__ = (UniqueConstraint('id', 'category_id', ),)
-
- class DocStatus:
- UNPROCESSED = 0 # 未处理
- COMPLETED = 100 # 已完成
- DISABLED = -1 # 禁用
-
- class Documents(SQLModel, table=True):
- id: UUID4 = Field(default_factory=uuid.uuid1, primary_key=True,index=True) # 使用 UUID v1 作为主键
- open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True) # 关联到用户表的外键
- path: str = Field(nullable=False, index=True) # 相对路径
- status: int = Field(nullable=False) # 文档状态
- update_time: datetime = Field(default_factory=datetime.now) # 创建时间、更新时间
- __table_args__ = (UniqueConstraint('open_id', 'path', name='uq_documents'),)
- class DocumentBase(BaseRepository):
- def __init__(self, model: Documents, engine=...):
- super().__init__(model, engine)
-
- def before_update(self, obj_model: 'Documents', exist_obj: 'Documents') -> None:
- """在更新对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
- pass
-
- def before_create(self, obj_model: 'Documents') -> None:
- """在创建对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
- pass
-
- def get_model_dump_field_for_update(self):
- return self.non_unique_fields
-
- def add_or_update(self, obj_model: Documents, ex_session: Optional[Session] = None) -> SQLModel:
- session = ex_session or Session(bind=self.engine)
- exist_obj:Documents = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
- logger.debug(f"check table '{self.model.__tablename__}' where {self.unique_constraint_fields}")
- if exist_obj:
- dict_data = self.model_dump_by_field(obj_model, fields=self.get_model_dump_field_for_update())
- if dict_data:
- logger.debug(f"modify table '{self.model.__tablename__}' id '{exist_obj.id}' from {dict_data}")
- else:
- logger.debug(f"table '{self.model.__tablename__}' do nothing. exist '{exist_obj}'. update field {self.get_model_dump_field_for_update()}")
- self.before_update(obj_model, exist_obj)
- self.set_obj_by_dict(exist_obj,dict_data)
- if not ex_session:
- session.commit()
- # commit 之后 exist_obj 的值就被释放了,不存在了
- return True
- return exist_obj
- else:
- self.before_create(obj_model)
- self.create(obj_model,ex_session)
- if not ex_session:
- session.commit()
- return True
- logger.debug(f"on table '{self.model.__tablename__}' create {obj_model}")
- return obj_model
- class CategoriesRepository(DocumentBase):
- def __init__(self, engine=engine):
- super().__init__(Categories, engine)
- # 分类表不需要更新时间
- def get_model_dump_field_for_update(self):
- ret:list = self.non_unique_fields
- if "update_time" in ret:
- ret.remove("update_time")
- return ret
- class DocumentCategoriesRepository(DocumentBase):
- def __init__(self, engine=engine):
- super().__init__(DocumentCategories, engine)
- class DocumentsRepository(DocumentBase):
- def __init__(self, engine=engine):
- super().__init__(Documents, engine)
- def add_document_with_categories(self, open_id, file_path, category_name="default"):
- with Session(bind=self.engine) as session:
- doc_model:Documents = self.exec_add_or_update_file(open_id, file_path, session)
- cr = CategoriesRepository()
- category_model:Categories = cr.add_or_update(Categories(open_id=open_id, name=category_name),session)
- dcr = DocumentCategoriesRepository()
- doc_categ_model = dcr.add_or_update(DocumentCategories(id=doc_model.id, category_id=category_model.id), session)
- session.commit()
- return True
-
- def exec_add_or_update_file(self, open_id, file_path, session):
- # file_path = {DATA_DIR}/{open_id}/docs/xxx/example_file.pdf
- relative_path = DocumentsRepository.get_relative_path(file_path)
- if relative_path == None:
- return
- self.instance_model = Documents(
- open_id=open_id,
- path=relative_path,
- status=DocStatus.UNPROCESSED,
- )
- res = self.add_or_update(self.instance_model, session)
- return res
- def get_relative_path(full_path):
- pattern = r'docs(/.*?)$'
- match = re.search(pattern, full_path)
- if match:
- return match.group(1)
- else:
- logger.error(f"Can not get rel path:{full_path}")
-
-
- def get_user_files_path(self, open_id: str, category_id: Optional[UUID4] = None, category_name: Optional[str] = None) -> List[str]:
- with Session(self.engine) as session:
- # 基础查询,从 Documents 表中选择 path
- base_statement = select(Documents.path).where(Documents.open_id == open_id)
-
- # 根据 category_id 或 category_name 进行过滤
- if category_id:
- base_statement = base_statement.where(Documents.id.in_(
- select(DocumentCategories.id).where(DocumentCategories.category_id == category_id)
- ))
- elif category_name:
- category_subquery = select(Categories.id).where(Categories.name == category_name)
- doc_category_subquery = select(DocumentCategories.id).where(DocumentCategories.category_id.in_(category_subquery))
- base_statement = base_statement.where(Documents.id.in_(doc_category_subquery))
-
- # 执行查询并返回结果
- results = session.exec(base_statement)
- return results.all()
-
- # 示例使用
- def main():
- from db.user import test_add
- open_id = test_add()
- # 创建实例
- documents_repo = DocumentsRepository()
- documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme.md")
- documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/99readme3.md")
- documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme5.md")
- logger.info(documents_repo.get_user_files_path(open_id))
- # 添加分类
- # doc1 = Documents(open_id=open_id, document_name="docs_fn", status="ready", file_path="/path")
- # doc2 = Documents(open_id=open_id, document_name="docs_jj", status="ready", file_path="/path")
- # 实现有关代码
-
- if __name__ == "__main__":
- main()
|