Browse Source

打算改为非异步实现

qyl 1 year ago
parent
commit
5f16811f52
3 changed files with 186 additions and 109 deletions
  1. 108 43
      db/base.py
  2. 66 62
      db/docs.py
  3. 12 4
      db/user.py

+ 108 - 43
db/base.py

@@ -1,4 +1,5 @@
 from typing import List, Any,Optional,Callable
+import datetime 
 from sqlmodel import SQLModel
 from sqlalchemy.orm import sessionmaker 
 from sqlalchemy.ext.asyncio import AsyncSession
@@ -8,6 +9,7 @@ from sqlalchemy.dialects.postgresql import insert
 from sqlalchemy.exc import IntegrityError
 from sqlalchemy import UniqueConstraint
 from sqlalchemy.sql import func
+import sqlmodel
 from db.engine import engine
 from config import logger
 
@@ -46,44 +48,104 @@ class BaseRepository:
         async with self.session_factory() as session:  
             result = await session.execute(select(self.model))  
             return result.scalars().all() 
-        
-    async def _process_on_conflict(self, instances: SQLModel | List[SQLModel], 
-                                   update_fields: Optional[List[str]] = None) -> SQLModel | List[SQLModel]:
-        async def exec_one(session, instance, index_elements):
-            instance_dict = instance.model_dump()
-            if update_fields is not None:
-                set_values = {key:instance_dict[key] for key in instance_dict if key not in index_elements}
-                stmt = insert(self.model).values(**instance_dict).on_conflict_do_update(
-                    index_elements=index_elements,
-                    set_=set_values,
-                )
-            else:
-                stmt = insert(self.model).values(**instance_dict).on_conflict_do_nothing(
-                    index_elements=index_elements,
-                )
-            stmt = stmt.returning(self.model.id)
-            res = await session.execute(stmt)
-            instance.id = res.scalar()
-            # logger.debug(f'{self.model} {"on_conflict_do_update" if update_fields else "on_conflict_do_nothing"}')
-
+    
+    '''
+    根据 instance 模型判断哪些是唯一约束,执行 on_conflict_insert 事务
+    '''
+    async def aexec_inset_on_conflict(self, 
+                                    instance:SQLModel, 
+                                    session: Optional[AsyncSession],
+                                    update_field:Any) -> SQLModel:
+        # index_elements 获取 self.model 具有唯一属性的字段
         index_elements = self._get_unique_constraint_fields()
-        async with self.session_factory() as session:
-            if not isinstance(instances, list):
-                await exec_one(session, instances, index_elements)
-            else:
-                for instance in instances:
-                    await exec_one(session, instance, index_elements)
-            await session.commit()
-            return instances
+        instance_dict = instance.model_dump()
+        if isinstance(update_field,self.model):
+            set_values = {key:instance_dict[key] for key in instance_dict if key not in index_elements}
+            stmt = insert(self.model).values(**instance_dict).on_conflict_do_update(
+                index_elements=index_elements,
+                set_=set_values,
+            )
+        elif isinstance(update_field, dict):
+            set_values = {key:update_field[key] for key in update_field if key not in index_elements}
+            logger.debug(f"{index_elements}  {set_values}")
+            stmt = insert(self.model).values(**update_field).on_conflict_do_update(
+                index_elements=index_elements,
+                set_=set_values,
+            )
+        else:
+            stmt = insert(self.model).values(**instance_dict).on_conflict_do_nothing(
+                index_elements=index_elements,
+            )
+        stmt = stmt.returning(self.model.id)
+        res = await session.execute(stmt)
+        # 如果 id 不为 None,说明插入或更新了数据,否则
+        instance.id = res.scalar()
+        # logger.debug(f'{self.model} {"on_conflict_do_update" if update_fields else "on_conflict_do_nothing"}')
+        return instance
+            
+    '''
+    判断是否为列表,自动添加到事务 session.execute(stmt)
+    '''
+    async def aexec_instances_if_list(self, 
+                                instances: SQLModel | List[SQLModel], 
+                                session: Optional[AsyncSession],
+                                update_field:Any,
+                                ) -> SQLModel | List[SQLModel]:
+        if not isinstance(instances, list):
+            await self.aexec_inset_on_conflict(instances, session, update_field)
+        else:
+            for instance in instances:
+                await self.aexec_inset_on_conflict(instance, session, update_field)
+        return instances
+    
+    async def ais_commit(self,
+                        instances: SQLModel | List[SQLModel], 
+                        session: Optional[AsyncSession],
+                        update_field:bool=False,
+                        ) -> SQLModel | List[SQLModel]:
+        if session==None:
+            async with self.session_factory() as session:
+                await session.commit()
+        return instances
+    
+    '''
+    根据 instances 模型判断是否存在唯一约束,存在则不添加数据
+    input:
+      - session: None 说明没有上层事务,自动提交。传入 session 说明使用外部事务来 commit
+    '''
+    async def aon_conflict_do_nothing(self, 
+                                      instances: SQLModel|List[SQLModel], 
+                                      session: Optional[AsyncSession] = None,
+                                      ) -> SQLModel|List[SQLModel]:
+        if session==None:
+            async with self.session_factory() as session:
+                await self.aexec_instances_if_list(instances, session, update_field=None)
+                await session.commit()
+        else:
+            await self.aexec_instances_if_list(instances, session, update_field=None)
+        return instances
+        
 
