|
@@ -1,218 +1,140 @@
|
|
|
-from typing import List, Any,Optional,Callable
|
|
|
|
|
|
|
+from typing import List, Any,Dict, Optional,Callable
|
|
|
import datetime
|
|
import datetime
|
|
|
-from sqlmodel import SQLModel
|
|
|
|
|
-from sqlalchemy.orm import sessionmaker
|
|
|
|
|
-from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
|
+from sqlmodel import SQLModel,UniqueConstraint,Session,select,PrimaryKeyConstraint
|
|
|
from sqlalchemy.ext.declarative import DeclarativeMeta
|
|
from sqlalchemy.ext.declarative import DeclarativeMeta
|
|
|
-from sqlalchemy.sql.expression import select
|
|
|
|
|
-from sqlalchemy.dialects.postgresql import insert
|
|
|
|
|
-from sqlalchemy.exc import IntegrityError
|
|
|
|
|
-from sqlalchemy import UniqueConstraint
|
|
|
|
|
-from sqlalchemy.sql import func
|
|
|
|
|
import sqlmodel
|
|
import sqlmodel
|
|
|
|
|
+from typing import Optional
|
|
|
from db.engine import engine
|
|
from db.engine import engine
|
|
|
from config import logger
|
|
from config import logger
|
|
|
|
|
|
|
|
-class BaseRepository:
|
|
|
|
|
- def __init__(self, model: SQLModel, engine=engine):
|
|
|
|
|
- self.model = model
|
|
|
|
|
- self.engine = engine
|
|
|
|
|
- self.session_factory = sessionmaker(
|
|
|
|
|
- bind=engine, class_=AsyncSession, expire_on_commit=False
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- '''
|
|
|
|
|
- input:
|
|
|
|
|
- - instances: SQLModel 定义的 class 实例
|
|
|
|
|
- return : instances
|
|
|
|
|
- '''
|
|
|
|
|
- async def aadd(self, instances: List[SQLModel]):
|
|
|
|
|
- if not isinstance(instances, list):
|
|
|
|
|
- instances = [instances]
|
|
|
|
|
- async with self.session_factory() as session:
|
|
|
|
|
- session.add_all(instances)
|
|
|
|
|
- await session.commit()
|
|
|
|
|
- return instances
|
|
|
|
|
-
|
|
|
|
|
- '''
|
|
|
|
|
- input:
|
|
|
|
|
- - instances: SQLModel 定义的字段名
|
|
|
|
|
- return : instances:SQLModel
|
|
|
|
|
- '''
|
|
|
|
|
- async def aget(self, **kwargs):
|
|
|
|
|
- async with self.session_factory() as session:
|
|
|
|
|
- result = await session.get(self.model, **kwargs)
|
|
|
|
|
- return result
|
|
|
|
|
-
|
|
|
|
|
- async def aget_all(self):
|
|
|
|
|
- async with self.session_factory() as session:
|
|
|
|
|
- result = await session.execute(select(self.model))
|
|
|
|
|
- return result.scalars().all()
|
|
|
|
|
-
|
|
|
|
|
- '''
|
|
|
|
|
- 根据 instance 模型判断哪些是唯一约束,执行 on_conflict_insert 事务
|
|
|
|
|
- '''
|
|
|
|
|
- async def aexec_inset_on_conflict(self,
|
|
|
|
|
- instance:SQLModel,
|
|
|
|
|
- session: Optional[AsyncSession],
|
|
|
|
|
- update_field:Any) -> SQLModel:
|
|
|
|
|
- # index_elements 获取 self.model 具有唯一属性的字段
|
|
|
|
|
- index_elements = self._get_unique_constraint_fields()
|
|
|
|
|
- instance_dict = instance.model_dump()
|
|
|
|
|
- if isinstance(update_field,self.model):
|
|
|
|
|
- set_values = {key:instance_dict[key] for key in instance_dict if key not in index_elements}
|
|
|
|
|
- stmt = insert(self.model).values(**instance_dict).on_conflict_do_update(
|
|
|
|
|
- index_elements=index_elements,
|
|
|
|
|
- set_=set_values,
|
|
|
|
|
- )
|
|
|
|
|
- elif isinstance(update_field, dict):
|
|
|
|
|
- set_values = {key:update_field[key] for key in update_field if key not in index_elements}
|
|
|
|
|
- logger.debug(f"{index_elements} {set_values}")
|
|
|
|
|
- stmt = insert(self.model).values(**update_field).on_conflict_do_update(
|
|
|
|
|
- index_elements=index_elements,
|
|
|
|
|
- set_=set_values,
|
|
|
|
|
- )
|
|
|
|
|
- else:
|
|
|
|
|
- stmt = insert(self.model).values(**instance_dict).on_conflict_do_nothing(
|
|
|
|
|
- index_elements=index_elements,
|
|
|
|
|
- )
|
|
|
|
|
- stmt = stmt.returning(self.model.id)
|
|
|
|
|
- res = await session.execute(stmt)
|
|
|
|
|
- # 如果 id 不为 None,说明插入或更新了数据,否则
|
|
|
|
|
- instance.id = res.scalar()
|
|
|
|
|
- # logger.debug(f'{self.model} {"on_conflict_do_update" if update_fields else "on_conflict_do_nothing"}')
|
|
|
|
|
- return instance
|
|
|
|
|
-
|
|
|
|
|
- '''
|
|
|
|
|
- 判断是否为列表,自动添加到事务 session.execute(stmt)
|
|
|
|
|
- '''
|
|
|
|
|
- async def aexec_instances_if_list(self,
|
|
|
|
|
- instances: SQLModel | List[SQLModel],
|
|
|
|
|
- session: Optional[AsyncSession],
|
|
|
|
|
- update_field:Any,
|
|
|
|
|
- ) -> SQLModel | List[SQLModel]:
|
|
|
|
|
- if not isinstance(instances, list):
|
|
|
|
|
- await self.aexec_inset_on_conflict(instances, session, update_field)
|
|
|
|
|
- else:
|
|
|
|
|
- for instance in instances:
|
|
|
|
|
- await self.aexec_inset_on_conflict(instance, session, update_field)
|
|
|
|
|
- return instances
|
|
|
|
|
-
|
|
|
|
|
- async def ais_commit(self,
|
|
|
|
|
- instances: SQLModel | List[SQLModel],
|
|
|
|
|
- session: Optional[AsyncSession],
|
|
|
|
|
- update_field:bool=False,
|
|
|
|
|
- ) -> SQLModel | List[SQLModel]:
|
|
|
|
|
- if session==None:
|
|
|
|
|
- async with self.session_factory() as session:
|
|
|
|
|
- await session.commit()
|
|
|
|
|
- return instances
|
|
|
|
|
-
|
|
|
|
|
- '''
|
|
|
|
|
- 根据 instances 模型判断是否存在唯一约束,存在则不添加数据
|
|
|
|
|
- input:
|
|
|
|
|
- - session: None 说明没有上层事务,自动提交。传入 session 说明使用外部事务来 commit
|
|
|
|
|
- '''
|
|
|
|
|
- async def aon_conflict_do_nothing(self,
|
|
|
|
|
- instances: SQLModel|List[SQLModel],
|
|
|
|
|
- session: Optional[AsyncSession] = None,
|
|
|
|
|
- ) -> SQLModel|List[SQLModel]:
|
|
|
|
|
- if session==None:
|
|
|
|
|
- async with self.session_factory() as session:
|
|
|
|
|
- await self.aexec_instances_if_list(instances, session, update_field=None)
|
|
|
|
|
- await session.commit()
|
|
|
|
|
- else:
|
|
|
|
|
- await self.aexec_instances_if_list(instances, session, update_field=None)
|
|
|
|
|
- return instances
|
|
|
|
|
|
|
+class BaseRepository:
|
|
|
|
|
+ def __init__(self, model: SQLModel, engine=engine):
|
|
|
|
|
+ self.model = model
|
|
|
|
|
+ self.engine = engine
|
|
|
|
|
+ self.unique_constraint_fields,self.primary_key_fields = self.get_unique_constraint_fields()
|
|
|
|
|
+ self.non_unique_fields = self.get_non_unique_fields()
|
|
|
|
|
+ # logger.debug(f"主键字段:{ self.primary_key_fields}")
|
|
|
|
|
+ # logger.debug(f"唯一约束字段:{self.unique_constraint_fields}", )
|
|
|
|
|
+ # logger.debug(f"非唯一约束字段:{self.non_unique_fields}")
|
|
|
|
|
+
|
|
|
|
|
+ def create(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> Optional[int]:
|
|
|
|
|
+ session = ex_session or Session(bind=self.engine)
|
|
|
|
|
+ session.add(obj_in)
|
|
|
|
|
+ if not ex_session:
|
|
|
|
|
+ session.commit()
|
|
|
|
|
+ return obj_in
|
|
|
|
|
+
|
|
|
|
|
+ def get(self, id: int, ex_session: Optional[Session] = None) -> Optional[SQLModel]:
|
|
|
|
|
+ session = ex_session or Session(bind=self.engine)
|
|
|
|
|
+ return session.get(self.model, id)
|
|
|
|
|
+
|
|
|
|
|
+ def update(self, id: int, obj_in: SQLModel, ex_session: Optional[Session] = None) -> bool:
|
|
|
|
|
+ def session_exec(session,obj_in: SQLModel):
|
|
|
|
|
+ obj = session.get(self.model, id)
|
|
|
|
|
+ if not obj:
|
|
|
|
|
+ return False
|
|
|
|
|
+ for key, value in obj_in.model_dump().items():
|
|
|
|
|
+ setattr(obj, key, value)
|
|
|
|
|
|
|
|
|
|
+ session = ex_session or Session(bind=self.engine)
|
|
|
|
|
+ session_exec(session,obj_in)
|
|
|
|
|
+ if not ex_session:
|
|
|
|
|
+ session.commit()
|
|
|
|
|
+ return obj_in
|
|
|
|
|
|
|
|
- '''
|
|
|
|
|
- 根据 instances 模型判断是否存在唯一约束,存在更新数据
|
|
|
|
|
- input:
|
|
|
|
|
- - session: None 说明没有上层事务,自动提交。传入 session 说明使用外部事务来 commit
|
|
|
|
|
- '''
|
|
|
|
|
- async def aon_conflict_do_update(self,
|
|
|
|
|
- instances: SQLModel | List[SQLModel],
|
|
|
|
|
- session: Optional[AsyncSession] = None,
|
|
|
|
|
- update_field:dict=None
|
|
|
|
|
- ) -> SQLModel | List[SQLModel]:
|
|
|
|
|
- if not update_field:
|
|
|
|
|
- update_field = instances
|
|
|
|
|
- if session==None:
|
|
|
|
|
- async with self.session_factory() as session:
|
|
|
|
|
- await self.aexec_instances_if_list(instances, session, update_field=update_field)
|
|
|
|
|
- await session.commit()
|
|
|
|
|
- else:
|
|
|
|
|
- await self.aexec_instances_if_list(instances, session, update_field=update_field)
|
|
|
|
|
- return instances
|
|
|
|
|
|
|
+ def delete(self, id: int, ex_session: Optional[Session] = None) -> bool:
|
|
|
|
|
+ def session_exec(session: Optional[Session],obj_in: SQLModel):
|
|
|
|
|
+ obj = session.get(self.model, id)
|
|
|
|
|
+ if not obj:
|
|
|
|
|
+ return False
|
|
|
|
|
+ session.delete(obj)
|
|
|
|
|
+ return obj
|
|
|
|
|
+ session = ex_session or Session(bind=self.engine)
|
|
|
|
|
+ obj = session_exec(session,id)
|
|
|
|
|
+ if not ex_session:
|
|
|
|
|
+ session.commit()
|
|
|
|
|
+ return obj
|
|
|
|
|
+
|
|
|
|
|
+ def model_dump_by_field(self, obj_model:SQLModel=None, fields:List[str]=None):
|
|
|
|
|
+ if not obj_model:
|
|
|
|
|
+ obj_model = self.model
|
|
|
|
|
+ if not fields:
|
|
|
|
|
+ fields = self.non_unique_fields
|
|
|
|
|
+ data = {}
|
|
|
|
|
+ for field in fields:
|
|
|
|
|
+ if hasattr(obj_model, field):
|
|
|
|
|
+ data[field] = getattr(obj_model, field)
|
|
|
|
|
+ return data
|
|
|
|
|
+
|
|
|
|
|
+ def set_obj_by_dict(self, target_obj: SQLModel, dict_data: dict):
|
|
|
|
|
+ for key, value in dict_data.items():
|
|
|
|
|
+ if hasattr(target_obj, key):
|
|
|
|
|
+ setattr(target_obj, key, value)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ def add_or_update_by_unique(self, obj: SQLModel, ex_session: Optional[Session] = None, update_field=Any) -> bool:
|
|
|
|
|
+ def session_exec(session: Optional[Session],obj_in: SQLModel, update_field=Any):
|
|
|
|
|
+ table_unique_keys = self.get_unique_constraint_fields()
|
|
|
|
|
+ existing_obj = self.check_exist(obj_in, check_field=table_unique_keys, ex_session=session)
|
|
|
|
|
+ # logger.debug(f"existing_obj_by_unique_keys: {existing_obj_by_unique_keys} {type(existing_obj_by_unique_keys)}")
|
|
|
|
|
+ if existing_obj:
|
|
|
|
|
+ logger.debug(f"check {obj.__class__} from fields {table_unique_keys} exist {existing_obj} ")
|
|
|
|
|
+ if isinstance(update_field, self.model):
|
|
|
|
|
+ not_unique_keys = [k for k in obj.model_dump().keys() if k not in table_unique_keys]
|
|
|
|
|
+ for key in not_unique_keys:
|
|
|
|
|
+ setattr(existing_obj, key, getattr(obj, key))
|
|
|
|
|
+ logger.debug(f"update non unique filed: {not_unique_keys}")
|
|
|
|
|
+ elif isinstance(update_field, dict):
|
|
|
|
|
+ for key in update_field.keys():
|
|
|
|
|
+ setattr(existing_obj, key, getattr(obj, key))
|
|
|
|
|
+ logger.debug(f"update_field: {update_field.keys()}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ logger.debug(" do nothing")
|
|
|
|
|
+ return existing_obj
|
|
|
|
|
+ else:
|
|
|
|
|
+ logger.debug(f"add {obj}")
|
|
|
|
|
+ session.add(obj)
|
|
|
|
|
+ return obj
|
|
|
|
|
|
|
|
|
|
+ session = ex_session or Session(bind=self.engine)
|
|
|
|
|
+ obj = session_exec(session,obj,update_field)
|
|
|
|
|
+ if not ex_session:
|
|
|
|
|
+ session.commit()
|
|
|
|
|
+ return obj
|
|
|
|
|
|
|
|
- def _get_unique_constraint_fields(self) -> List[str]:
|
|
|
|
|
- constraints = getattr(self.model.__table__, 'constraints', [])
|
|
|
|
|
- unique_constraints = [c for c in constraints if isinstance(c, UniqueConstraint)]
|
|
|
|
|
-
|
|
|
|
|
- index_elements = []
|
|
|
|
|
- for uc in unique_constraints:
|
|
|
|
|
- index_elements.extend([column.name for column in uc.columns])
|
|
|
|
|
-
|
|
|
|
|
- return list(set(index_elements)) # 去除重复字段
|
|
|
|
|
|
|
+ def check_exist(self, obj: SQLModel, check_field=None, ex_session=None):
|
|
|
|
|
+ session = ex_session or Session(bind=self.engine)
|
|
|
|
|
+ if not check_field:
|
|
|
|
|
+ check_field = self.get_unique_constraint_fields()
|
|
|
|
|
+ check_field_dict: Dict[str, Any] = {k: getattr(obj, k) for k in check_field}
|
|
|
|
|
+ # logger.debug(f"check if unique constraint: {check_field_dict}")
|
|
|
|
|
+
|
|
|
|
|
+ query = select(self.model).where(*[getattr(self.model, k) == v for k, v in check_field_dict.items()])
|
|
|
|
|
+ existing_obj_by_check_field = session.scalars(query).first()
|
|
|
|
|
+ return existing_obj_by_check_field
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ def get_unique_constraint_fields(self) -> List[str]:
|
|
|
|
|
+ # 获取通过 UniqueConstraint 定义的唯一约束字段
|
|
|
|
|
+ constraints = getattr(self.model.__table__, 'constraints', [])
|
|
|
|
|
+ unique_constraints_fields = [column.name for c in constraints if isinstance(c, UniqueConstraint) for column in c.columns]
|
|
|
|
|
+
|
|
|
|
|
+ # 获取主键字段
|
|
|
|
|
+ primary_key_fields = [column.name for c in constraints if isinstance(c, PrimaryKeyConstraint) for column in c.columns]
|
|
|
|
|
+
|
|
|
|
|
+ # 合并唯一约束字段和主键字段,并去重
|
|
|
|
|
+ # unique_fields = list(set(unique_constraints_fields))
|
|
|
|
|
+
|
|
|
|
|
+ return unique_constraints_fields,primary_key_fields
|
|
|
|
|
|
|
|
-class DouyinBaseRepository(BaseRepository):
|
|
|
|
|
- def __init__(self, model: DeclarativeMeta, engine=engine):
|
|
|
|
|
- super().__init__(model, engine)
|
|
|
|
|
|
|
+ def get_non_unique_fields(self) -> List[str]:
|
|
|
|
|
+ unique_fields,primary_key_fields = self.get_unique_constraint_fields()
|
|
|
|
|
+ unique_fields.extend(primary_key_fields)
|
|
|
|
|
+ all_fields = [k for k,v in self.model.model_fields.items()]
|
|
|
|
|
+ non_unique_fields = [field for field in all_fields if field not in unique_fields]
|
|
|
|
|
+ return non_unique_fields
|
|
|
|
|
|
|
|
- '''
|
|
|
|
|
- input: open_id:str
|
|
|
|
|
- return : SQL row
|
|
|
|
|
- '''
|
|
|
|
|
- async def get_by_open_id(self, open_id):
|
|
|
|
|
- async with self.session_factory() as session:
|
|
|
|
|
- stmt = select(self.model).where(self.model.open_id == open_id)
|
|
|
|
|
- result = await session.execute(stmt)
|
|
|
|
|
- user_info = result.scalars().first() # 获取查询结果的第一个记录,如果没有找到则返回 None
|
|
|
|
|
- return user_info
|
|
|
|
|
-
|
|
|
|
|
- def get_update_time(self):
|
|
|
|
|
|
|
+ def set_update_time(self, obj: SQLModel):
|
|
|
if hasattr(self.model, 'update_time'):
|
|
if hasattr(self.model, 'update_time'):
|
|
|
- return {'update_time': datetime.datetime.now()}
|
|
|
|
|
- '''
|
|
|
|
|
- input:
|
|
|
|
|
- - data:dict 通常抖音返回数据是json格式,因此这里也用字典传参类型,如果是 SQLmodel 会自动用 data.model_dump() 方法转化成字典
|
|
|
|
|
- - constraint_name 字段唯一值,如果 data 所含的字段存在于数据库则更新该行
|
|
|
|
|
- return : res
|
|
|
|
|
- '''
|
|
|
|
|
- # async def aon_conflict_do_update(self, data: dict, session: Optional[AsyncSession] = None,) -> SQLModel | List[SQLModel]:
|
|
|
|
|
- # if type(data) == self.model:
|
|
|
|
|
- # data = data.model_dump()
|
|
|
|
|
- # if hasattr(self.model, 'update_time'):
|
|
|
|
|
- # import datetime
|
|
|
|
|
- # data['update_time'] = datetime.datetime.now()
|
|
|
|
|
- # clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)}
|
|
|
|
|
- # return await super().aon_conflict_do_update(self.model(**clean_data))
|
|
|
|
|
- # try:
|
|
|
|
|
- # index_elements = self._get_unique_constraint_fields()
|
|
|
|
|
- # async with self.session_factory() as session:
|
|
|
|
|
- # # 只获取 self.model 定义的字段
|
|
|
|
|
- # clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)}
|
|
|
|
|
- # # logger.debug(f"clean data:{clean_data} from {self.model}")
|
|
|
|
|
- # # 如果 self.model 中有 update_time 字段,则设置其为当前时间
|
|
|
|
|
- # if hasattr(self.model, 'update_time'):
|
|
|
|
|
- # clean_data['update_time'] = func.now()
|
|
|
|
|
- # # 构建 SQL 语句,实现插入或更新
|
|
|
|
|
- # insert_stmt = insert(self.model).values(**clean_data)
|
|
|
|
|
- # update_stmt = insert_stmt.on_conflict_do_update(
|
|
|
|
|
- # index_elements=index_elements,
|
|
|
|
|
- # set_={k: clean_data[k] for k in clean_data if k not in index_elements}
|
|
|
|
|
- # ).returning(self.model.id)
|
|
|
|
|
- # result = await session.execute(update_stmt)
|
|
|
|
|
- # new_id = result.scalar()
|
|
|
|
|
- # await session.commit()
|
|
|
|
|
- # return new_id
|
|
|
|
|
- # except IntegrityError as e:
|
|
|
|
|
- # logger.exception(f"IntegrityError occurred: {e}")
|
|
|
|
|
- # # 如果需要,可以在这里做更多的错误处理,比如回滚事务等。
|
|
|
|
|
- # # 但注意,由于使用了async with,session在退出with块时通常会自动回滚未提交的事务。
|
|
|
|
|
- # except Exception as e:
|
|
|
|
|
- # # 捕获其他类型的异常
|
|
|
|
|
- # logger.exception(f"An unexpected error occurred: {e}")
|
|
|
|
|
- # raise # 如果需要,可以重新抛出异常
|
|
|
|
|
|
|
+ obj.update_time = datetime.datetime.now()
|
|
|
|
|
+ return obj.update_time
|