Browse Source

优化 aon_conflict_do_nothing、on_conflict_do_update

qyl 1 year ago
parent
commit
486da35fad
3 changed files with 74 additions and 70 deletions
  1. 62 62
      db/base.py
  2. 7 4
      db/docs.py
  3. 5 4
      db/user.py

+ 62 - 62
db/base.py

@@ -1,4 +1,4 @@
-from typing import List, Any
+from typing import List, Any,Optional,Callable
 from sqlmodel import SQLModel
 from sqlalchemy.orm import sessionmaker 
 from sqlalchemy.ext.asyncio import AsyncSession
@@ -47,49 +47,44 @@ class BaseRepository:
             result = await session.execute(select(self.model))  
             return result.scalars().all() 
         
-    async def aon_conflict_do_nothing(self, instances: SQLModel|List[SQLModel]) -> SQLModel|List[SQLModel]:
-        async def exec_one(session, instance,index_elements):
-            data = instance.model_dump()
-            stmt = insert(self.model).values(**data).on_conflict_do_nothing(  
-                index_elements=index_elements  # 对应联合唯一约束的列  
-            ).returning(self.model.id)
+    async def _process_on_conflict(self, instances: SQLModel | List[SQLModel], 
+                                   update_fields: Optional[List[str]] = None) -> SQLModel | List[SQLModel]:
+        async def exec_one(session, instance, index_elements):
+            instance_dict = instance.model_dump()
+            if update_fields is not None:
+                set_values = {key:instance_dict[key] for key in instance_dict if key not in index_elements}
+                stmt = insert(self.model).values(**instance_dict).on_conflict_do_update(
+                    index_elements=index_elements,
+                    set_=set_values,
+                )
+            else:
+                stmt = insert(self.model).values(**instance_dict).on_conflict_do_nothing(
+                    index_elements=index_elements,
+                )
+            stmt = stmt.returning(self.model.id)
             res = await session.execute(stmt)
             instance.id = res.scalar()
-        
+            # logger.debug(f'{self.model} {"on_conflict_do_update" if update_fields else "on_conflict_do_nothing"}')
+
         index_elements = self._get_unique_constraint_fields()
-        async with self.session_factory() as session:  
-            if not isinstance(instances, list):  
+        async with self.session_factory() as session:
+            if not isinstance(instances, list):
                 await exec_one(session, instances, index_elements)
             else:
-                for instance in instances:  
-                   await exec_one(session, instance, index_elements)
-            await session.commit()  
+                for instance in instances:
+                    await exec_one(session, instance, index_elements)
+            await session.commit()
             return instances
-        
+
+    async def aon_conflict_do_nothing(self, instances: SQLModel|List[SQLModel]) -> SQLModel|List[SQLModel]:
+        return await self._process_on_conflict(instances, None)
+
     async def aon_conflict_do_update(
             self,
             instances: SQLModel | List[SQLModel],
-            update_fields: List[str],  # 需要更新的字段列表
         ) -> SQLModel | List[SQLModel]:
-            async def exec_one(session, instance, index_elements):
-                data = instance.model_dump()
-                stmt = insert(self.model).values(**data).on_conflict_do_update(
-                    constraint=index_elements,  # 对应联合唯一约束
-                    set_=dict((k, data[k]) for k in update_fields if k in data),  # 更新指定字段
-                ).returning(self.model.id)
-                res = await session.execute(stmt)
-                instance.id = res.scalar()
-
-            index_elements = self._get_unique_constraint_fields()
-            async with self.session_factory() as session:
-                if not isinstance(instances, list):
-                    await exec_one(session, instances, index_elements)
-                else:
-                    for instance in instances:
-                        await exec_one(session, instance, index_elements)
-                await session.commit()
-                return instances
-            
+            return await self._process_on_conflict(instances, update_fields=True)
+    
     def _get_unique_constraint_fields(self) -> List[str]:
             constraints = getattr(self.model.__table__, 'constraints', [])
             unique_constraints = [c for c in constraints if isinstance(c, UniqueConstraint)]
