docs.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import asyncio
  2. from datetime import datetime
  3. import re
  4. from typing import Optional, Tuple
  5. from enum import Enum
  6. from typing import List, Any
  7. import os
  8. import sys
  9. sys.path.append(os.path.dirname(os.path.dirname(__file__)))
  10. from sqlmodel import Field, SQLModel,Session,Relationship, Integer, Sequence, UniqueConstraint,select
  11. from config import DB_URL,logger
  12. # from db.common import engine
  13. from pydantic import UUID4
  14. import uuid
  15. from db.base import BaseRepository
  16. from db.engine import engine
  17. from typing import TYPE_CHECKING, Optional
  18. if TYPE_CHECKING:
  19. from .user import User
  20. from db.video_data import VideoDocLink,VideoData
  21. class DocumentsLink(SQLModel, table=True):
  22. user_id: Optional[int] = Field(
  23. default=None, foreign_key="user.id", primary_key=True
  24. )
  25. doc_id: Optional[int] = Field(
  26. default=None, foreign_key="documents.id", primary_key=True
  27. )
  28. class DocumentCategoriesLink(SQLModel, table=True):
  29. documents_id: int = Field(default=None, foreign_key="documents.id", primary_key=True)
  30. category_id: int = Field(default=None, foreign_key="categories.id", primary_key=True)
  31. class Categories(SQLModel, table=True):
  32. id: int = Field(default=None, primary_key=True) # 使用 UUID v1 作为主键
  33. name: str = Field(default="default", index=True) # 分类的名称,添加索引以优化查询性能
  34. update_time: datetime = Field(default_factory=datetime.now) # 创建时间、更新时间
  35. docs: "Documents" = Relationship(back_populates="categories", link_model=DocumentCategoriesLink)
  36. class DocStatus:
  37. UNPROCESSED = 0 # 未处理
  38. COMPLETED = 100 # 已完成
  39. DISABLED = -1 # 禁用
  40. class Documents(SQLModel, table=True):
  41. id: int = Field(default=None, primary_key=True,index=True) # 使用 UUID v1 作为主键
  42. # open_id: str = Field(index=True) # 关联到用户表的外键
  43. path: str = Field(nullable=False) # 相对路径
  44. status: int = Field(nullable=False) # 文档状态
  45. update_time: datetime = Field(default_factory=datetime.now) # 创建时间、更新时间
  46. user:"User" = Relationship(back_populates="docs", link_model=DocumentsLink, sa_relationship_kwargs={"cascade": "all, delete-orphan", "single_parent":True})
  47. categories: List[Categories] = Relationship(back_populates="docs", link_model=DocumentCategoriesLink)
  48. video_data:VideoData = Relationship(back_populates="doc", link_model=VideoDocLink, sa_relationship_kwargs={"cascade": "all, delete-orphan", "single_parent":True})
  49. class DocumentBase(BaseRepository):
  50. def __init__(self, model: Documents, engine=...):
  51. super().__init__(model, engine)
  52. def before_update(self, obj_model: 'Documents', exist_obj: 'Documents') -> None:
  53. """在更新对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
  54. pass
  55. def before_create(self, obj_model: 'Documents') -> None:
  56. """在创建对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
  57. pass
  58. def get_model_dump_field_for_update(self):
  59. return self.non_unique_fields
  60. def add_or_update(self, obj_model: Documents, ex_session: Optional[Session] = None) -> SQLModel:
  61. session = ex_session or Session(bind=self.engine)
  62. exist_obj:Documents = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
  63. logger.debug(f"check table '{self.model.__tablename__}' where {self.unique_constraint_fields}")
  64. if exist_obj:
  65. dict_data = self.model_dump_by_field(obj_model, fields=self.get_model_dump_field_for_update())
  66. if dict_data:
  67. logger.debug(f"modify table '{self.model.__tablename__}' id '{exist_obj.id}' from {dict_data}")
  68. else:
  69. logger.debug(f"table '{self.model.__tablename__}' do nothing. exist '{exist_obj}'. update field {self.get_model_dump_field_for_update()}")
  70. self.before_update(obj_model, exist_obj)
  71. self.set_obj_by_dict(exist_obj,dict_data)
  72. if not ex_session:
  73. session.commit()
  74. # commit 之后 exist_obj 的值就被释放了,不存在了
  75. return True
  76. return exist_obj
  77. else:
  78. self.before_create(obj_model)
  79. self.create(obj_model,ex_session)
  80. if not ex_session:
  81. session.commit()
  82. return True
  83. logger.debug(f"on table '{self.model.__tablename__}' create {obj_model}")
  84. return obj_model
  85. class CategoriesRepository(DocumentBase):
  86. def __init__(self, engine=engine):
  87. super().__init__(Categories, engine)
  88. # 分类表不需要更新时间
  89. def get_model_dump_field_for_update(self):
  90. ret:list = self.non_unique_fields
  91. if "update_time" in ret:
  92. ret.remove("update_time")
  93. return ret
  94. class DocumentsRepository(DocumentBase):
  95. def __init__(self, engine=engine):
  96. super().__init__(Documents, engine)
  97. def add_document_with_categories(self, user, file_path, category_name="default"):
  98. with Session(bind=self.engine) as session:
  99. doc_model:Documents = self.exec_add_or_update_file(user.open_id, file_path, session)
  100. doc_model.categories.append(category_name)
  101. user.docs.append(doc_model)
  102. session.add(user)
  103. session.commit()
  104. session.refresh(user)
  105. return user
  106. # def update_document(self, user:User, ):
  107. # def add_document_with_categories(self, open_id, file_path, category_name="default") -> DocumentCategories:
  108. # with Session(bind=self.engine) as session:
  109. # doc_model:Documents = self.exec_add_or_update_file(open_id, file_path, session)
  110. # cr = CategoriesRepository()
  111. # category_model:Categories = cr.add_or_update(Categories(open_id=open_id, name=category_name),session)
  112. # dcr = DocumentCategoriesRepository()
  113. # doc_categ_model = dcr.add_or_update(DocumentCategories(id=doc_model.id, category_id=category_model.id), session)
  114. # session.commit()
  115. # # 强制刷新,让 model 从数据库总获取最新状态
  116. # session.refresh(doc_model)
  117. # session.refresh(category_model)
  118. # session.refresh(doc_categ_model)
  119. # return (doc_model, category_model, doc_categ_model)
  120. def exec_add_or_update_file(self, open_id, file_path, session):
  121. # file_path = {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
  122. relative_path = DocumentsRepository.get_doc_path_from_full_path(file_path)
  123. if relative_path == None:
  124. return
  125. self.instance_model = Documents(
  126. open_id=open_id,
  127. path=relative_path,
  128. status=DocStatus.UNPROCESSED,
  129. )
  130. res = self.add_or_update(self.instance_model, session)
  131. return res
  132. '''
  133. 从绝对路径中提取相对路径
  134. input: {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
  135. output: xxx/example_file.pdf
  136. '''
  137. def get_doc_path_from_full_path(full_path):
  138. pattern = r'docs/(.*?)$'
  139. match = re.search(pattern, full_path)
  140. if match:
  141. return match.group(1)
  142. else:
  143. logger.error(f"Can not get rel path:{full_path}")
  144. '''
  145. 从 doc model 中提取文件相对路径
  146. input: Documents(path=example_file.pdf)
  147. output: {open_id}/docs/example_file.pdf
  148. '''
  149. def get_user_file_relpath_from_docmodel(doc_model:Documents):
  150. return os.path.join(str(doc_model.open_id), "docs", doc_model.path)
  151. '''
  152. return: List[Tuple[file_path, category_id]]
  153. '''
  154. def get_user_files_path(self, open_id: str, category_id: Optional[UUID4] = None, category_name: Optional[str] = None) -> List[Tuple[str, UUID4]]:
  155. with Session(self.engine) as session:
  156. # 基础查询,从 Documents 表中选择 path 和 id
  157. base_statement = select(Documents.path, Documents.id).where(Documents.open_id == open_id)
  158. # 如果提供了 category_id,则通过 DocumentCategories 进行关联查询
  159. if category_id:
  160. base_statement = base_statement.join(DocumentCategories, Documents.id == DocumentCategories.id).where(DocumentCategories.category_id == category_id)
  161. # 如果提供了 category_name,则先找到对应的 category_id 再进行关联查询
  162. elif category_name:
  163. category_subquery = select(Categories.id).where(Categories.name == category_name)
  164. doc_category_subquery = select(DocumentCategories.id).where(DocumentCategories.category_id.in_(category_subquery))
  165. base_statement = (
  166. base_statement.join(DocumentCategories, Documents.id == DocumentCategories.id)
  167. .where(DocumentCategories.id.in_(doc_category_subquery))
  168. )
  169. # 执行查询并返回结果(每个结果为一个元组:(文档路径, 分类ID))
  170. results = session.exec(base_statement)
  171. return [(result.path, result.id) for result in results]
  172. # 示例使用
  173. def main():
  174. from db.user_oauth import test_add
  175. open_id = test_add()
  176. # 创建实例
  177. documents_repo = DocumentsRepository()
  178. # model = Documents(id=uuid.UUID("f7069528-ccb9-11ee-933a-00155db00104"), status=5)
  179. # model = documents_repo.update(model)
  180. # res = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme.md")
  181. # documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/99readme3.md")
  182. doc_model,_,_ = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme5.md")
  183. rel_path = DocumentsRepository.get_user_file_relpath_from_docmodel(doc_model)
  184. logger.info(rel_path)
  185. # res = documents_repo.get_user_files_path(open_id)
  186. # 假设调用服务端的代码。注意这里只是假设示例,实际上要自己编写调用的代码逻辑
  187. # 添加分类
  188. # doc1 = Documents(open_id=open_id, document_name="docs_fn", status="ready", file_path="/path")
  189. # doc2 = Documents(open_id=open_id, document_name="docs_jj", status="ready", file_path="/path")
  190. # 实现有关代码
  191. if __name__ == "__main__":
  192. main()