| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- from typing import List, Any,Dict, Optional,Callable
- import datetime
- from sqlmodel import SQLModel,UniqueConstraint,Session,select,PrimaryKeyConstraint
- from sqlalchemy.ext.declarative import DeclarativeMeta
- import sqlmodel
- from typing import Optional
- from db.engine import engine
- from config import logger
- class BaseRepository:
- def __init__(self, model: SQLModel, engine=engine):
- self.model = model
- self.engine = engine
- self.unique_constraint_fields,self.primary_key_fields = self.get_unique_constraint_fields()
- self.non_unique_fields = self.get_non_unique_fields()
- # logger.debug(f"主键字段:{ self.primary_key_fields}")
- # logger.debug(f"唯一约束字段:{self.unique_constraint_fields}", )
- # logger.debug(f"非唯一约束字段:{self.non_unique_fields}")
- def create(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> Optional[int]:
- session = ex_session or Session(bind=self.engine)
- session.add(obj_in)
- if not ex_session:
- session.commit()
- return obj_in
- def get(self, id: int, ex_session: Optional[Session] = None) -> Optional[SQLModel]:
- session = ex_session or Session(bind=self.engine)
- return session.get(self.model, id)
- def update(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> bool:
- def session_exec(id: int=None, obj_in: SQLModel=None, ex_session: Optional[Session] = None):
- obj = session.get(self.model, id)
- if not obj:
- return obj_in
- for key, value in obj_in.model_dump().items():
- setattr(obj, key, value)
- return obj
- if not obj_in.id:
- return
- session = ex_session or Session(bind=self.engine)
- obj = session_exec(obj_in.id,obj_in,session)
- if not ex_session:
- session.commit()
- session.refresh(obj)
- return obj
- def delete(self, id: int, ex_session: Optional[Session] = None) -> bool:
- def session_exec(session: Optional[Session],obj_in: SQLModel):
- obj = session.get(self.model, id)
- if not obj:
- return False
- session.delete(obj)
- return obj
- session = ex_session or Session(bind=self.engine)
- obj = session_exec(session,id)
- if not ex_session:
- session.commit()
- return obj
-
- def model_dump_by_field(self, obj_model:SQLModel=None, fields:List[str]=None):
- if not obj_model:
- obj_model = self.model
- if not fields:
- fields = self.non_unique_fields
- data = {}
- for field in fields:
- if hasattr(obj_model, field):
- data[field] = getattr(obj_model, field)
- return data
-
- def set_obj_by_dict(self, target_obj: SQLModel, dict_data: dict):
- for key, value in dict_data.items():
- if hasattr(target_obj, key):
- setattr(target_obj, key, value)
-
-
- def add_or_update_by_unique(self, obj: SQLModel, ex_session: Optional[Session] = None, update_field=Any) -> bool:
- def session_exec(session: Optional[Session],obj_in: SQLModel, update_field=Any):
- table_unique_keys = self.get_unique_constraint_fields()
- existing_obj = self.check_exist(obj_in, check_field=table_unique_keys, ex_session=session)
- # logger.debug(f"existing_obj_by_unique_keys: {existing_obj_by_unique_keys} {type(existing_obj_by_unique_keys)}")
- if existing_obj:
- logger.debug(f"check {obj.__class__} from fields {table_unique_keys} exist {existing_obj} ")
- if isinstance(update_field, self.model):
- not_unique_keys = [k for k in obj.model_dump().keys() if k not in table_unique_keys]
- for key in not_unique_keys:
- setattr(existing_obj, key, getattr(obj, key))
- logger.debug(f"update non unique filed: {not_unique_keys}")
- elif isinstance(update_field, dict):
- for key in update_field.keys():
- setattr(existing_obj, key, getattr(obj, key))
- logger.debug(f"update_field: {update_field.keys()}")
- else:
- logger.debug(" do nothing")
- return existing_obj
- else:
- logger.debug(f"add {obj}")
- session.add(obj)
- return obj
- session = ex_session or Session(bind=self.engine)
- obj = session_exec(session,obj,update_field)
- if not ex_session:
- session.commit()
- return obj
-
- def check_exist(self, obj: SQLModel, check_field:List[str]=None, ex_session=None):
- session = ex_session or Session(bind=self.engine)
- if not check_field:
- check_field = self.get_unique_constraint_fields()
- check_field_dict: Dict[str, Any] = {k: getattr(obj, k) for k in check_field}
- # logger.debug(f"check if unique constraint: {check_field_dict}")
-
- query = select(self.model).where(*[getattr(self.model, k) == v for k, v in check_field_dict.items()])
- existing_obj_by_check_field = session.scalars(query).first()
- return existing_obj_by_check_field
-
-
- def get_unique_constraint_fields(self) -> List[str]:
- # 获取通过 UniqueConstraint 定义的唯一约束字段
- constraints = getattr(self.model.__table__, 'constraints', [])
- unique_constraints_fields = [column.name for c in constraints if isinstance(c, UniqueConstraint) for column in c.columns]
-
- # 获取主键字段
- primary_key_fields = [column.name for c in constraints if isinstance(c, PrimaryKeyConstraint) for column in c.columns]
-
- # 合并唯一约束字段和主键字段,并去重
- # unique_fields = list(set(unique_constraints_fields))
-
- return unique_constraints_fields,primary_key_fields
- def get_non_unique_fields(self) -> List[str]:
- unique_fields,primary_key_fields = self.get_unique_constraint_fields()
- unique_fields.extend(primary_key_fields)
- all_fields = [k for k,v in self.model.model_fields.items()]
- non_unique_fields = [field for field in all_fields if field not in unique_fields]
- return non_unique_fields
- def set_update_time(self, obj: SQLModel):
- if hasattr(self.model, 'update_time'):
- obj.update_time = datetime.datetime.now()
- return obj.update_time
|