| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- from typing import List, Any
- from sqlmodel import SQLModel
- from sqlalchemy.orm import sessionmaker
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy.ext.declarative import DeclarativeMeta
- from sqlalchemy.sql.expression import select
- from sqlalchemy.dialects.postgresql import insert
- from sqlalchemy.exc import IntegrityError
- from sqlalchemy import UniqueConstraint
- from sqlalchemy.sql import func
- from db.engine import engine
- from config import logger
- class BaseRepository:
- def __init__(self, model: SQLModel, engine=engine):
- self.model = model
- self.engine = engine
- self.session_factory = sessionmaker(
- bind=engine, class_=AsyncSession, expire_on_commit=False
- )
-
- '''
- input:
- - instances: SQLModel 定义的 class 实例
- return : instances
- '''
- async def aadd(self, instances: List[SQLModel]):
- if not isinstance(instances, list):
- instances = [instances]
- async with self.session_factory() as session:
- session.add_all(instances)
- await session.commit()
- return instances
-
- '''
- input:
- - instances: SQLModel 定义的字段名
- return : instances:SQLModel
- '''
- async def aget(self, **kwargs):
- async with self.session_factory() as session:
- result = await session.get(self.model, **kwargs)
- return result
-
- async def aget_all(self):
- async with self.session_factory() as session:
- result = await session.execute(select(self.model))
- return result.scalars().all()
-
- async def aon_conflict_do_nothing(self, instances: SQLModel|List[SQLModel]) -> SQLModel|List[SQLModel]:
- async def exec_one(session, instance,index_elements):
- data = instance.model_dump()
- stmt = insert(self.model).values(**data).on_conflict_do_nothing(
- index_elements=index_elements # 对应联合唯一约束的列
- ).returning(self.model.id)
- res = await session.execute(stmt)
- instance.id = res.scalar()
-
- index_elements = self._get_unique_constraint_fields()
- async with self.session_factory() as session:
- if not isinstance(instances, list):
- await exec_one(session, instances, index_elements)
- else:
- for instance in instances:
- await exec_one(session, instance, index_elements)
- await session.commit()
- return instances
-
- async def aon_conflict_do_update(
- self,
- instances: SQLModel | List[SQLModel],
- update_fields: List[str], # 需要更新的字段列表
- ) -> SQLModel | List[SQLModel]:
- async def exec_one(session, instance, index_elements):
- data = instance.model_dump()
- stmt = insert(self.model).values(**data).on_conflict_do_update(
- constraint=index_elements, # 对应联合唯一约束
- set_=dict((k, data[k]) for k in update_fields if k in data), # 更新指定字段
- ).returning(self.model.id)
- res = await session.execute(stmt)
- instance.id = res.scalar()
- index_elements = self._get_unique_constraint_fields()
- async with self.session_factory() as session:
- if not isinstance(instances, list):
- await exec_one(session, instances, index_elements)
- else:
- for instance in instances:
- await exec_one(session, instance, index_elements)
- await session.commit()
- return instances
-
- def _get_unique_constraint_fields(self) -> List[str]:
- constraints = getattr(self.model.__table__, 'constraints', [])
- unique_constraints = [c for c in constraints if isinstance(c, UniqueConstraint)]
-
- index_elements = []
- for uc in unique_constraints:
- index_elements.extend([column.name for column in uc.columns])
-
- return list(set(index_elements)) # 去除重复字段
- class DouyinBaseRepository(BaseRepository):
- def __init__(self, model: DeclarativeMeta, engine=engine):
- super().__init__(model, engine)
- '''
- input: open_id:str
- return : SQL row
- '''
- async def get_by_open_id(self, open_id):
- async with self.session_factory() as session:
- stmt = select(self.model).where(self.model.open_id == open_id)
- result = await session.execute(stmt)
- user_info = result.scalars().first() # 获取查询结果的第一个记录,如果没有找到则返回 None
- return user_info
-
- '''
- input:
- - data:dict 通常抖音返回数据是json格式,因此这里也用字典传参类型,如果是 SQLmodel 会自动用 data.model_dump() 方法转化成字典
- - constraint_name 字段唯一值,如果 data 所含的字段存在于数据库则更新该行
- return : res
- '''
- async def aadd_or_update(self, data: dict):
- if type(data) == self.model:
- data = data.model_dump()
- try:
- index_elements = self._get_unique_constraint_fields()
- async with self.session_factory() as session:
- # 只获取 self.model 定义的字段
- clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)}
- # logger.debug(f"clean data:{clean_data} from {self.model}")
- # 如果 self.model 中有 update_time 字段,则设置其为当前时间
- if hasattr(self.model, 'update_time'):
- clean_data['update_time'] = func.now()
- # 构建 SQL 语句,实现插入或更新
- insert_stmt = insert(self.model).values(**clean_data)
- update_stmt = insert_stmt.on_conflict_do_update(
- index_elements=index_elements,
- set_={k: clean_data[k] for k in clean_data if k not in index_elements}
- ).returning(self.model.id)
- result = await session.execute(update_stmt)
- new_id = result.scalar()
- await session.commit()
- return new_id
- except IntegrityError as e:
- logger.exception(f"IntegrityError occurred: {e}")
- # 如果需要,可以在这里做更多的错误处理,比如回滚事务等。
- # 但注意,由于使用了async with,session在退出with块时通常会自动回滚未提交的事务。
- except Exception as e:
- # 捕获其他类型的异常
- logger.exception(f"An unexpected error occurred: {e}")
- raise # 如果需要,可以重新抛出异常
|