-    async def aon_conflict_do_nothing(self, instances: SQLModel|List[SQLModel]) -> SQLModel|List[SQLModel]:
-        return await self._process_on_conflict(instances, None)
+    '''
+    根据 instances 模型判断是否存在唯一约束,存在更新数据
+    input:
+      - session: None 说明没有上层事务,自动提交。传入 session 说明使用外部事务来 commit
+    '''
+    async def aon_conflict_do_update(self,
+                                    instances: SQLModel | List[SQLModel],
+                                    session: Optional[AsyncSession] = None,
+                                    update_field:dict=None
+                                    ) -> SQLModel | List[SQLModel]:
+        if not update_field:
+            update_field = instances
+        if session==None:
+            async with self.session_factory() as session:
+                await self.aexec_instances_if_list(instances, session, update_field=update_field)
+                await session.commit()
+        else:
+            await self.aexec_instances_if_list(instances, session, update_field=update_field)
+        return instances
 
-    async def aon_conflict_do_update(
-            self,
-            instances: SQLModel | List[SQLModel],
-        ) -> SQLModel | List[SQLModel]:
-            return await self._process_on_conflict(instances, update_fields=True)
     
     def _get_unique_constraint_fields(self) -> List[str]:
             constraints = getattr(self.model.__table__, 'constraints', [])
@@ -110,20 +172,23 @@ class DouyinBaseRepository(BaseRepository):
             user_info = result.scalars().first()  # 获取查询结果的第一个记录,如果没有找到则返回 None  
             return user_info
     
+    def get_update_time(self):
+        if hasattr(self.model, 'update_time'): 
+            return {'update_time': datetime.datetime.now()}
     '''
     input: 
       - data:dict  通常抖音返回数据是json格式,因此这里也用字典传参类型,如果是 SQLmodel 会自动用 data.model_dump() 方法转化成字典
       - constraint_name 字段唯一值,如果 data 所含的字段存在于数据库则更新该行
     return : res
     '''
-    async def aon_conflict_do_update(self, data: dict):  
-        if type(data) == self.model:
-            data = data.model_dump()
-        if hasattr(self.model, 'update_time'): 
-            import datetime 
-            data['update_time'] = datetime.datetime.now()
-        clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)} 
-        return await super().aon_conflict_do_update(self.model(**clean_data))
+    # async def aon_conflict_do_update(self, data: dict, session: Optional[AsyncSession] = None,) -> SQLModel | List[SQLModel]:  
+    #     if type(data) == self.model:
+    #         data = data.model_dump()
+    #     if hasattr(self.model, 'update_time'): 
+    #         import datetime 
+    #         data['update_time'] = datetime.datetime.now()
+    #     clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)} 
+    #     return await super().aon_conflict_do_update(self.model(**clean_data))
         # try:  
         #     index_elements = self._get_unique_constraint_fields()
         #     async with self.session_factory() as session:  

+ 66 - 62
db/docs.py

@@ -8,71 +8,69 @@ import os
 import sys
 sys.path.append(os.path.dirname(os.path.dirname(__file__)))
 
-from sqlmodel import Field, SQLModel,Column, Integer, Sequence, UniqueConstraint  
+from sqlmodel import Field, SQLModel,Column, Integer, Sequence, UniqueConstraint 
 from config import DB_URL,logger
 # from db.common import engine
 from sqlalchemy.dialects.postgresql import insert
 from sqlalchemy.sql.sqltypes import Integer, String, DateTime
 from sqlalchemy.sql.schema import Column
 from sqlalchemy import UniqueConstraint
+from pydantic import UUID4
+import uuid
 from db.base import BaseRepository,DouyinBaseRepository
 from db.engine import engine,create_all
 
   
 
 class Categories(SQLModel,DouyinBaseRepository, table=True):  
-    id: int = Field(primary_key=True)  # 分类的唯一标识符  
+    id: UUID4 = Field(default_factory=uuid.uuid1, primary_key=True)  # 使用 UUID v1 作为主键 
     open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)  # 关联到用户表的外键  
     name: str = Field(default="default", index=True)  # 分类的名称,添加索引以优化查询性能  
     update_time: datetime = Field(default_factory=datetime.now)  # 创建时间、更新时间
     # 添加联合唯一约束  
