base.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from typing import List, Any
  2. from sqlmodel import SQLModel
  3. from sqlalchemy.orm import sessionmaker
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from sqlalchemy.ext.declarative import DeclarativeMeta
  6. from sqlalchemy.sql.expression import select
  7. from sqlalchemy.dialects.postgresql import insert
  8. from sqlalchemy.exc import IntegrityError
  9. from sqlalchemy import UniqueConstraint
  10. from sqlalchemy.sql import func
  11. from db.engine import engine
  12. from config import logger
  13. class BaseRepository:
  14. def __init__(self, model: SQLModel, engine=engine):
  15. self.model = model
  16. self.engine = engine
  17. self.session_factory = sessionmaker(
  18. bind=engine, class_=AsyncSession, expire_on_commit=False
  19. )
  20. '''
  21. input:
  22. - instances: SQLModel 定义的 class 实例
  23. return : instances
  24. '''
  25. async def aadd(self, instances: List[SQLModel]):
  26. if not isinstance(instances, list):
  27. instances = [instances]
  28. async with self.session_factory() as session:
  29. session.add_all(instances)
  30. await session.commit()
  31. return instances
  32. '''
  33. input:
  34. - instances: SQLModel 定义的字段名
  35. return : instances:SQLModel
  36. '''
  37. async def aget(self, **kwargs):
  38. async with self.session_factory() as session:
  39. result = await session.get(self.model, **kwargs)
  40. return result
  41. async def aget_all(self):
  42. async with self.session_factory() as session:
  43. result = await session.execute(select(self.model))
  44. return result.scalars().all()
  45. async def aon_conflict_do_nothing(self, instances: SQLModel|List[SQLModel]) -> SQLModel|List[SQLModel]:
  46. async def exec_one(session, instance,index_elements):
  47. data = instance.model_dump()
  48. stmt = insert(self.model).values(**data).on_conflict_do_nothing(
  49. index_elements=index_elements # 对应联合唯一约束的列
  50. ).returning(self.model.id)
  51. res = await session.execute(stmt)
  52. instance.id = res.scalar()
  53. index_elements = self._get_unique_constraint_fields()
  54. async with self.session_factory() as session:
  55. if not isinstance(instances, list):
  56. await exec_one(session, instances, index_elements)
  57. else:
  58. for instance in instances:
  59. await exec_one(session, instance, index_elements)
  60. await session.commit()
  61. return instances
  62. async def aon_conflict_do_update(
  63. self,
  64. instances: SQLModel | List[SQLModel],
  65. update_fields: List[str], # 需要更新的字段列表
  66. ) -> SQLModel | List[SQLModel]:
  67. async def exec_one(session, instance, index_elements):
  68. data = instance.model_dump()
  69. stmt = insert(self.model).values(**data).on_conflict_do_update(
  70. constraint=index_elements, # 对应联合唯一约束
  71. set_=dict((k, data[k]) for k in update_fields if k in data), # 更新指定字段
  72. ).returning(self.model.id)
  73. res = await session.execute(stmt)
  74. instance.id = res.scalar()
  75. index_elements = self._get_unique_constraint_fields()
  76. async with self.session_factory() as session:
  77. if not isinstance(instances, list):
  78. await exec_one(session, instances, index_elements)
  79. else:
  80. for instance in instances:
  81. await exec_one(session, instance, index_elements)
  82. await session.commit()
  83. return instances
  84. def _get_unique_constraint_fields(self) -> List[str]:
  85. constraints = getattr(self.model.__table__, 'constraints', [])
  86. unique_constraints = [c for c in constraints if isinstance(c, UniqueConstraint)]
  87. index_elements = []
  88. for uc in unique_constraints:
  89. index_elements.extend([column.name for column in uc.columns])
  90. return list(set(index_elements)) # 去除重复字段
  91. class DouyinBaseRepository(BaseRepository):
  92. def __init__(self, model: DeclarativeMeta, engine=engine):
  93. super().__init__(model, engine)
  94. '''
  95. input: open_id:str
  96. return : SQL row
  97. '''
  98. async def get_by_open_id(self, open_id):
  99. async with self.session_factory() as session:
  100. stmt = select(self.model).where(self.model.open_id == open_id)
  101. result = await session.execute(stmt)
  102. user_info = result.scalars().first() # 获取查询结果的第一个记录,如果没有找到则返回 None
  103. return user_info
  104. '''
  105. input:
  106. - data:dict 通常抖音返回数据是json格式,因此这里也用字典传参类型,如果是 SQLmodel 会自动用 data.model_dump() 方法转化成字典
  107. - constraint_name 字段唯一值,如果 data 所含的字段存在于数据库则更新该行
  108. return : res
  109. '''
  110. async def aadd_or_update(self, data: dict):
  111. if type(data) == self.model:
  112. data = data.model_dump()
  113. try:
  114. index_elements = self._get_unique_constraint_fields()
  115. async with self.session_factory() as session:
  116. # 只获取 self.model 定义的字段
  117. clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)}
  118. # logger.debug(f"clean data:{clean_data} from {self.model}")
  119. # 如果 self.model 中有 update_time 字段,则设置其为当前时间
  120. if hasattr(self.model, 'update_time'):
  121. clean_data['update_time'] = func.now()
  122. # 构建 SQL 语句,实现插入或更新
  123. insert_stmt = insert(self.model).values(**clean_data)
  124. update_stmt = insert_stmt.on_conflict_do_update(
  125. index_elements=index_elements,
  126. set_={k: clean_data[k] for k in clean_data if k not in index_elements}
  127. ).returning(self.model.id)
  128. result = await session.execute(update_stmt)
  129. new_id = result.scalar()
  130. await session.commit()
  131. return new_id
  132. except IntegrityError as e:
  133. logger.exception(f"IntegrityError occurred: {e}")
  134. # 如果需要,可以在这里做更多的错误处理,比如回滚事务等。
  135. # 但注意,由于使用了async with,session在退出with块时通常会自动回滚未提交的事务。
  136. except Exception as e:
  137. # 捕获其他类型的异常
  138. logger.exception(f"An unexpected error occurred: {e}")
  139. raise # 如果需要,可以重新抛出异常