| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230 |
- import re
- from typing import List, Any,Dict, Optional,Callable,Union
- import datetime
- import sqlalchemy
- from sqlmodel import SQLModel,UniqueConstraint,Session,select,PrimaryKeyConstraint,Field,Column
- from sqlalchemy.ext.declarative import DeclarativeMeta
- import sqlmodel
- from typing import Optional
- from db.engine import engine
- from config import logger
- from sqlalchemy.sql.elements import BinaryExpression
- from sqlmodel.orm.session import ScalarResult
- def update_from_model(target: SQLModel, source: SQLModel, exclude: list[str] = ["id"]):
- for key, value in source.model_dump().items():
- if key not in exclude:
- setattr(target, key, value)
- class BaseRepository:
- def __init__(self, model: SQLModel, engine=engine):
- self.model = model
- self.engine = engine
- self.primary_key_fields=self.get_primary_key_columns()
- self.unique_constraint_fields = self.get_unique_constraint_columns()
- self.non_unique_fields = self.get_non_unique_columns()
- # logger.debug(f"主键字段:{ self.primary_key_fields}")
- # logger.debug(f"唯一约束字段:{self.unique_constraint_fields}", )
- # logger.debug(f"非唯一约束字段:{self.non_unique_fields}")
- # Usage: BaseRepository().select(Hero.age > 45, Hero.id==5)
- def select(self, *where:BinaryExpression, ex_session: Optional[Session] = None) -> ScalarResult:
- session = ex_session or Session(bind=self.engine)
- statement = select(self.model).where(*where)
- logger.debug(f"{statement}")
- res = session.exec(statement).unique()
- return res
- def create(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> Optional[int]:
- session = ex_session or Session(bind=self.engine)
- session.add(obj_in)
- if not ex_session:
- session.commit()
- return obj_in
- def get(self, id: int, ex_session: Optional[Session] = None) -> Optional[SQLModel]:
- session = ex_session or Session(bind=self.engine)
- return session.get(self.model, id)
- def update(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> bool:
- def session_exec(id: int=None, obj_in: SQLModel=None, ex_session: Optional[Session] = None):
- obj = session.get(self.model, id)
- if not obj:
- return obj_in
- for key, value in obj_in.model_dump().items():
- setattr(obj, key, value)
- return obj
- if not obj_in.id:
- return
- session = ex_session or Session(bind=self.engine)
- obj = session_exec(obj_in.id,obj_in,session)
- if not ex_session:
- session.commit()
- session.refresh(obj)
- return obj
- def delete(self, *where:BinaryExpression, ex_session: Optional[Session] = None):
- session = ex_session or Session(bind=self.engine)
- exec_res = self.select(*where, ex_session=session).all()
- logger.info(f"exec_res: {exec_res}")
- for obj in exec_res:
- logger.info(f"del {obj}")
- session.delete(obj)
- if not ex_session:
- session.commit()
-
- def model_dump_by_field(self, obj_model:SQLModel=None, fields:List[str]=None):
- if not obj_model:
- obj_model = self.model
- if not fields:
- fields = self.non_unique_fields
- data = {}
- for field in fields:
- if hasattr(obj_model, field):
- data[field] = getattr(obj_model, field)
- return data
-
- def set_obj_by_dict(self, target_obj: SQLModel, dict_data: dict):
- for key, value in dict_data.items():
- if hasattr(target_obj, key):
- setattr(target_obj, key, value)
-
- '''
- add or update 只能更新一行
- 如果多行需要自己传入session组包,因为 update 通常意味着搜索唯一字段来更新,绝对不可能搜索age这种字段进行update,那将会产生灾难后果
- 不需要指定更新哪些字段,一般model里面非唯一字段就全部更新,如果指定更新的字段,只能追加唯一字段。(如ip地址,URL地址,路径,镜像名称,端口号,等等)
- 剩下搜索条件 where condition,只能搜唯一字段和主键,不传参默认是 uni 不然你更新没有意义(就像搜索“班级”进行更新一样)。
- input:
- - obj_model: 新的数据 model
- - not_update_fields: 如果 except 报错 UNIQUE constraint failed 说明 model 唯一字段冲突,则更新除了 not_update_fields 的值
- '''
- def add_or_update(self, obj_model:SQLModel, not_update_fields:List[Union[str, Column]]=[]) -> SQLModel:
- with Session(bind=self.engine) as session:
- session.add(obj_model)
- try:
- logger.debug(f"add {obj_model}")
- session.commit()
- session.refresh(obj_model)
- logger.debug(f"refresh {obj_model}")
- return obj_model
- except sqlalchemy.exc.IntegrityError as e:
- conflict_field_name = self.extract_conflict_field(str(e))
- if not conflict_field_name:
- logger.exception(f"Is not UNIQUE constraint error:{e}")
- return
- if not not_update_fields:
- # 如果没有定义不更新的字段,默认不更新唯一字段
- not_update_fields = self.unique_constraint_fields
- else:
- # 如果自定义了不更新的字段,则在这些字段的基础上,添加目前 except 报错冲突的字段
- not_update_fields.append(conflict_field_name)
- logger.debug(f"conflict_field_name: {conflict_field_name}")
- logger.debug(f"not_update_fields: {not_update_fields}")
- session.rollback()
- statement = select(self.model).where(getattr(self.model, conflict_field_name) == getattr(obj_model, conflict_field_name))
- existing_obj:SQLModel = session.exec(statement).one()
- logger.debug(f"old: {obj_model}")
- for attr in obj_model.model_fields:
- if attr not in not_update_fields and getattr(obj_model, attr) is not None:
- setattr(existing_obj, attr, getattr(obj_model, attr))
- session.add(existing_obj)
- logger.debug(f"update: {existing_obj}")
- session.commit()
- session.refresh(existing_obj)
- return existing_obj
- except Exception as e:
- logger.exception(f"other error: {e}")
-
- def extract_conflict_field(self, error_message):
- '''error_message:
- 使用 SQLite 报错如下:
- sqlalchemy.exc.IntegrityError: (sqlite3.IntegrityError) UNIQUE constraint failed: natmodel.pid
- [SQL: INSERT INTO natmodel ...
- 或者使用 psycopg2 驱动 PostgreSQL 时报错如下:
- (psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint "comment_from_user_id_key"
- DETAIL: Key (from_user_id)=() already exists.
- '''
- # 示例正则表达式,需要根据实际错误消息格式调整
- match = re.search(r"unique constraint: (\w+)\.(\w+)", error_message, re.I)
- if match:
- return match.group(2)
- match_pg = re.search(r'Key \(([^)]+)\)=.*already exists', error_message, re.I)
- if match_pg:
- return match_pg.group(1) # 返回 PostgreSQL 的字段名
- return None
- def check_exist(self, obj: SQLModel, check_field:List[str]=None, ex_session=None):
- session = ex_session or Session(bind=self.engine)
- if not check_field:
- check_field = self.get_unique_constraint_fields()
- check_field_dict: Dict[str, Any] = {k: getattr(obj, k) for k in check_field}
- # logger.debug(f"check if unique constraint: {check_field_dict}")
-
- query = select(self.model).where(*[getattr(self.model, k) == v for k, v in check_field_dict.items()])
- existing_obj_by_check_field = session.scalars(query).first()
- return existing_obj_by_check_field
-
- def _get_column_names(self, column_list: list) -> list:
- return [col.name for col in column_list]
- def get_unique_constraint_columns(self) -> list:
- constraints = getattr(self.model.__table__, 'constraints', [])
- unique_constraints = [c for c in constraints if isinstance(c, UniqueConstraint)]
- return self._get_column_names([column for constraint in unique_constraints for column in constraint.columns])
- def get_primary_key_columns(self) -> list:
- primary_key_constraint = next((c for c in self.model.__table__.constraints if isinstance(c, PrimaryKeyConstraint)), None)
- return self._get_column_names(primary_key_constraint.columns) if primary_key_constraint else []
- def get_non_unique_columns(self) -> list:
- unique_and_primary = self.unique_constraint_fields + self.get_primary_key_columns()
- all_fields = [field_name for field_name in self.model.model_fields.keys()]
- return [field for field in all_fields if field not in unique_and_primary]
-
- def set_update_time(self, obj: SQLModel):
- if hasattr(self.model, 'update_time'):
- obj.update_time = datetime.datetime.now()
- return obj.update_time
- class DouyinBaseRepository(BaseRepository):
- def __init__(self, model: SQLModel, engine=engine):
- super().__init__(model, engine)
- def dict_to_model(self, dict_data: dict, model=None) -> SQLModel:
- if not model:
- model = self.model
- clean_data = {k: v for k, v in dict_data.items() if hasattr(model, k)}
- obj_model = model(**clean_data)
- return obj_model
-
- def get_by_open_id(self, open_id):
- with Session(self.engine) as session:
- logger.debug(f"get {open_id}")
- base_statement = select(self.model).where(self.model.open_id == open_id)
- results = session.exec(base_statement)
- return results.first()
- def main():
- class Hero(SQLModel, table=True):
- id: Optional[int] = Field(default=None, primary_key=True)
- name: str
- secret_name: Optional[str]
- age: Optional[int] = None
- class_num: Optional[str]
- student_id:Optional[int]
- __table_args__ = (UniqueConstraint('student_id', name='uq_open_id_ctname'),)
- h= Hero(name="123", student_id=555)
- self = BaseRepository(Hero, engine)
- logger.info(f"get_non_unique_fields:{self.get_primary_key_columns()}")
- logger.info(f"get_primary_key_fields.columns:{self.get_non_unique_columns()}")
- logger.info(f"get_unique_constraint_fields:{self.get_unique_constraint_columns()}")
- # logger.info(b.get_primary_key_fields())
- # logger.info(b.get_non_unique_fields())
- # logger.info(b.get_unique_constraint_fields())
- # logger.info(f"b.model.model_fields:{b.model.model_fields} {type(b.model.model_fields.get('id'))}")
- # logger.info(f"b.model.model_fields:{Hero.model_validate(h)}")
-
- if __name__ == "__main__":
- main()
|