-    __table_args__ = (UniqueConstraint('open_id', 'name', name='uq_open_id_name'),)
+    __table_args__ = (UniqueConstraint('open_id', 'name', name='uq_open_id_ctname'),)
     
         
     
-class DocumentCategories(SQLModel, table=True):  
-    document_id: int = Field(foreign_key="documents.id", primary_key=True)  # 关联到文档表的外键  
-    category_id: int = Field(foreign_key="categories.id", primary_key=True)  # 关联到分类表的外键  
-
+class DocumentCategories(SQLModel, table=True):
+    id: UUID4 = Field(foreign_key="documents.id",index=True, primary_key=True)  # 关联到文档表的外键  
+    category_id: UUID4 = Field(foreign_key="categories.id",index=True)  # 关联到分类表的外键  
+    __table_args__ = (UniqueConstraint('id', 'category_id', ),)
+    
 class DocStatus:  
     UNPROCESSED = 0  # 未处理  
     COMPLETED = 100  # 已完成  
     DISABLED = -1    # 禁用  
     
 class Documents(SQLModel, table=True):  
-    id: Optional[int] = Field(primary_key=True)
+    id: UUID4 = Field(default_factory=uuid.uuid1, primary_key=True,index=True)  # 使用 UUID v1 作为主键 
     open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)  # 关联到用户表的外键  
     path: str = Field(nullable=False, index=True) # 相对路径
     status: int = Field(nullable=False) # 文档状态  
     update_time: datetime = Field(default_factory=datetime.now)  # 创建时间、更新时间
-    __table_args__ = (UniqueConstraint('path'),) 
+    __table_args__ = (UniqueConstraint('open_id', 'path', name='uq_documents'),) 
     
 class CategoriesRepository(DouyinBaseRepository):  
     def __init__(self, engine=engine):  
         super().__init__(Categories, engine)  
-  
+    
+    async def aexec_add_or_update_categ(self, open_id, category_name, session):
+        categ_model = Categories(open_id=open_id,name=category_name)
+        await self.aon_conflict_do_nothing(categ_model, session)
+        if categ_model.id:
+            logger.debug(f"{open_id} add new category:{category_name}")
+        else:
+            logger.debug(f"{open_id} already have same name:{category_name}")
+        return categ_model
+
         
 class DocumentCategoriesRepository(DouyinBaseRepository):  
     def __init__(self, engine=engine):  
         super().__init__(DocumentCategories, engine)  
 
-  
-  
+
 class DocumentsRepository(DouyinBaseRepository):  
-    def __init__(self, open_id, file_path, category_name="default", engine=engine):  
-        # file_path = {DATA_DIR}/{open_id}/docs/xxx/example_file.pdf
-        relative_path = DocumentsRepository.get_relative_path(file_path)
-        if relative_path == None:
-            return
-        self.doc_model = Documents(
-                            open_id=open_id,
-                            path=relative_path,
-                            status=DocStatus.UNPROCESSED,
-                            )
-        self.category_model = Categories(
-            open_id=open_id,
-            name=category_name
-        )
+    def __init__(self, engine=engine):  
         super().__init__(Documents, engine)  
 
     def get_relative_path(full_path):
@@ -83,48 +81,54 @@ class DocumentsRepository(DouyinBaseRepository):
         else:
             logger.error(f"Can not get rel path:{full_path}")
     
-    async def add_document_with_categories(self):
-        document_id = await self.aon_conflict_do_nothing(self.doc_model)
-        logger.debug(f"document_id:{document_id}")
-        cr = CategoriesRepository()
-        category_id = await cr.aon_conflict_do_nothing(self.category_model)
-        logger.debug(f"category_id:{category_id}")
-        doc_categ_model = DocumentCategories(document_id, category_id)
-        dr = DocumentCategoriesRepository()
-        dr.aon_conflict_do_nothing(doc_categ_model)
+    # file_path = {DATA_DIR}/{open_id}/docs/xxx/example_file.pdf
+    async def add_document_with_categories(self, open_id, file_path, category_name="default"):
+        async with self.session_factory() as session:
+            doc_model:Documents = await self.aexec_add_or_update_file(open_id, file_path, session)
+            
+            cr = CategoriesRepository()
+            category_model = await cr.aexec_add_or_update_categ(open_id, category_name,session)
+            logger.debug(f"category_id:{category_model}")
+            
+            if doc_model.id is not None and category_model.id is not None:
+                doc_categ_model = DocumentCategories(id=doc_model.id, category_id=category_model.id)
+                dr = DocumentCategoriesRepository()
+                logger.info(doc_categ_model)
+                await dr.aon_conflict_do_nothing(doc_categ_model, session)
+                await session.commit()
+            else:
+                logger.info("DocumentCategories no change.")
         return
