|
|
@@ -1,4 +1,4 @@
|
|
|
-from typing import List, Any
|
|
|
+from typing import List, Any,Optional,Callable
|
|
|
from sqlmodel import SQLModel
|
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
@@ -47,49 +47,44 @@ class BaseRepository:
|
|
|
result = await session.execute(select(self.model))
|
|
|
return result.scalars().all()
|
|
|
|
|
|
- async def aon_conflict_do_nothing(self, instances: SQLModel|List[SQLModel]) -> SQLModel|List[SQLModel]:
|
|
|
- async def exec_one(session, instance,index_elements):
|
|
|
- data = instance.model_dump()
|
|
|
- stmt = insert(self.model).values(**data).on_conflict_do_nothing(
|
|
|
- index_elements=index_elements # 对应联合唯一约束的列
|
|
|
- ).returning(self.model.id)
|
|
|
+ async def _process_on_conflict(self, instances: SQLModel | List[SQLModel],
|
|
|
+ update_fields: Optional[List[str]] = None) -> SQLModel | List[SQLModel]:
|
|
|
+ async def exec_one(session, instance, index_elements):
|
|
|
+ instance_dict = instance.model_dump()
|
|
|
+ if update_fields is not None:
|
|
|
+ set_values = {key:instance_dict[key] for key in instance_dict if key not in index_elements}
|
|
|
+ stmt = insert(self.model).values(**instance_dict).on_conflict_do_update(
|
|
|
+ index_elements=index_elements,
|
|
|
+ set_=set_values,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ stmt = insert(self.model).values(**instance_dict).on_conflict_do_nothing(
|
|
|
+ index_elements=index_elements,
|
|
|
+ )
|
|
|
+ stmt = stmt.returning(self.model.id)
|
|
|
res = await session.execute(stmt)
|
|
|
instance.id = res.scalar()
|
|
|
-
|
|
|
+ # logger.debug(f'{self.model} {"on_conflict_do_update" if update_fields else "on_conflict_do_nothing"}')
|
|
|
+
|
|
|
index_elements = self._get_unique_constraint_fields()
|
|
|
- async with self.session_factory() as session:
|
|
|
- if not isinstance(instances, list):
|
|
|
+ async with self.session_factory() as session:
|
|
|
+ if not isinstance(instances, list):
|
|
|
await exec_one(session, instances, index_elements)
|
|
|
else:
|
|
|
- for instance in instances:
|
|
|
- await exec_one(session, instance, index_elements)
|
|
|
- await session.commit()
|
|
|
+ for instance in instances:
|
|
|
+ await exec_one(session, instance, index_elements)
|
|
|
+ await session.commit()
|
|
|
return instances
|
|
|
-
|
|
|
+
|
|
|
+ async def aon_conflict_do_nothing(self, instances: SQLModel|List[SQLModel]) -> SQLModel|List[SQLModel]:
|
|
|
+ return await self._process_on_conflict(instances, None)
|
|
|
+
|
|
|
async def aon_conflict_do_update(
|
|
|
self,
|
|
|
instances: SQLModel | List[SQLModel],
|
|
|
- update_fields: List[str], # 需要更新的字段列表
|
|
|
) -> SQLModel | List[SQLModel]:
|
|
|
- async def exec_one(session, instance, index_elements):
|
|
|
- data = instance.model_dump()
|
|
|
- stmt = insert(self.model).values(**data).on_conflict_do_update(
|
|
|
- constraint=index_elements, # 对应联合唯一约束
|
|
|
- set_=dict((k, data[k]) for k in update_fields if k in data), # 更新指定字段
|
|
|
- ).returning(self.model.id)
|
|
|
- res = await session.execute(stmt)
|
|
|
- instance.id = res.scalar()
|
|
|
-
|
|
|
- index_elements = self._get_unique_constraint_fields()
|
|
|
- async with self.session_factory() as session:
|
|
|
- if not isinstance(instances, list):
|
|
|
- await exec_one(session, instances, index_elements)
|
|
|
- else:
|
|
|
- for instance in instances:
|
|
|
- await exec_one(session, instance, index_elements)
|
|
|
- await session.commit()
|
|
|
- return instances
|
|
|
-
|
|
|
+ return await self._process_on_conflict(instances, update_fields=True)
|
|
|
+
|
|
|
def _get_unique_constraint_fields(self) -> List[str]:
|
|
|
constraints = getattr(self.model.__table__, 'constraints', [])
|
|
|
unique_constraints = [c for c in constraints if isinstance(c, UniqueConstraint)]
|
|
|
@@ -121,33 +116,38 @@ class DouyinBaseRepository(BaseRepository):
|
|
|
- constraint_name 字段唯一值,如果 data 所含的字段存在于数据库则更新该行
|
|
|
return : res
|
|
|
'''
|
|
|
- async def aadd_or_update(self, data: dict):
|
|
|
+ async def aon_conflict_do_update(self, data: dict):
|
|
|
if type(data) == self.model:
|
|
|
data = data.model_dump()
|
|
|
- try:
|
|
|
- index_elements = self._get_unique_constraint_fields()
|
|
|
- async with self.session_factory() as session:
|
|
|
- # 只获取 self.model 定义的字段
|
|
|
- clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)}
|
|
|
- # logger.debug(f"clean data:{clean_data} from {self.model}")
|
|
|
- # 如果 self.model 中有 update_time 字段,则设置其为当前时间
|
|
|
- if hasattr(self.model, 'update_time'):
|
|
|
- clean_data['update_time'] = func.now()
|
|
|
- # 构建 SQL 语句,实现插入或更新
|
|
|
- insert_stmt = insert(self.model).values(**clean_data)
|
|
|
- update_stmt = insert_stmt.on_conflict_do_update(
|
|
|
- index_elements=index_elements,
|
|
|
- set_={k: clean_data[k] for k in clean_data if k not in index_elements}
|
|
|
- ).returning(self.model.id)
|
|
|
- result = await session.execute(update_stmt)
|
|
|
- new_id = result.scalar()
|
|
|
- await session.commit()
|
|
|
- return new_id
|
|
|
- except IntegrityError as e:
|
|
|
- logger.exception(f"IntegrityError occurred: {e}")
|
|
|
- # 如果需要,可以在这里做更多的错误处理,比如回滚事务等。
|
|
|
- # 但注意,由于使用了async with,session在退出with块时通常会自动回滚未提交的事务。
|
|
|
- except Exception as e:
|
|
|
- # 捕获其他类型的异常
|
|
|
- logger.exception(f"An unexpected error occurred: {e}")
|
|
|
- raise # 如果需要,可以重新抛出异常
|
|
|
+ if hasattr(self.model, 'update_time'):
|
|
|
+ import datetime
|
|
|
+ data['update_time'] = datetime.datetime.now()
|
|
|
+ clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)}
|
|
|
+ return await super().aon_conflict_do_update(self.model(**clean_data))
|
|
|
+ # try:
|
|
|
+ # index_elements = self._get_unique_constraint_fields()
|
|
|
+ # async with self.session_factory() as session:
|
|
|
+ # # 只获取 self.model 定义的字段
|
|
|
+ # clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)}
|
|
|
+ # # logger.debug(f"clean data:{clean_data} from {self.model}")
|
|
|
+ # # 如果 self.model 中有 update_time 字段,则设置其为当前时间
|
|
|
+ # if hasattr(self.model, 'update_time'):
|
|
|
+ # clean_data['update_time'] = func.now()
|
|
|
+ # # 构建 SQL 语句,实现插入或更新
|
|
|
+ # insert_stmt = insert(self.model).values(**clean_data)
|
|
|
+ # update_stmt = insert_stmt.on_conflict_do_update(
|
|
|
+ # index_elements=index_elements,
|
|
|
+ # set_={k: clean_data[k] for k in clean_data if k not in index_elements}
|
|
|
+ # ).returning(self.model.id)
|
|
|
+ # result = await session.execute(update_stmt)
|
|
|
+ # new_id = result.scalar()
|
|
|
+ # await session.commit()
|
|
|
+ # return new_id
|
|
|
+ # except IntegrityError as e:
|
|
|
+ # logger.exception(f"IntegrityError occurred: {e}")
|
|
|
+ # # 如果需要,可以在这里做更多的错误处理,比如回滚事务等。
|
|
|
+ # # 但注意,由于使用了async with,session在退出with块时通常会自动回滚未提交的事务。
|
|
|
+ # except Exception as e:
|
|
|
+ # # 捕获其他类型的异常
|
|
|
+ # logger.exception(f"An unexpected error occurred: {e}")
|
|
|
+ # raise # 如果需要,可以重新抛出异常
|