docs.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import asyncio
  2. from datetime import datetime
  3. import re
  4. from typing import Optional, Tuple
  5. from enum import Enum
  6. from typing import List, Any
  7. import os
  8. import sys
  9. sys.path.append(os.path.dirname(os.path.dirname(__file__)))
  10. from sqlmodel import Field, SQLModel,Session,Relationship, Integer, Sequence, UniqueConstraint,select
  11. from config import DB_URL,logger
  12. # from db.common import engine
  13. import uuid
  14. from db.base import BaseRepository
  15. from db.engine import engine
  16. from db.model import *
  17. from typing import TYPE_CHECKING, Optional
  18. class DocumentBase(BaseRepository):
  19. def __init__(self, model: Documents, engine=...):
  20. super().__init__(model, engine)
  21. def before_update(self, obj_model: 'Documents', exist_obj: 'Documents') -> None:
  22. """在更新对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
  23. pass
  24. def before_create(self, obj_model: 'Documents') -> None:
  25. """在创建对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""
  26. pass
  27. def get_model_dump_field_for_update(self):
  28. return self.non_unique_fields
  29. def add_or_update(self, obj_model: Documents, ex_session: Optional[Session] = None) -> SQLModel:
  30. session = ex_session or Session(bind=self.engine)
  31. exist_obj:Documents = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
  32. logger.debug(f"check table '{self.model.__tablename__}' where {self.unique_constraint_fields}")
  33. if exist_obj:
  34. dict_data = self.model_dump_by_field(obj_model, fields=self.get_model_dump_field_for_update())
  35. if dict_data:
  36. logger.debug(f"modify table '{self.model.__tablename__}' id '{exist_obj.id}' from {dict_data}")
  37. else:
  38. logger.debug(f"table '{self.model.__tablename__}' do nothing. exist '{exist_obj}'. update field {self.get_model_dump_field_for_update()}")
  39. self.before_update(obj_model, exist_obj)
  40. self.set_obj_by_dict(exist_obj,dict_data)
  41. if not ex_session:
  42. session.commit()
  43. # commit 之后 exist_obj 的值就被释放了,不存在了
  44. return True
  45. return exist_obj
  46. else:
  47. self.before_create(obj_model)
  48. self.create(obj_model,ex_session)
  49. if not ex_session:
  50. session.commit()
  51. return True
  52. logger.debug(f"on table '{self.model.__tablename__}' create {obj_model}")
  53. return obj_model
  54. class CategoriesRepository(DocumentBase):
  55. def __init__(self, engine=engine):
  56. super().__init__(Categories, engine)
  57. def get_or_create(self, user, category_name, ex_session=None):
  58. session = ex_session or Session(bind=self.engine)
  59. category = session.exec(
  60. select(Categories)
  61. .join(LinkUserCatgegory)
  62. .filter(LinkUserCatgegory.user_id == user.id)
  63. .filter(Categories.name == category_name)
  64. ).first()
  65. if not category:
  66. category = Categories(name=category_name, user=user)
  67. return category
  68. class DocumentsRepository(DocumentBase):
  69. def __init__(self, engine=engine):
  70. super().__init__(Documents, engine)
  71. def get_by_cate_path(self, session, category_id, doc_path)->Documents:
  72. doc_model = session.exec(
  73. select(Documents)
  74. .join(LinkDocumentCategories)
  75. .filter(LinkDocumentCategories.category_id == category_id)
  76. .filter(Documents.path == doc_path)
  77. ).first()
  78. return doc_model
  79. def get_if_exist(self, user, doc_path, category_name, ex_session=None):
  80. session = ex_session or Session(bind=self.engine)
  81. category = CategoriesRepository().get_or_create(user, category_name=category_name, ex_session=session)
  82. doc_model = None
  83. if category:
  84. doc_model = session.exec(
  85. select(Documents)
  86. .join(LinkDocumentCategories)
  87. .filter(LinkDocumentCategories.category_id == category.id)
  88. .filter(Documents.path == doc_path)
  89. ).first()
  90. return category, doc_model
  91. def add_category_doc(self, user, file_path, category_name, ex_session):
  92. category, doc_model = self.get_if_exist(user, file_path, category_name, ex_session)
  93. if not category:
  94. doc_model:Documents = self.gen_model(user.open_id, file_path)
  95. category = Categories(name=category_name, docs=[doc_model])
  96. category.user = user
  97. else:
  98. if not doc_model:
  99. doc_model:Documents = self.gen_model(user.open_id, file_path)
  100. category.docs.append(doc_model)
  101. return category
  102. def get_or_create(self, user:User, file_path, category_name, summarize=None, ex_session=None):
  103. session = ex_session or Session(bind=self.engine)
  104. doc_model = session.exec(
  105. select(Documents)
  106. .join(LinkUserDocument)
  107. .join(LinkDocumentCategories)
  108. .join(Categories).where(Categories.name == category_name)
  109. .where(Documents.path == file_path)
  110. ).first()
  111. if not doc_model:
  112. catgory = CategoriesRepository().get_or_create(user, category_name, ex_session=ex_session)
  113. doc_model = Documents(path=file_path,status=0,user=catgory.user)
  114. doc_model.categories.append(catgory)
  115. logger.info(f"create document {doc_model}")
  116. return doc_model
  117. # file_path 可以是全路径,也可以是 {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf 形式的相对路径,由上层决定
  118. # - 如果是相对路径,则说明它是一个挂载点。
  119. def add_document_with_categories(self, user, file_path, category_name="default"):
  120. file_path = DocumentsRepository.get_doc_path_from_full_path(file_path)
  121. with Session(bind=self.engine) as session:
  122. category = self.add_category_doc(user, file_path, category_name, session)
  123. session.add(category)
  124. session.commit()
  125. session.refresh(category)
  126. return category.user,category.docs[-1]
  127. def set_status(self, doc:Documents, status:DocStatus):
  128. with Session(bind=self.engine) as session:
  129. session.add(doc)
  130. session.commit()
  131. session.refresh(doc)
  132. return doc
  133. def gen_model(self, open_id, file_path):
  134. # file_path = {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
  135. doc_model = Documents(
  136. open_id=open_id,
  137. path=file_path,
  138. status=DocStatus.UNPROCESSED,
  139. )
  140. return doc_model
  141. '''
  142. 从绝对路径中提取相对路径
  143. input: {MNT_DOUYIN_DATA}/{open_id}/docs/xxx/example_file.pdf
  144. output: xxx/example_file.pdf
  145. '''
  146. def get_doc_path_from_full_path(full_path):
  147. pattern = r'docs/(.*?)$'
  148. match = re.search(pattern, full_path)
  149. if match:
  150. return match.group(1)
  151. else:
  152. return full_path
  153. # 示例使用
  154. def main():
  155. from db.user_oauth import test_add
  156. open_id = test_add()
  157. # 创建实例
  158. documents_repo = DocumentsRepository()
  159. # model = Documents(id=uuid.UUID("f7069528-ccb9-11ee-933a-00155db00104"), status=5)
  160. # model = documents_repo.update(model)
  161. # res = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme.md")
  162. # documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/99readme3.md")
  163. doc_model,_,_ = documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme5.md")
  164. rel_path = DocumentsRepository.get_user_file_relpath_from_docmodel(doc_model)
  165. logger.info(rel_path)
  166. # res = documents_repo.get_user_files_path(open_id)
  167. # 假设调用服务端的代码。注意这里只是假设示例,实际上要自己编写调用的代码逻辑
  168. # 添加分类
  169. # doc1 = Documents(open_id=open_id, document_name="docs_fn", status="ready", file_path="/path")
  170. # doc2 = Documents(open_id=open_id, document_name="docs_jj", status="ready", file_path="/path")
  171. # 实现有关代码
  172. if __name__ == "__main__":
  173. main()