base.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import re
  2. from typing import List, Any,Dict, Optional,Callable,Union
  3. import datetime
  4. import sqlalchemy
  5. from sqlmodel import SQLModel,UniqueConstraint,Session,select,PrimaryKeyConstraint,Field,Column
  6. from sqlalchemy.ext.declarative import DeclarativeMeta
  7. import sqlmodel
  8. from typing import Optional
  9. from db.engine import engine
  10. from config import logger
  11. from sqlalchemy.sql.elements import BinaryExpression
  12. from sqlmodel.orm.session import ScalarResult
  13. def update_from_model(target: SQLModel, source: SQLModel, exclude: list[str] = ["id"]):
  14. for key, value in source.model_dump().items():
  15. if key not in exclude:
  16. setattr(target, key, value)
  17. class BaseRepository:
  18. def __init__(self, model: SQLModel, engine=engine):
  19. self.model = model
  20. self.engine = engine
  21. self.primary_key_fields=self.get_primary_key_columns()
  22. self.unique_constraint_fields = self.get_unique_constraint_columns()
  23. self.non_unique_fields = self.get_non_unique_columns()
  24. # logger.debug(f"主键字段:{ self.primary_key_fields}")
  25. # logger.debug(f"唯一约束字段:{self.unique_constraint_fields}", )
  26. # logger.debug(f"非唯一约束字段:{self.non_unique_fields}")
  27. # Usage: BaseRepository().select(Hero.age > 45, Hero.id==5)
  28. def select(self, *where:BinaryExpression, ex_session: Optional[Session] = None) -> ScalarResult:
  29. session = ex_session or Session(bind=self.engine)
  30. statement = select(self.model).where(*where)
  31. logger.debug(f"{statement}")
  32. res = session.exec(statement).unique()
  33. return res
  34. def create(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> Optional[int]:
  35. session = ex_session or Session(bind=self.engine)
  36. session.add(obj_in)
  37. if not ex_session:
  38. session.commit()
  39. return obj_in
  40. def get(self, id: int, ex_session: Optional[Session] = None) -> Optional[SQLModel]:
  41. session = ex_session or Session(bind=self.engine)
  42. return session.get(self.model, id)
  43. def update(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> bool:
  44. def session_exec(id: int=None, obj_in: SQLModel=None, ex_session: Optional[Session] = None):
  45. obj = session.get(self.model, id)
  46. if not obj:
  47. return obj_in
  48. for key, value in obj_in.model_dump().items():
  49. setattr(obj, key, value)
  50. return obj
  51. if not obj_in.id:
  52. return
  53. session = ex_session or Session(bind=self.engine)
  54. obj = session_exec(obj_in.id,obj_in,session)
  55. if not ex_session:
  56. session.commit()
  57. session.refresh(obj)
  58. return obj
  59. def delete(self, *where:BinaryExpression, ex_session: Optional[Session] = None):
  60. session = ex_session or Session(bind=self.engine)
  61. exec_res = self.select(*where, ex_session=session).all()
  62. logger.info(f"exec_res: {exec_res}")
  63. for obj in exec_res:
  64. logger.info(f"del {obj}")
  65. session.delete(obj)
  66. if not ex_session:
  67. session.commit()
  68. def model_dump_by_field(self, obj_model:SQLModel=None, fields:List[str]=None):
  69. if not obj_model:
  70. obj_model = self.model
  71. if not fields:
  72. fields = self.non_unique_fields
  73. data = {}
  74. for field in fields:
  75. if hasattr(obj_model, field):
  76. data[field] = getattr(obj_model, field)
  77. return data
  78. def set_obj_by_dict(self, target_obj: SQLModel, dict_data: dict):
  79. for key, value in dict_data.items():
  80. if hasattr(target_obj, key):
  81. setattr(target_obj, key, value)
  82. '''
  83. add or update 只能更新一行
  84. 如果多行需要自己传入session组包,因为 update 通常意味着搜索唯一字段来更新,绝对不可能搜索age这种字段进行update,那将会产生灾难后果
  85. 不需要指定更新哪些字段,一般model里面非唯一字段就全部更新,如果指定更新的字段,只能追加唯一字段。(如ip地址,URL地址,路径,镜像名称,端口号,等等)
  86. 剩下搜索条件 where condition,只能搜唯一字段和主键,不传参默认是 uni 不然你更新没有意义(就像搜索“班级”进行更新一样)。
  87. input:
  88. - obj_model: 新的数据 model
  89. - not_update_fields: 如果 except 报错 UNIQUE constraint failed 说明 model 唯一字段冲突,则更新除了 not_update_fields 的值
  90. '''
  91. def add_or_update(self, obj_model:SQLModel, not_update_fields:List[Union[str, Column]]=[]) -> SQLModel:
  92. with Session(bind=self.engine) as session:
  93. session.add(obj_model)
  94. try:
  95. logger.debug(f"add {obj_model}")
  96. session.commit()
  97. session.refresh(obj_model)
  98. logger.debug(f"refresh {obj_model}")
  99. return obj_model
  100. except sqlalchemy.exc.IntegrityError as e:
  101. conflict_field_name = self.extract_conflict_field(str(e))
  102. if not conflict_field_name:
  103. logger.exception(f"Is not UNIQUE constraint error:{e}")
  104. return
  105. if not not_update_fields:
  106. # 如果没有定义不更新的字段,默认不更新唯一字段
  107. not_update_fields = self.unique_constraint_fields
  108. else:
  109. # 如果自定义了不更新的字段,则在这些字段的基础上,添加目前 except 报错冲突的字段
  110. not_update_fields.append(conflict_field_name)
  111. logger.debug(f"conflict_field_name: {conflict_field_name}")
  112. logger.debug(f"not_update_fields: {not_update_fields}")
  113. session.rollback()
  114. statement = select(self.model).where(getattr(self.model, conflict_field_name) == getattr(obj_model, conflict_field_name))
  115. existing_obj:SQLModel = session.exec(statement).one()
  116. logger.debug(f"old: {obj_model}")
  117. for attr in obj_model.model_fields:
  118. if attr not in not_update_fields and getattr(obj_model, attr) is not None:
  119. setattr(existing_obj, attr, getattr(obj_model, attr))
  120. session.add(existing_obj)
  121. logger.debug(f"update: {existing_obj}")
  122. session.commit()
  123. session.refresh(existing_obj)
  124. return existing_obj
  125. except Exception as e:
  126. logger.exception(f"other error: {e}")
  127. def extract_conflict_field(self, error_message):
  128. '''error_message:
  129. 使用 SQLite 报错如下:
  130. sqlalchemy.exc.IntegrityError: (sqlite3.IntegrityError) UNIQUE constraint failed: natmodel.pid
  131. [SQL: INSERT INTO natmodel ...
  132. 或者使用 psycopg2 驱动 PostgreSQL 时报错如下:
  133. (psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint "comment_from_user_id_key"
  134. DETAIL: Key (from_user_id)=() already exists.
  135. '''
  136. # 示例正则表达式,需要根据实际错误消息格式调整
  137. match = re.search(r"unique constraint: (\w+)\.(\w+)", error_message, re.I)
  138. if match:
  139. return match.group(2)
  140. match_pg = re.search(r'Key \(([^)]+)\)=.*already exists', error_message, re.I)
  141. if match_pg:
  142. return match_pg.group(1) # 返回 PostgreSQL 的字段名
  143. return None
  144. def check_exist(self, obj: SQLModel, check_field:List[str]=None, ex_session=None):
  145. session = ex_session or Session(bind=self.engine)
  146. if not check_field:
  147. check_field = self.get_unique_constraint_fields()
  148. check_field_dict: Dict[str, Any] = {k: getattr(obj, k) for k in check_field}
  149. # logger.debug(f"check if unique constraint: {check_field_dict}")
  150. query = select(self.model).where(*[getattr(self.model, k) == v for k, v in check_field_dict.items()])
  151. existing_obj_by_check_field = session.scalars(query).first()
  152. return existing_obj_by_check_field
  153. def _get_column_names(self, column_list: list) -> list:
  154. return [col.name for col in column_list]
  155. def get_unique_constraint_columns(self) -> list:
  156. constraints = getattr(self.model.__table__, 'constraints', [])
  157. unique_constraints = [c for c in constraints if isinstance(c, UniqueConstraint)]
  158. return self._get_column_names([column for constraint in unique_constraints for column in constraint.columns])
  159. def get_primary_key_columns(self) -> list:
  160. primary_key_constraint = next((c for c in self.model.__table__.constraints if isinstance(c, PrimaryKeyConstraint)), None)
  161. return self._get_column_names(primary_key_constraint.columns) if primary_key_constraint else []
  162. def get_non_unique_columns(self) -> list:
  163. unique_and_primary = self.unique_constraint_fields + self.get_primary_key_columns()
  164. all_fields = [field_name for field_name in self.model.model_fields.keys()]
  165. return [field for field in all_fields if field not in unique_and_primary]
  166. def set_update_time(self, obj: SQLModel):
  167. if hasattr(self.model, 'update_time'):
  168. obj.update_time = datetime.datetime.now()
  169. return obj.update_time
  170. class DouyinBaseRepository(BaseRepository):
  171. def __init__(self, model: SQLModel, engine=engine):
  172. super().__init__(model, engine)
  173. def dict_to_model(self, dict_data: dict, model=None) -> SQLModel:
  174. if not model:
  175. model = self.model
  176. clean_data = {k: v for k, v in dict_data.items() if hasattr(model, k)}
  177. obj_model = model(**clean_data)
  178. return obj_model
  179. def get_by_open_id(self, open_id):
  180. with Session(self.engine) as session:
  181. logger.debug(f"get {open_id}")
  182. base_statement = select(self.model).where(self.model.open_id == open_id)
  183. results = session.exec(base_statement)
  184. return results.first()
  185. def main():
  186. class Hero(SQLModel, table=True):
  187. id: Optional[int] = Field(default=None, primary_key=True)
  188. name: str
  189. secret_name: Optional[str]
  190. age: Optional[int] = None
  191. class_num: Optional[str]
  192. student_id:Optional[int]
  193. __table_args__ = (UniqueConstraint('student_id', name='uq_open_id_ctname'),)
  194. h= Hero(name="123", student_id=555)
  195. self = BaseRepository(Hero, engine)
  196. logger.info(f"get_non_unique_fields:{self.get_primary_key_columns()}")
  197. logger.info(f"get_primary_key_fields.columns:{self.get_non_unique_columns()}")
  198. logger.info(f"get_unique_constraint_fields:{self.get_unique_constraint_columns()}")
  199. # logger.info(b.get_primary_key_fields())
  200. # logger.info(b.get_non_unique_fields())
  201. # logger.info(b.get_unique_constraint_fields())
  202. # logger.info(f"b.model.model_fields:{b.model.model_fields} {type(b.model.model_fields.get('id'))}")
  203. # logger.info(f"b.model.model_fields:{Hero.model_validate(h)}")
  204. if __name__ == "__main__":
  205. main()