docs.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import asyncio
  2. from datetime import datetime
  3. import re
  4. from typing import Optional, Tuple
  5. from enum import Enum
  6. from typing import List, Any
  7. import os
  8. import sys
  9. sys.path.append(os.path.dirname(os.path.dirname(__file__)))
  10. from sqlmodel import Field, SQLModel,Session, Integer, Sequence, UniqueConstraint,select
  11. from config import DB_URL,logger
  12. # from db.common import engine
  13. from pydantic import UUID4
  14. import uuid
  15. from db.base import BaseRepository
  16. from db.engine import engine
  17. class Categories(SQLModel, table=True):
  18. id: UUID4 = Field(default_factory=uuid.uuid1, primary_key=True) # 使用 UUID v1 作为主键
  19. open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True) # 关联到用户表的外键
  20. name: str = Field(default="default", index=True) # 分类的名称,添加索引以优化查询性能
  21. update_time: datetime = Field(default_factory=datetime.now) # 创建时间、更新时间
  22. # 添加联合唯一约束
  23. __table_args__ = (UniqueConstraint('open_id', 'name', name='uq_open_id_ctname'),)
  24. class DocumentCategories(SQLModel, table=True):
  25. id: UUID4 = Field(foreign_key="documents.id",index=True, primary_key=True) # 关联到文档表的外键
  26. category_id: UUID4 = Field(foreign_key="categories.id",index=True) # 关联到分类表的外键
  27. __table_args__ = (UniqueConstraint('id', 'category_id', ),)
  28. class DocStatus:
  29. UNPROCESSED = 0 # 未处理
  30. COMPLETED = 100 # 已完成
  31. DISABLED = -1 # 禁用
  32. class Documents(SQLModel, table=True):
  33. id: UUID4 = Field(default_factory=uuid.uuid1, primary_key=True,index=True) # 使用 UUID v1 作为主键
  34. open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True) # 关联到用户表的外键
  35. path: str = Field(nullable=False, index=True) # 相对路径
  36. status: int = Field(nullable=False) # 文档状态
  37. update_time: datetime = Field(default_factory=datetime.now) # 创建时间、更新时间
  38. __table_args__ = (UniqueConstraint('open_id', 'path', name='uq_documents'),)
  39. class DocumentBase(BaseRepository):
  40. def __init__(self, model: Documents, engine=...):
  41. super().__init__(model, engine)
  42. def before_update(self, obj_model: 'Documents', exist_obj: 'Documents') -> None:
  43. """在更新对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
  44. pass
  45. def before_create(self, obj_model: 'Documents') -> None:
  46. """在创建对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
  47. pass
  48. def get_model_dump_field_for_update(self):
  49. return self.non_unique_fields
  50. def add_or_update(self, obj_model: Documents, ex_session: Optional[Session] = None) -> SQLModel:
  51. session = ex_session or Session(bind=self.engine)
  52. exist_obj:Documents = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
  53. logger.debug(f"check table '{self.model.__tablename__}' where {self.unique_constraint_fields}")
  54. if exist_obj:
  55. dict_data = self.model_dump_by_field(obj_model, fields=self.get_model_dump_field_for_update())
  56. if dict_data:
  57. logger.debug(f"modify table '{self.model.__tablename__}' id '{exist_obj.id}' from {dict_data}")
  58. else:
  59. logger.debug(f"table '{self.model.__tablename__}' do nothing. exist '{exist_obj}'. update field {self.get_model_dump_field_for_update()}")
  60. self.before_update(obj_model, exist_obj)
  61. self.set_obj_by_dict(exist_obj,dict_data)
  62. if not ex_session:
  63. session.commit()
  64. # commit 之后 exist_obj 的值就被释放了,不存在了
  65. return True
  66. return exist_obj
  67. else:
  68. self.before_create(obj_model)
  69. self.create(obj_model,ex_session)
  70. if not ex_session:
  71. session.commit()
  72. return True
  73. logger.debug(f"on table '{self.model.__tablename__}' create {obj_model}")
  74. return obj_model
  75. class CategoriesRepository(DocumentBase):
  76. def __init__(self, engine=engine):
  77. super().__init__(Categories, engine)
  78. # 分类表不需要更新时间
  79. def get_model_dump_field_for_update(self):
  80. ret:list = self.non_unique_fields
  81. if "update_time" in ret:
  82. ret.remove("update_time")
  83. return ret
  84. class DocumentCategoriesRepository(DocumentBase):
  85. def __init__(self, engine=engine):
  86. super().__init__(DocumentCategories, engine)
  87. class DocumentsRepository(DocumentBase):
  88. def __init__(self, engine=engine):
  89. super().__init__(Documents, engine)
  90. def add_document_with_categories(self, open_id, file_path, category_name="default") -> DocumentCategories:
  91. with Session(bind=self.engine) as session:
  92. doc_model:Documents = self.exec_add_or_update_file(open_id, file_path, session)
  93. cr = CategoriesRepository()
  94. category_model:Categories = cr.add_or_update(Categories(open_id=open_id, name=category_name),session)
  95. dcr = DocumentCategoriesRepository()
  96. doc_categ_model = dcr.add_or_update(DocumentCategories(id=doc_model.id, category_id=category_model.id), session)
  97. session.commit()
  98. # 强制刷新,让 model 从数据库总获取最新状态
  99. session.refresh(doc_model)
  100. session.refresh(category_model)
  101. session.refresh(doc_categ_model)
  102. return (doc_model, category_model, doc_categ_model)
  103. def exec_add_or_update_file(self, open_id, file_path, session):
  104. # file_path = {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
  105. relative_path = DocumentsRepository.get_doc_path_from_full_path(file_path)
  106. if relative_path == None:
  107. return
  108. self.instance_model = Documents(
  109. open_id=open_id,
  110. path=relative_path,
  111. status=DocStatus.UNPROCESSED,
  112. )
  113. res = self.add_or_update(self.instance_model, session)
  114. return res
  115. '''
  116. 从绝对路径中提取相对路径
  117. input: {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
  118. output: xxx/example_file.pdf
  119. '''
  120. def get_doc_path_from_full_path(full_path):
  121. pattern = r'docs/(.*?)$'
  122. match = re.search(pattern, full_path)
  123. if match:
  124. return match.group(1)
  125. else:
  126. logger.error(f"Can not get rel path:{full_path}")
  127. '''
  128. 从 doc model 中提取文件相对路径
  129. input: Documents(path=example_file.pdf)
  130. output: {open_id}/docs/example_file.pdf
  131. '''
  132. def get_user_file_relpath_from_docmodel(doc_model:Documents):
  133. return os.path.join(str(doc_model.open_id), "docs", doc_model.path)
  134. '''
  135. return: List[Tuple[file_path, category_id]]
  136. '''
  137. def get_user_files_path(self, open_id: str, category_id: Optional[UUID4] = None, category_name: Optional[str] = None) -> List[Tuple[str, UUID4]]:
  138. with Session(self.engine) as session:
  139. # 基础查询,从 Documents 表中选择 path 和 id
  140. base_statement = select(Documents.path, Documents.id).where(Documents.open_id == open_id)
  141. # 如果提供了 category_id,则通过 DocumentCategories 进行关联查询
  142. if category_id:
  143. base_statement = base_statement.join(DocumentCategories, Documents.id == DocumentCategories.id).where(DocumentCategories.category_id == category_id)
  144. # 如果提供了 category_name,则先找到对应的 category_id 再进行关联查询
  145. elif category_name:
  146. category_subquery = select(Categories.id).where(Categories.name == category_name)
  147. doc_category_subquery = select(DocumentCategories.id).where(DocumentCategories.category_id.in_(category_subquery))
  148. base_statement = (
  149. base_statement.join(DocumentCategories, Documents.id == DocumentCategories.id)
  150. .where(DocumentCategories.id.in_(doc_category_subquery))
  151. )
  152. # 执行查询并返回结果(每个结果为一个元组:(文档路径, 分类ID))
  153. results = session.exec(base_statement)
  154. return [(result.path, result.id) for result in results]
  155. # 示例使用
  156. def main():
  157. from db.user import test_add
  158. open_id = test_add()
  159. # 创建实例
  160. documents_repo = DocumentsRepository()
  161. # model = Documents(id=uuid.UUID("f7069528-ccb9-11ee-933a-00155db00104"), status=5)
  162. # model = documents_repo.update(model)
  163. # res = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme.md")
  164. # documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/99readme3.md")
  165. doc_model,_,_ = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme5.md")
  166. rel_path = DocumentsRepository.get_user_file_relpath_from_docmodel(doc_model)
  167. logger.info(rel_path)
  168. # res = documents_repo.get_user_files_path(open_id)
  169. # 假设调用服务端的代码。注意这里只是假设示例,实际上要自己编写调用的代码逻辑
  170. # 添加分类
  171. # doc1 = Documents(open_id=open_id, document_name="docs_fn", status="ready", file_path="/path")
  172. # doc2 = Documents(open_id=open_id, document_name="docs_jj", status="ready", file_path="/path")
  173. # 实现有关代码
  174. if __name__ == "__main__":
  175. main()