base.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from typing import List, Any,Dict, Optional,Callable
  2. import datetime
  3. from sqlmodel import SQLModel,UniqueConstraint,Session,select,PrimaryKeyConstraint
  4. from sqlalchemy.ext.declarative import DeclarativeMeta
  5. import sqlmodel
  6. from typing import Optional
  7. from db.engine import engine
  8. from config import logger
  9. class BaseRepository:
  10. def __init__(self, model: SQLModel, engine=engine):
  11. self.model = model
  12. self.engine = engine
  13. self.unique_constraint_fields,self.primary_key_fields = self.get_unique_constraint_fields()
  14. self.non_unique_fields = self.get_non_unique_fields()
  15. # logger.debug(f"主键字段:{ self.primary_key_fields}")
  16. # logger.debug(f"唯一约束字段:{self.unique_constraint_fields}", )
  17. # logger.debug(f"非唯一约束字段:{self.non_unique_fields}")
  18. def create(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> Optional[int]:
  19. session = ex_session or Session(bind=self.engine)
  20. session.add(obj_in)
  21. if not ex_session:
  22. session.commit()
  23. return obj_in
  24. def get(self, id: int, ex_session: Optional[Session] = None) -> Optional[SQLModel]:
  25. session = ex_session or Session(bind=self.engine)
  26. return session.get(self.model, id)
  27. def update(self, id: int, obj_in: SQLModel, ex_session: Optional[Session] = None) -> bool:
  28. def session_exec(session,obj_in: SQLModel):
  29. obj = session.get(self.model, id)
  30. if not obj:
  31. return False
  32. for key, value in obj_in.model_dump().items():
  33. setattr(obj, key, value)
  34. session = ex_session or Session(bind=self.engine)
  35. session_exec(session,obj_in)
  36. if not ex_session:
  37. session.commit()
  38. return obj_in
  39. def delete(self, id: int, ex_session: Optional[Session] = None) -> bool:
  40. def session_exec(session: Optional[Session],obj_in: SQLModel):
  41. obj = session.get(self.model, id)
  42. if not obj:
  43. return False
  44. session.delete(obj)
  45. return obj
  46. session = ex_session or Session(bind=self.engine)
  47. obj = session_exec(session,id)
  48. if not ex_session:
  49. session.commit()
  50. return obj
  51. def model_dump_by_field(self, obj_model:SQLModel=None, fields:List[str]=None):
  52. if not obj_model:
  53. obj_model = self.model
  54. if not fields:
  55. fields = self.non_unique_fields
  56. data = {}
  57. for field in fields:
  58. if hasattr(obj_model, field):
  59. data[field] = getattr(obj_model, field)
  60. return data
  61. def set_obj_by_dict(self, target_obj: SQLModel, dict_data: dict):
  62. for key, value in dict_data.items():
  63. if hasattr(target_obj, key):
  64. setattr(target_obj, key, value)
  65. def add_or_update_by_unique(self, obj: SQLModel, ex_session: Optional[Session] = None, update_field=Any) -> bool:
  66. def session_exec(session: Optional[Session],obj_in: SQLModel, update_field=Any):
  67. table_unique_keys = self.get_unique_constraint_fields()
  68. existing_obj = self.check_exist(obj_in, check_field=table_unique_keys, ex_session=session)
  69. # logger.debug(f"existing_obj_by_unique_keys: {existing_obj_by_unique_keys} {type(existing_obj_by_unique_keys)}")
  70. if existing_obj:
  71. logger.debug(f"check {obj.__class__} from fields {table_unique_keys} exist {existing_obj} ")
  72. if isinstance(update_field, self.model):
  73. not_unique_keys = [k for k in obj.model_dump().keys() if k not in table_unique_keys]
  74. for key in not_unique_keys:
  75. setattr(existing_obj, key, getattr(obj, key))
  76. logger.debug(f"update non unique filed: {not_unique_keys}")
  77. elif isinstance(update_field, dict):
  78. for key in update_field.keys():
  79. setattr(existing_obj, key, getattr(obj, key))
  80. logger.debug(f"update_field: {update_field.keys()}")
  81. else:
  82. logger.debug(" do nothing")
  83. return existing_obj
  84. else:
  85. logger.debug(f"add {obj}")
  86. session.add(obj)
  87. return obj
  88. session = ex_session or Session(bind=self.engine)
  89. obj = session_exec(session,obj,update_field)
  90. if not ex_session:
  91. session.commit()
  92. return obj
  93. def check_exist(self, obj: SQLModel, check_field=None, ex_session=None):
  94. session = ex_session or Session(bind=self.engine)
  95. if not check_field:
  96. check_field = self.get_unique_constraint_fields()
  97. check_field_dict: Dict[str, Any] = {k: getattr(obj, k) for k in check_field}
  98. # logger.debug(f"check if unique constraint: {check_field_dict}")
  99. query = select(self.model).where(*[getattr(self.model, k) == v for k, v in check_field_dict.items()])
  100. existing_obj_by_check_field = session.scalars(query).first()
  101. return existing_obj_by_check_field
  102. def get_unique_constraint_fields(self) -> List[str]:
  103. # 获取通过 UniqueConstraint 定义的唯一约束字段
  104. constraints = getattr(self.model.__table__, 'constraints', [])
  105. unique_constraints_fields = [column.name for c in constraints if isinstance(c, UniqueConstraint) for column in c.columns]
  106. # 获取主键字段
  107. primary_key_fields = [column.name for c in constraints if isinstance(c, PrimaryKeyConstraint) for column in c.columns]
  108. # 合并唯一约束字段和主键字段,并去重
  109. # unique_fields = list(set(unique_constraints_fields))
  110. return unique_constraints_fields,primary_key_fields
  111. def get_non_unique_fields(self) -> List[str]:
  112. unique_fields,primary_key_fields = self.get_unique_constraint_fields()
  113. unique_fields.extend(primary_key_fields)
  114. all_fields = [k for k,v in self.model.model_fields.items()]
  115. non_unique_fields = [field for field in all_fields if field not in unique_fields]
  116. return non_unique_fields
  117. def set_update_time(self, obj: SQLModel):
  118. if hasattr(self.model, 'update_time'):
  119. obj.update_time = datetime.datetime.now()
  120. return obj.update_time