base.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from typing import List, Any,Dict, Optional,Callable
  2. import datetime
  3. from sqlmodel import SQLModel,UniqueConstraint,Session,select,PrimaryKeyConstraint
  4. from sqlalchemy.ext.declarative import DeclarativeMeta
  5. import sqlmodel
  6. from typing import Optional
  7. from db.engine import engine
  8. from config import logger
  9. class BaseRepository:
  10. def __init__(self, model: SQLModel, engine=engine):
  11. self.model = model
  12. self.engine = engine
  13. self.unique_constraint_fields,self.primary_key_fields = self.get_unique_constraint_fields()
  14. self.non_unique_fields = self.get_non_unique_fields()
  15. # logger.debug(f"主键字段:{ self.primary_key_fields}")
  16. # logger.debug(f"唯一约束字段:{self.unique_constraint_fields}", )
  17. # logger.debug(f"非唯一约束字段:{self.non_unique_fields}")
  18. def create(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> Optional[int]:
  19. session = ex_session or Session(bind=self.engine)
  20. session.add(obj_in)
  21. if not ex_session:
  22. session.commit()
  23. return obj_in
  24. def get(self, id: int, ex_session: Optional[Session] = None) -> Optional[SQLModel]:
  25. session = ex_session or Session(bind=self.engine)
  26. return session.get(self.model, id)
  27. def update(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> bool:
  28. def session_exec(id: int=None, obj_in: SQLModel=None, ex_session: Optional[Session] = None):
  29. obj = session.get(self.model, id)
  30. if not obj:
  31. return obj_in
  32. for key, value in obj_in.model_dump().items():
  33. setattr(obj, key, value)
  34. return obj
  35. if not obj_in.id:
  36. return
  37. session = ex_session or Session(bind=self.engine)
  38. obj = session_exec(obj_in.id,obj_in,session)
  39. if not ex_session:
  40. session.commit()
  41. session.refresh(obj)
  42. return obj
  43. def delete(self, id: int, ex_session: Optional[Session] = None) -> bool:
  44. def session_exec(session: Optional[Session],obj_in: SQLModel):
  45. obj = session.get(self.model, id)
  46. if not obj:
  47. return False
  48. session.delete(obj)
  49. return obj
  50. session = ex_session or Session(bind=self.engine)
  51. obj = session_exec(session,id)
  52. if not ex_session:
  53. session.commit()
  54. return obj
  55. def model_dump_by_field(self, obj_model:SQLModel=None, fields:List[str]=None):
  56. if not obj_model:
  57. obj_model = self.model
  58. if not fields:
  59. fields = self.non_unique_fields
  60. data = {}
  61. for field in fields:
  62. if hasattr(obj_model, field):
  63. data[field] = getattr(obj_model, field)
  64. return data
  65. def set_obj_by_dict(self, target_obj: SQLModel, dict_data: dict):
  66. for key, value in dict_data.items():
  67. if hasattr(target_obj, key):
  68. setattr(target_obj, key, value)
  69. def add_or_update_by_unique(self, obj: SQLModel, ex_session: Optional[Session] = None, update_field=Any) -> bool:
  70. def session_exec(session: Optional[Session],obj_in: SQLModel, update_field=Any):
  71. table_unique_keys = self.get_unique_constraint_fields()
  72. existing_obj = self.check_exist(obj_in, check_field=table_unique_keys, ex_session=session)
  73. # logger.debug(f"existing_obj_by_unique_keys: {existing_obj_by_unique_keys} {type(existing_obj_by_unique_keys)}")
  74. if existing_obj:
  75. logger.debug(f"check {obj.__class__} from fields {table_unique_keys} exist {existing_obj} ")
  76. if isinstance(update_field, self.model):
  77. not_unique_keys = [k for k in obj.model_dump().keys() if k not in table_unique_keys]
  78. for key in not_unique_keys:
  79. setattr(existing_obj, key, getattr(obj, key))
  80. logger.debug(f"update non unique filed: {not_unique_keys}")
  81. elif isinstance(update_field, dict):
  82. for key in update_field.keys():
  83. setattr(existing_obj, key, getattr(obj, key))
  84. logger.debug(f"update_field: {update_field.keys()}")
  85. else:
  86. logger.debug(" do nothing")
  87. return existing_obj
  88. else:
  89. logger.debug(f"add {obj}")
  90. session.add(obj)
  91. return obj
  92. session = ex_session or Session(bind=self.engine)
  93. obj = session_exec(session,obj,update_field)
  94. if not ex_session:
  95. session.commit()
  96. return obj
  97. def check_exist(self, obj: SQLModel, check_field:List[str]=None, ex_session=None):
  98. session = ex_session or Session(bind=self.engine)
  99. if not check_field:
  100. check_field = self.get_unique_constraint_fields()
  101. check_field_dict: Dict[str, Any] = {k: getattr(obj, k) for k in check_field}
  102. # logger.debug(f"check if unique constraint: {check_field_dict}")
  103. query = select(self.model).where(*[getattr(self.model, k) == v for k, v in check_field_dict.items()])
  104. existing_obj_by_check_field = session.scalars(query).first()
  105. return existing_obj_by_check_field
  106. def get_unique_constraint_fields(self) -> List[str]:
  107. # 获取通过 UniqueConstraint 定义的唯一约束字段
  108. constraints = getattr(self.model.__table__, 'constraints', [])
  109. unique_constraints_fields = [column.name for c in constraints if isinstance(c, UniqueConstraint) for column in c.columns]
  110. # 获取主键字段
  111. primary_key_fields = [column.name for c in constraints if isinstance(c, PrimaryKeyConstraint) for column in c.columns]
  112. # 合并唯一约束字段和主键字段,并去重
  113. # unique_fields = list(set(unique_constraints_fields))
  114. return unique_constraints_fields,primary_key_fields
  115. def get_non_unique_fields(self) -> List[str]:
  116. unique_fields,primary_key_fields = self.get_unique_constraint_fields()
  117. unique_fields.extend(primary_key_fields)
  118. all_fields = [k for k,v in self.model.model_fields.items()]
  119. non_unique_fields = [field for field in all_fields if field not in unique_fields]
  120. return non_unique_fields
  121. def set_update_time(self, obj: SQLModel):
  122. if hasattr(self.model, 'update_time'):
  123. obj.update_time = datetime.datetime.now()
  124. return obj.update_time