-        # 添加或更新文档
-        await self.add_or_update_document(new_document.model_dump(), "document_id")
-
-        # 获取已存在的分类
-        categories_repo = CategoriesRepository()
-        existing_categories = await categories_repo.get_all_by_ids(category_ids)
-        existing_category_ids = {category.category_id for category in existing_categories}
-
-        # 添加不存在的分类
-        for category_id in set(category_ids) - existing_category_ids:
-            new_category = Categories(open_id=new_document.open_id, category_id=category_id, category_name=f"Category_{category_id}")  # 假设名称由 ID 生成
-            await categories_repo.add([new_category])
-
-        # 创建并添加文档分类关联关系
-        document_categories_to_add = []
-        for category_id in category_ids:
-            doc_cat = DocumentCategories(document_id=new_document.document_id, category_id=category_id)
-            document_categories_to_add.append(doc_cat)
-
-        # 添加文档分类关联关系到数据库
-        document_categories_repo = DocumentCategoriesRepository()
-        await document_categories_repo.add(document_categories_to_add)
-  
+    
+    async def aexec_add_or_update_file(self, open_id, file_path, session):
+        relative_path = DocumentsRepository.get_relative_path(file_path)
+        if relative_path == None:
+            return
+        self.instance_model = Documents(
+                            open_id=open_id,
+                            path=relative_path,
+                            status=DocStatus.UNPROCESSED,
+                            )
+        # 在同一个 open_id 用户层面上,如果 relative_path 相同,则产生冲突,仅仅更新时间。说明 file_path 同路径下覆盖了新文件
+        # 没有产生冲突,说明不同用户或不同路径下新增了文件
+        document_model:Documents = await self.aon_conflict_do_update(self.instance_model, session)
+        res = self.aget(open_id=open_id, file_path=file_path)
+        logger.info(f"get doc row:{res}")
+        if document_model.id:
+            logger.debug(f"{document_model.open_id} add new file:{document_model.path}")
+        else:
+            logger.debug(f"{document_model.open_id} overwrite file:{document_model.path}")
+        return document_model
+        
 # 示例使用  
 async def main():  
     from db.user import test_add
     open_id = await test_add()
     # 创建实例  
-    categories_repo = CategoriesRepository()  
-    documents_repo = DocumentsRepository(open_id,"/home/user/code/open-douyin/open_id/docs/readme2.md")  
-    document_categories_repo = DocumentCategoriesRepository()  
-    await documents_repo.add_document_with_categories()
+    documents_repo = DocumentsRepository()  
+    res = await documents_repo.aget(id=1)
+    logger.info(res)
+    # await documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme.md")
     # 添加分类  
     # doc1 = Documents(open_id=open_id, document_name="docs_fn", status="ready", file_path="/path")
     # doc2 = Documents(open_id=open_id, document_name="docs_jj", status="ready", file_path="/path")

+ 12 - 4
db/user.py

@@ -46,7 +46,11 @@ class UserInfoRepository(DouyinBaseRepository):
         self.model:UserInfo
         
     async def create_user_info(self, user_info_data):  
-        return await self.aon_conflict_do_update(user_info_data)
+        if hasattr(self.model, 'update_time'): 
+            import datetime 
+            user_info_data['update_time'] = datetime.datetime.now()
+        clean_data = {k: v for k, v in user_info_data.items() if hasattr(self.model, k)} 
+        return await self.aon_conflict_do_update(self.model(**clean_data))
     
     async def update_user_info(self, user_id, user_info_data):  
         async with self.session_factory() as session:  
@@ -71,7 +75,11 @@ class UserOAuthRepository(DouyinBaseRepository):
         self.model:UserOAuthToken
 
     async def add_token(self, data: dict):  
-        return await self.aon_conflict_do_update(data)
+        if hasattr(self.model, 'update_time'): 
+            import datetime 
+            data['update_time'] = datetime.datetime.now()
+        clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)} 
+        return await self.aon_conflict_do_update(self.model(**clean_data))
 
 
     async def delete_token(self, token_id: int):  
@@ -112,10 +120,10 @@ async def test_add():
   }
     db_manager = UserOAuthRepository()
     res = await db_manager.add_token(user_oauth)
-    logger.debug(res)
+    # logger.debug(res)
     db_user_info = UserInfoRepository()
     res = await db_user_info.create_user_info(user_info)
-    logger.debug(res)
+    # logger.debug(res)
     return user_oauth["open_id"]
 
 if __name__ == "__main__":