@@ -121,33 +116,38 @@ class DouyinBaseRepository(BaseRepository):
       - constraint_name 字段唯一值,如果 data 所含的字段存在于数据库则更新该行
     return : res
     '''
-    async def aadd_or_update(self, data: dict):  
+    async def aon_conflict_do_update(self, data: dict):  
         if type(data) == self.model:
             data = data.model_dump()
-        try:  
-            index_elements = self._get_unique_constraint_fields()
-            async with self.session_factory() as session:  
-                # 只获取 self.model 定义的字段
-                clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)} 
-                # logger.debug(f"clean data:{clean_data} from {self.model}") 
-                # 如果 self.model 中有 update_time 字段,则设置其为当前时间  
-                if hasattr(self.model, 'update_time'):  
-                    clean_data['update_time'] = func.now() 
-                # 构建 SQL 语句,实现插入或更新
-                insert_stmt = insert(self.model).values(**clean_data)
-                update_stmt = insert_stmt.on_conflict_do_update(  
-                    index_elements=index_elements,  
-                    set_={k: clean_data[k] for k in clean_data if k not in index_elements}  
-                ).returning(self.model.id)
-                result = await session.execute(update_stmt)  
-                new_id = result.scalar()
-                await session.commit()  
-                return new_id  
-        except IntegrityError as e:  
-            logger.exception(f"IntegrityError occurred: {e}")  
-            # 如果需要,可以在这里做更多的错误处理,比如回滚事务等。  
-            # 但注意,由于使用了async with,session在退出with块时通常会自动回滚未提交的事务。  
-        except Exception as e:  
-            # 捕获其他类型的异常  
-            logger.exception(f"An unexpected error occurred: {e}")  
-            raise  # 如果需要,可以重新抛出异常
+        if hasattr(self.model, 'update_time'): 
+            import datetime 
+            data['update_time'] = datetime.datetime.now()
+        clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)} 
+        return await super().aon_conflict_do_update(self.model(**clean_data))
+        # try:  
+        #     index_elements = self._get_unique_constraint_fields()
+        #     async with self.session_factory() as session:  
+        #         # 只获取 self.model 定义的字段
+        #         clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)} 
+        #         # logger.debug(f"clean data:{clean_data} from {self.model}") 
+        #         # 如果 self.model 中有 update_time 字段,则设置其为当前时间  
+        #         if hasattr(self.model, 'update_time'):  
+        #             clean_data['update_time'] = func.now() 
+        #         # 构建 SQL 语句,实现插入或更新
+        #         insert_stmt = insert(self.model).values(**clean_data)
+        #         update_stmt = insert_stmt.on_conflict_do_update(  
+        #             index_elements=index_elements,  
+        #             set_={k: clean_data[k] for k in clean_data if k not in index_elements}  
+        #         ).returning(self.model.id)
+        #         result = await session.execute(update_stmt)  
+        #         new_id = result.scalar()
+        #         await session.commit()  
+        #         return new_id  
+        # except IntegrityError as e:  
+        #     logger.exception(f"IntegrityError occurred: {e}")  
+        #     # 如果需要,可以在这里做更多的错误处理,比如回滚事务等。  
+        #     # 但注意,由于使用了async with,session在退出with块时通常会自动回滚未提交的事务。  
+        # except Exception as e:  
+        #     # 捕获其他类型的异常  
+        #     logger.exception(f"An unexpected error occurred: {e}")  
+        #     raise  # 如果需要,可以重新抛出异常

+ 7 - 4
db/docs.py

@@ -84,11 +84,14 @@ class DocumentsRepository(DouyinBaseRepository):
             logger.error(f"Can not get rel path:{full_path}")
     
     async def add_document_with_categories(self):
-        # document_id = await self.aadd_or_update(self.doc_model.model_dump(), constraint_name="path")
-        # logger.debug(f"document_id:{document_id}")
-        c = CategoriesRepository()
-        category_id = await c.aon_conflict_do_nothing(self.category_model, index_elements=["open_id", "name"])
+        document_id = await self.aon_conflict_do_nothing(self.doc_model)
+        logger.debug(f"document_id:{document_id}")
+        cr = CategoriesRepository()
+        category_id = await cr.aon_conflict_do_nothing(self.category_model)
         logger.debug(f"category_id:{category_id}")
+        doc_categ_model = DocumentCategories(document_id, category_id)
+        dr = DocumentCategoriesRepository()
+        dr.aon_conflict_do_nothing(doc_categ_model)
         return
         # 添加或更新文档
         await self.add_or_update_document(new_document.model_dump(), "document_id")

+ 5 - 4
db/user.py

@@ -16,7 +16,7 @@ from db.engine import engine,create_all
 # 定义数据库模型,不推荐使用 __tablename__ = '"UserInfo"' 来定义包含大写的表名字,
 # 因为可能会导致与其他数据库系统不兼容,而且表查询的时候需要额外注意表格名的大小写
 class UserOAuthToken(SQLModel, table=True):  
-    id: Optional[int] = Field(default=None, primary_key=True)
+    id: Optional[int] = Field(primary_key=True)
     access_token:str
     expires_in: Optional[int] = None
     open_id:str = Field(index=True)
@@ -27,7 +27,7 @@ class UserOAuthToken(SQLModel, table=True):
     __table_args__ = (UniqueConstraint('open_id'),)  
 
 class UserInfo(SQLModel, table=True):  
-    id: Optional[int] = Field(default=None, primary_key=True)  
+    id: Optional[int] = Field(primary_key=True)  
     avatar: str  
     avatar_larger: str  
     client_key: str  
@@ -46,7 +46,7 @@ class UserInfoRepository(DouyinBaseRepository):
         self.model:UserInfo
         
     async def create_user_info(self, user_info_data):  
-        return await self.aadd_or_update(user_info_data)
+        return await self.aon_conflict_do_update(user_info_data)
     
     async def update_user_info(self, user_id, user_info_data):  
         async with self.session_factory() as session:  
@@ -71,7 +71,7 @@ class UserOAuthRepository(DouyinBaseRepository):
         self.model:UserOAuthToken
 
     async def add_token(self, data: dict):  
-        return await self.aadd_or_update(data)
+        return await self.aon_conflict_do_update(data)
 
 
     async def delete_token(self, token_id: int):  
@@ -112,6 +112,7 @@ async def test_add():
   }
     db_manager = UserOAuthRepository()
     res = await db_manager.add_token(user_oauth)
+    logger.debug(res)
     db_user_info = UserInfoRepository()
     res = await db_user_info.create_user_info(user_info)
     logger.debug(res)