from typing import List, Any,Dict, Optional,Callable import datetime from sqlmodel import SQLModel,UniqueConstraint,Session,select,PrimaryKeyConstraint from sqlalchemy.ext.declarative import DeclarativeMeta import sqlmodel from typing import Optional from db.engine import engine from config import logger class BaseRepository: def __init__(self, model: SQLModel, engine=engine): self.model = model self.engine = engine self.unique_constraint_fields,self.primary_key_fields = self.get_unique_constraint_fields() self.non_unique_fields = self.get_non_unique_fields() # logger.debug(f"主键字段:{ self.primary_key_fields}") # logger.debug(f"唯一约束字段:{self.unique_constraint_fields}", ) # logger.debug(f"非唯一约束字段:{self.non_unique_fields}") def create(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> Optional[int]: session = ex_session or Session(bind=self.engine) session.add(obj_in) if not ex_session: session.commit() return obj_in def get(self, id: int, ex_session: Optional[Session] = None) -> Optional[SQLModel]: session = ex_session or Session(bind=self.engine) return session.get(self.model, id) def update(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> bool: def session_exec(id: int=None, obj_in: SQLModel=None, ex_session: Optional[Session] = None): obj = session.get(self.model, id) if not obj: return obj_in for key, value in obj_in.model_dump().items(): setattr(obj, key, value) return obj if not obj_in.id: return session = ex_session or Session(bind=self.engine) obj = session_exec(obj_in.id,obj_in,session) if not ex_session: session.commit() session.refresh(obj) return obj def delete(self, id: int, ex_session: Optional[Session] = None) -> bool: def session_exec(session: Optional[Session],obj_in: SQLModel): obj = session.get(self.model, id) if not obj: return False session.delete(obj) return obj session = ex_session or Session(bind=self.engine) obj = session_exec(session,id) if not ex_session: session.commit() return obj def model_dump_by_field(self, obj_model:SQLModel=None, fields:List[str]=None): if not obj_model: obj_model = self.model if not fields: fields = self.non_unique_fields data = {} for field in fields: if hasattr(obj_model, field): data[field] = getattr(obj_model, field) return data def set_obj_by_dict(self, target_obj: SQLModel, dict_data: dict): for key, value in dict_data.items(): if hasattr(target_obj, key): setattr(target_obj, key, value) def add_or_update_by_unique(self, obj: SQLModel, ex_session: Optional[Session] = None, update_field=Any) -> bool: def session_exec(session: Optional[Session],obj_in: SQLModel, update_field=Any): table_unique_keys = self.get_unique_constraint_fields() existing_obj = self.check_exist(obj_in, check_field=table_unique_keys, ex_session=session) # logger.debug(f"existing_obj_by_unique_keys: {existing_obj_by_unique_keys} {type(existing_obj_by_unique_keys)}") if existing_obj: logger.debug(f"check {obj.__class__} from fields {table_unique_keys} exist {existing_obj} ") if isinstance(update_field, self.model): not_unique_keys = [k for k in obj.model_dump().keys() if k not in table_unique_keys] for key in not_unique_keys: setattr(existing_obj, key, getattr(obj, key)) logger.debug(f"update non unique filed: {not_unique_keys}") elif isinstance(update_field, dict): for key in update_field.keys(): setattr(existing_obj, key, getattr(obj, key)) logger.debug(f"update_field: {update_field.keys()}") else: logger.debug(" do nothing") return existing_obj else: logger.debug(f"add {obj}") session.add(obj) return obj session = ex_session or Session(bind=self.engine) obj = session_exec(session,obj,update_field) if not ex_session: session.commit() return obj def check_exist(self, obj: SQLModel, check_field:List[str]=None, ex_session=None): session = ex_session or Session(bind=self.engine) if not check_field: check_field = self.get_unique_constraint_fields() check_field_dict: Dict[str, Any] = {k: getattr(obj, k) for k in check_field} # logger.debug(f"check if unique constraint: {check_field_dict}") query = select(self.model).where(*[getattr(self.model, k) == v for k, v in check_field_dict.items()]) existing_obj_by_check_field = session.scalars(query).first() return existing_obj_by_check_field def get_unique_constraint_fields(self) -> List[str]: # 获取通过 UniqueConstraint 定义的唯一约束字段 constraints = getattr(self.model.__table__, 'constraints', []) unique_constraints_fields = [column.name for c in constraints if isinstance(c, UniqueConstraint) for column in c.columns] # 获取主键字段 primary_key_fields = [column.name for c in constraints if isinstance(c, PrimaryKeyConstraint) for column in c.columns] # 合并唯一约束字段和主键字段,并去重 # unique_fields = list(set(unique_constraints_fields)) return unique_constraints_fields,primary_key_fields def get_non_unique_fields(self) -> List[str]: unique_fields,primary_key_fields = self.get_unique_constraint_fields() unique_fields.extend(primary_key_fields) all_fields = [k for k,v in self.model.model_fields.items()] non_unique_fields = [field for field in all_fields if field not in unique_fields] return non_unique_fields def set_update_time(self, obj: SQLModel): if hasattr(self.model, 'update_time'): obj.update_time = datetime.datetime.now() return obj.update_time