Quellcode durchsuchen

同步方式:完成用户表和文档表的增删改查

qyl vor 2 Jahren
Ursprung
Commit
221a65d5be
5 geänderte Dateien mit 305 neuen und 319 gelöschten Zeilen
  1. 1 1
      config.py
  2. 127 205
      db/base.py
  3. 104 65
      db/docs.py
  4. 3 8
      db/engine.py
  5. 70 40
      db/user.py

+ 1 - 1
config.py

@@ -19,7 +19,7 @@ logger.info(f"load config:{ __file__}")
 
 # openssl rand -hex 32
 os.environ["SECRET_KEY"]="34581f02dcdbab9dc176d6bb578fb15cc6b8e66159865c14e1fc81cd1d92c2a6"
-os.environ["DB_URL"]="postgresql+asyncpg://pg:pg@sv-v:5432/douyin"
+os.environ["DB_URL"]="postgresql://pg:pg@sv-v:5432/douyin"
 os.environ["CLIENT_KEY"] = 'aw6aipmfdtplwtyq'
 os.environ["CLIENT_SECRET"] = '53cf3dcd2663629e8a773ab59df0968b'
 DOUYIN_OPEN_API="https://open.douyin.com"

+ 127 - 205
db/base.py

@@ -1,218 +1,140 @@
-from typing import List, Any,Optional,Callable
+from typing import List, Any,Dict, Optional,Callable
 import datetime 
-from sqlmodel import SQLModel
-from sqlalchemy.orm import sessionmaker 
-from sqlalchemy.ext.asyncio import AsyncSession
+from sqlmodel import SQLModel,UniqueConstraint,Session,select,PrimaryKeyConstraint
 from sqlalchemy.ext.declarative import DeclarativeMeta  
-from sqlalchemy.sql.expression import select
-from sqlalchemy.dialects.postgresql import insert
-from sqlalchemy.exc import IntegrityError
-from sqlalchemy import UniqueConstraint
-from sqlalchemy.sql import func
 import sqlmodel
+from typing import Optional
 from db.engine import engine
 from config import logger
 
-class BaseRepository:  
-    def __init__(self, model: SQLModel, engine=engine):  
-        self.model = model  
-        self.engine = engine  
-        self.session_factory = sessionmaker(  
-            bind=engine, class_=AsyncSession, expire_on_commit=False  
-        )  
-    
-    '''
-    input: 
-      - instances: SQLModel 定义的 class 实例
-    return : instances
-    '''
-    async def aadd(self, instances: List[SQLModel]): 
-        if not isinstance(instances, list):  
-            instances = [instances]  
-        async with self.session_factory() as session:  
-            session.add_all(instances)  
-            await session.commit()  
-            return instances  
-  
-    '''
-    input: 
-      - instances: SQLModel 定义的字段名
-    return : instances:SQLModel
-    '''
-    async def aget(self, **kwargs):  
-        async with self.session_factory() as session:  
-            result = await session.get(self.model, **kwargs)  
-            return result  
-  
-    async def aget_all(self):  
-        async with self.session_factory() as session:  
-            result = await session.execute(select(self.model))  
-            return result.scalars().all() 
-    
-    '''
-    根据 instance 模型判断哪些是唯一约束,执行 on_conflict_insert 事务
-    '''
-    async def aexec_inset_on_conflict(self, 
-                                    instance:SQLModel, 
-                                    session: Optional[AsyncSession],
-                                    update_field:Any) -> SQLModel:
-        # index_elements 获取 self.model 具有唯一属性的字段
-        index_elements = self._get_unique_constraint_fields()
-        instance_dict = instance.model_dump()
-        if isinstance(update_field,self.model):
-            set_values = {key:instance_dict[key] for key in instance_dict if key not in index_elements}
-            stmt = insert(self.model).values(**instance_dict).on_conflict_do_update(
-                index_elements=index_elements,
-                set_=set_values,
-            )
-        elif isinstance(update_field, dict):
-            set_values = {key:update_field[key] for key in update_field if key not in index_elements}
-            logger.debug(f"{index_elements}  {set_values}")
-            stmt = insert(self.model).values(**update_field).on_conflict_do_update(
-                index_elements=index_elements,
-                set_=set_values,
-            )
-        else:
-            stmt = insert(self.model).values(**instance_dict).on_conflict_do_nothing(
-                index_elements=index_elements,
-            )
-        stmt = stmt.returning(self.model.id)
-        res = await session.execute(stmt)
-        # 如果 id 不为 None,说明插入或更新了数据,否则
-        instance.id = res.scalar()
-        # logger.debug(f'{self.model} {"on_conflict_do_update" if update_fields else "on_conflict_do_nothing"}')
-        return instance
-            
-    '''
-    判断是否为列表,自动添加到事务 session.execute(stmt)
-    '''
-    async def aexec_instances_if_list(self, 
-                                instances: SQLModel | List[SQLModel], 
-                                session: Optional[AsyncSession],
-                                update_field:Any,
-                                ) -> SQLModel | List[SQLModel]:
-        if not isinstance(instances, list):
-            await self.aexec_inset_on_conflict(instances, session, update_field)
-        else:
-            for instance in instances:
-                await self.aexec_inset_on_conflict(instance, session, update_field)
-        return instances
-    
-    async def ais_commit(self,
-                        instances: SQLModel | List[SQLModel], 
-                        session: Optional[AsyncSession],
-                        update_field:bool=False,
-                        ) -> SQLModel | List[SQLModel]:
-        if session==None:
-            async with self.session_factory() as session:
-                await session.commit()
-        return instances
-    
-    '''
-    根据 instances 模型判断是否存在唯一约束,存在则不添加数据
-    input:
-      - session: None 说明没有上层事务,自动提交。传入 session 说明使用外部事务来 commit
-    '''
-    async def aon_conflict_do_nothing(self, 
-                                      instances: SQLModel|List[SQLModel], 
-                                      session: Optional[AsyncSession] = None,
-                                      ) -> SQLModel|List[SQLModel]:
-        if session==None:
-            async with self.session_factory() as session:
-                await self.aexec_instances_if_list(instances, session, update_field=None)
-                await session.commit()
-        else:
-            await self.aexec_instances_if_list(instances, session, update_field=None)
-        return instances
+class BaseRepository:
+    def __init__(self, model: SQLModel, engine=engine):
+        self.model = model
+        self.engine = engine
+        self.unique_constraint_fields,self.primary_key_fields = self.get_unique_constraint_fields()
+        self.non_unique_fields = self.get_non_unique_fields()
+        # logger.debug(f"主键字段:{ self.primary_key_fields}")
+        # logger.debug(f"唯一约束字段:{self.unique_constraint_fields}", )
+        # logger.debug(f"非唯一约束字段:{self.non_unique_fields}")
+
+    def create(self, obj_in: SQLModel, ex_session: Optional[Session] = None) -> Optional[int]:
+        session = ex_session or Session(bind=self.engine)
+        session.add(obj_in)
+        if not ex_session:
+            session.commit()
+        return obj_in
+
+    def get(self, id: int, ex_session: Optional[Session] = None) -> Optional[SQLModel]:
+        session = ex_session or Session(bind=self.engine)
+        return session.get(self.model, id)
+
+    def update(self, id: int, obj_in: SQLModel, ex_session: Optional[Session] = None) -> bool:
+        def session_exec(session,obj_in: SQLModel):
+            obj = session.get(self.model, id)
+            if not obj:
+                return False
+            for key, value in obj_in.model_dump().items():
+                setattr(obj, key, value)
         
+        session = ex_session or Session(bind=self.engine)
+        session_exec(session,obj_in)
+        if not ex_session:
+            session.commit()
+        return obj_in
 
-    '''
-    根据 instances 模型判断是否存在唯一约束,存在更新数据
-    input:
-      - session: None 说明没有上层事务,自动提交。传入 session 说明使用外部事务来 commit
-    '''
-    async def aon_conflict_do_update(self,
-                                    instances: SQLModel | List[SQLModel],
-                                    session: Optional[AsyncSession] = None,
-                                    update_field:dict=None
-                                    ) -> SQLModel | List[SQLModel]:
-        if not update_field:
-            update_field = instances
-        if session==None:
-            async with self.session_factory() as session:
-                await self.aexec_instances_if_list(instances, session, update_field=update_field)
-                await session.commit()
-        else:
-            await self.aexec_instances_if_list(instances, session, update_field=update_field)
-        return instances
+    def delete(self, id: int, ex_session: Optional[Session] = None) -> bool:
+        def session_exec(session: Optional[Session],obj_in: SQLModel):
+            obj = session.get(self.model, id)
+            if not obj:
+                return False
+            session.delete(obj)
+            return obj
+        session = ex_session or Session(bind=self.engine)
+        obj = session_exec(session,id)
+        if not ex_session:
+            session.commit()
+        return obj
+    
+    def model_dump_by_field(self, obj_model:SQLModel=None, fields:List[str]=None):
+        if not obj_model:
+            obj_model = self.model
+        if not fields:
+            fields = self.non_unique_fields
+        data = {}
+        for field in fields:
+            if hasattr(obj_model, field):
+                data[field] = getattr(obj_model, field)
+        return data
+    
+    def set_obj_by_dict(self, target_obj: SQLModel, dict_data: dict):
+        for key, value in dict_data.items():
+            if hasattr(target_obj, key):
+                setattr(target_obj, key, value) 
+                
+                   
+    def add_or_update_by_unique(self, obj: SQLModel, ex_session: Optional[Session] = None, update_field=Any) -> bool:
+        def session_exec(session: Optional[Session],obj_in: SQLModel, update_field=Any):
+            table_unique_keys = self.get_unique_constraint_fields()
+            existing_obj = self.check_exist(obj_in, check_field=table_unique_keys, ex_session=session)
+            # logger.debug(f"existing_obj_by_unique_keys: {existing_obj_by_unique_keys} {type(existing_obj_by_unique_keys)}")
+            if existing_obj:
+                logger.debug(f"check {obj.__class__} from fields {table_unique_keys} exist {existing_obj} ")
+                if isinstance(update_field, self.model):
+                    not_unique_keys = [k for k in obj.model_dump().keys() if k not in table_unique_keys]
+                    for key in not_unique_keys:
+                        setattr(existing_obj, key, getattr(obj, key))
+                    logger.debug(f"update non unique filed: {not_unique_keys}")
+                elif isinstance(update_field, dict):
+                    for key in update_field.keys():
+                        setattr(existing_obj, key, getattr(obj, key))
+                    logger.debug(f"update_field: {update_field.keys()}")
+                else:
+                    logger.debug(" do nothing")
+                return existing_obj
+            else:
+                logger.debug(f"add {obj}")
+                session.add(obj)
+                return obj
 
+        session = ex_session or Session(bind=self.engine)
+        obj = session_exec(session,obj,update_field)
+        if not ex_session:
+            session.commit()
+        return obj
     
-    def _get_unique_constraint_fields(self) -> List[str]:
-            constraints = getattr(self.model.__table__, 'constraints', [])
-            unique_constraints = [c for c in constraints if isinstance(c, UniqueConstraint)]
-            
-            index_elements = []
-            for uc in unique_constraints:
-                index_elements.extend([column.name for column in uc.columns])
-            
-            return list(set(index_elements))  # 去除重复字段
+    def check_exist(self, obj: SQLModel, check_field=None, ex_session=None):
+        session = ex_session or Session(bind=self.engine)
+        if not check_field:
+            check_field = self.get_unique_constraint_fields()
+        check_field_dict: Dict[str, Any] = {k: getattr(obj, k) for k in check_field}
+        # logger.debug(f"check if unique constraint: {check_field_dict}")
+        
+        query = select(self.model).where(*[getattr(self.model, k) == v for k, v in check_field_dict.items()])
+        existing_obj_by_check_field = session.scalars(query).first()
+        return existing_obj_by_check_field
+        
+        
+    def get_unique_constraint_fields(self) -> List[str]:  
+        # 获取通过 UniqueConstraint 定义的唯一约束字段  
+        constraints = getattr(self.model.__table__, 'constraints', [])  
+        unique_constraints_fields = [column.name for c in constraints if isinstance(c, UniqueConstraint) for column in c.columns]  
+  
+        # 获取主键字段  
+        primary_key_fields = [column.name for c in constraints if isinstance(c, PrimaryKeyConstraint) for column in c.columns]  
+  
+        # 合并唯一约束字段和主键字段,并去重  
+        # unique_fields = list(set(unique_constraints_fields))  
+          
+        return unique_constraints_fields,primary_key_fields
 
-class DouyinBaseRepository(BaseRepository):
-    def __init__(self, model: DeclarativeMeta, engine=engine):  
-        super().__init__(model, engine)  
+    def get_non_unique_fields(self) -> List[str]:  
+        unique_fields,primary_key_fields = self.get_unique_constraint_fields()
+        unique_fields.extend(primary_key_fields)
+        all_fields = [k for k,v in self.model.model_fields.items()]
+        non_unique_fields = [field for field in all_fields if field not in unique_fields]  
+        return non_unique_fields
 
-    '''
-    input: open_id:str
-    return : SQL row
-    '''
-    async def get_by_open_id(self, open_id):  
-        async with self.session_factory() as session:  
-            stmt = select(self.model).where(self.model.open_id == open_id)  
-            result = await session.execute(stmt)  
-            user_info = result.scalars().first()  # 获取查询结果的第一个记录,如果没有找到则返回 None  
-            return user_info
-    
-    def get_update_time(self):
+    def set_update_time(self, obj: SQLModel):
         if hasattr(self.model, 'update_time'): 
-            return {'update_time': datetime.datetime.now()}
-    '''
-    input: 
-      - data:dict  通常抖音返回数据是json格式,因此这里也用字典传参类型,如果是 SQLmodel 会自动用 data.model_dump() 方法转化成字典
-      - constraint_name 字段唯一值,如果 data 所含的字段存在于数据库则更新该行
-    return : res
-    '''
-    # async def aon_conflict_do_update(self, data: dict, session: Optional[AsyncSession] = None,) -> SQLModel | List[SQLModel]:  
-    #     if type(data) == self.model:
-    #         data = data.model_dump()
-    #     if hasattr(self.model, 'update_time'): 
-    #         import datetime 
-    #         data['update_time'] = datetime.datetime.now()
-    #     clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)} 
-    #     return await super().aon_conflict_do_update(self.model(**clean_data))
-        # try:  
-        #     index_elements = self._get_unique_constraint_fields()
-        #     async with self.session_factory() as session:  
-        #         # 只获取 self.model 定义的字段
-        #         clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)} 
-        #         # logger.debug(f"clean data:{clean_data} from {self.model}") 
-        #         # 如果 self.model 中有 update_time 字段,则设置其为当前时间  
-        #         if hasattr(self.model, 'update_time'):  
-        #             clean_data['update_time'] = func.now() 
-        #         # 构建 SQL 语句,实现插入或更新
-        #         insert_stmt = insert(self.model).values(**clean_data)
-        #         update_stmt = insert_stmt.on_conflict_do_update(  
-        #             index_elements=index_elements,  
-        #             set_={k: clean_data[k] for k in clean_data if k not in index_elements}  
-        #         ).returning(self.model.id)
-        #         result = await session.execute(update_stmt)  
-        #         new_id = result.scalar()
-        #         await session.commit()  
-        #         return new_id  
-        # except IntegrityError as e:  
-        #     logger.exception(f"IntegrityError occurred: {e}")  
-        #     # 如果需要,可以在这里做更多的错误处理,比如回滚事务等。  
-        #     # 但注意,由于使用了async with,session在退出with块时通常会自动回滚未提交的事务。  
-        # except Exception as e:  
-        #     # 捕获其他类型的异常  
-        #     logger.exception(f"An unexpected error occurred: {e}")  
-        #     raise  # 如果需要,可以重新抛出异常
+            obj.update_time =  datetime.datetime.now()
+            return obj.update_time

+ 104 - 65
db/docs.py

@@ -8,29 +8,25 @@ import os
 import sys
 sys.path.append(os.path.dirname(os.path.dirname(__file__)))
 
-from sqlmodel import Field, SQLModel,Column, Integer, Sequence, UniqueConstraint 
+from sqlmodel import Field, SQLModel,Session, Integer, Sequence, UniqueConstraint,select
 from config import DB_URL,logger
 # from db.common import engine
-from sqlalchemy.dialects.postgresql import insert
-from sqlalchemy.sql.sqltypes import Integer, String, DateTime
-from sqlalchemy.sql.schema import Column
-from sqlalchemy import UniqueConstraint
 from pydantic import UUID4
 import uuid
-from db.base import BaseRepository,DouyinBaseRepository
-from db.engine import engine,create_all
+from db.base import BaseRepository
+from db.engine import engine
 
   
 
-class Categories(SQLModel,DouyinBaseRepository, table=True):  
+class Categories(SQLModel, table=True):  
     id: UUID4 = Field(default_factory=uuid.uuid1, primary_key=True)  # 使用 UUID v1 作为主键 
     open_id: str = Field(foreign_key="useroauthtoken.open_id",index=True)  # 关联到用户表的外键  
     name: str = Field(default="default", index=True)  # 分类的名称,添加索引以优化查询性能  
     update_time: datetime = Field(default_factory=datetime.now)  # 创建时间、更新时间
     # 添加联合唯一约束  
     __table_args__ = (UniqueConstraint('open_id', 'name', name='uq_open_id_ctname'),)
-    
-        
+
+
     
 class DocumentCategories(SQLModel, table=True):
     id: UUID4 = Field(foreign_key="documents.id",index=True, primary_key=True)  # 关联到文档表的外键  
@@ -49,58 +45,81 @@ class Documents(SQLModel, table=True):
     status: int = Field(nullable=False) # 文档状态  
     update_time: datetime = Field(default_factory=datetime.now)  # 创建时间、更新时间
     __table_args__ = (UniqueConstraint('open_id', 'path', name='uq_documents'),) 
+
+class DocumentBase(BaseRepository):
+    def __init__(self, model: Documents, engine=...):
+        super().__init__(model, engine)
     
-class CategoriesRepository(DouyinBaseRepository):  
-    def __init__(self, engine=engine):  
-        super().__init__(Categories, engine)  
+    def before_update(self, obj_model: 'Documents', exist_obj: 'Documents') -> None:  
+        """在更新对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""  
+        pass  
+  
+    def before_create(self, obj_model: 'Documents') -> None:  
+        """在创建对象之前调用的钩子方法,子类可以覆盖以实现自定义逻辑"""  
+        pass      
+    
+    def get_model_dump_field_for_update(self):
+        return self.non_unique_fields
     
-    async def aexec_add_or_update_categ(self, open_id, category_name, session):
-        categ_model = Categories(open_id=open_id,name=category_name)
-        await self.aon_conflict_do_nothing(categ_model, session)
-        if categ_model.id:
-            logger.debug(f"{open_id} add new category:{category_name}")
+    def add_or_update(self, obj_model: Documents, ex_session: Optional[Session] = None) -> SQLModel:
+        session = ex_session or Session(bind=self.engine)
+        exist_obj:Documents = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
+        logger.debug(f"check table '{self.model.__tablename__}' where {self.unique_constraint_fields}")
+        if exist_obj:
+            dict_data = self.model_dump_by_field(obj_model, fields=self.get_model_dump_field_for_update())
+            if dict_data:
+                logger.debug(f"modify table '{self.model.__tablename__}' id '{exist_obj.id}' from {dict_data}")
+            else:
+                logger.debug(f"table '{self.model.__tablename__}' do nothing. exist '{exist_obj}'. update field {self.get_model_dump_field_for_update()}")
+            self.before_update(obj_model, exist_obj)
+            self.set_obj_by_dict(exist_obj,dict_data)
+            if not ex_session:
+                session.commit()
+                # commit 之后 exist_obj 的值就被释放了,不存在了
+                return True
+            return exist_obj
         else:
-            logger.debug(f"{open_id} already have same name:{category_name}")
-        return categ_model
+            self.before_create(obj_model)
+            self.create(obj_model,ex_session)
+            if not ex_session:
+                session.commit()
+                return True
+            logger.debug(f"on table '{self.model.__tablename__}' create {obj_model}")
+            return obj_model
 
-        
-class DocumentCategoriesRepository(DouyinBaseRepository):  
+
+class CategoriesRepository(DocumentBase):  
+    def __init__(self, engine=engine):  
+        super().__init__(Categories, engine)  
+
+    # 分类表不需要更新时间
+    def get_model_dump_field_for_update(self):
+        ret:list = self.non_unique_fields
+        if "update_time" in ret:
+            ret.remove("update_time")
+        return ret
+
+class DocumentCategoriesRepository(DocumentBase):  
     def __init__(self, engine=engine):  
         super().__init__(DocumentCategories, engine)  
 
 
-class DocumentsRepository(DouyinBaseRepository):  
+class DocumentsRepository(DocumentBase):  
     def __init__(self, engine=engine):  
         super().__init__(Documents, engine)  
 
-    def get_relative_path(full_path):
-        pattern = r'docs(/.*?)$'  
-        match = re.search(pattern, full_path)  
-        if match:
-            return match.group(1)
-        else:
-            logger.error(f"Can not get rel path:{full_path}")
-    
-    # file_path = {DATA_DIR}/{open_id}/docs/xxx/example_file.pdf
-    async def add_document_with_categories(self, open_id, file_path, category_name="default"):
-        async with self.session_factory() as session:
-            doc_model:Documents = await self.aexec_add_or_update_file(open_id, file_path, session)
-            
+    def add_document_with_categories(self, open_id, file_path, category_name="default"):
+        with Session(bind=self.engine) as session:
+            doc_model:Documents = self.exec_add_or_update_file(open_id, file_path, session)
             cr = CategoriesRepository()
-            category_model = await cr.aexec_add_or_update_categ(open_id, category_name,session)
-            logger.debug(f"category_id:{category_model}")
-            
-            if doc_model.id is not None and category_model.id is not None:
-                doc_categ_model = DocumentCategories(id=doc_model.id, category_id=category_model.id)
-                dr = DocumentCategoriesRepository()
-                logger.info(doc_categ_model)
-                await dr.aon_conflict_do_nothing(doc_categ_model, session)
-                await session.commit()
-            else:
-                logger.info("DocumentCategories no change.")
-        return
+            category_model:Categories = cr.add_or_update(Categories(open_id=open_id, name=category_name),session)
+            dcr = DocumentCategoriesRepository()
+            doc_categ_model = dcr.add_or_update(DocumentCategories(id=doc_model.id, category_id=category_model.id), session)
+            session.commit()
+            return True
     
-    async def aexec_add_or_update_file(self, open_id, file_path, session):
+    def exec_add_or_update_file(self, open_id, file_path, session):
+        # file_path = {DATA_DIR}/{open_id}/docs/xxx/example_file.pdf
         relative_path = DocumentsRepository.get_relative_path(file_path)
         if relative_path == None:
             return
@@ -109,31 +128,51 @@ class DocumentsRepository(DouyinBaseRepository):
                             path=relative_path,
                             status=DocStatus.UNPROCESSED,
                             )
-        # 在同一个 open_id 用户层面上,如果 relative_path 相同,则产生冲突,仅仅更新时间。说明 file_path 同路径下覆盖了新文件
-        # 没有产生冲突,说明不同用户或不同路径下新增了文件
-        document_model:Documents = await self.aon_conflict_do_update(self.instance_model, session)
-        res = self.aget(open_id=open_id, file_path=file_path)
-        logger.info(f"get doc row:{res}")
-        if document_model.id:
-            logger.debug(f"{document_model.open_id} add new file:{document_model.path}")
+        res = self.add_or_update(self.instance_model, session)
+        return res
+
+    def get_relative_path(full_path):
+        pattern = r'docs(/.*?)$'  
+        match = re.search(pattern, full_path)  
+        if match:
+            return match.group(1)
         else:
-            logger.debug(f"{document_model.open_id} overwrite file:{document_model.path}")
-        return document_model
+            logger.error(f"Can not get rel path:{full_path}")
+    
+    
+    def get_user_files_path(self, open_id: str, category_id: Optional[UUID4] = None, category_name: Optional[str] = None) -> List[str]:  
+        with Session(self.engine) as session:  
+            # 基础查询,从 Documents 表中选择 path  
+            base_statement = select(Documents.path).where(Documents.open_id == open_id)  
+              
+            # 根据 category_id 或 category_name 进行过滤  
+            if category_id:  
+                base_statement = base_statement.where(Documents.id.in_(  
+                    select(DocumentCategories.id).where(DocumentCategories.category_id == category_id)  
+                ))  
+            elif category_name:  
+                category_subquery = select(Categories.id).where(Categories.name == category_name)  
+                doc_category_subquery = select(DocumentCategories.id).where(DocumentCategories.category_id.in_(category_subquery))  
+                base_statement = base_statement.where(Documents.id.in_(doc_category_subquery))  
+              
+            # 执行查询并返回结果  
+            results = session.exec(base_statement)
+            return results.all()
         
 # 示例使用  
-async def main():  
+def main():  
     from db.user import test_add
-    open_id = await test_add()
+    open_id = test_add()
     # 创建实例  
     documents_repo = DocumentsRepository()  
-    res = await documents_repo.aget(id=1)
-    logger.info(res)
-    # await documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme.md")
+    documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme.md")
+    documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/99readme3.md")
+    documents_repo.add_document_with_categories(open_id,"/home/user/code/open-douyin/open_id/docs/readme5.md")
+    logger.info(documents_repo.get_user_files_path(open_id))
     # 添加分类  
     # doc1 = Documents(open_id=open_id, document_name="docs_fn", status="ready", file_path="/path")
     # doc2 = Documents(open_id=open_id, document_name="docs_jj", status="ready", file_path="/path")
     # 实现有关代码
   
 if __name__ == "__main__":  
-    import asyncio  
-    asyncio.run(main())
+    main()

+ 3 - 8
db/engine.py

@@ -1,11 +1,6 @@
 import asyncio
-from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine  
-from sqlmodel import Field, SQLModel
+from sqlmodel import Field, SQLModel, create_engine
 from config import DB_URL,logger
+import logging
 
-engine = create_async_engine(DB_URL)  # 替换成你的 DB_URL 
-# SQLModel.metadata.create_all() 
-async def create_all():  
-    async with engine.begin() as conn:  
-        await conn.run_sync(SQLModel.metadata.create_all)  
-  
+engine = create_engine(DB_URL)  # 替换成你的 DB_URL 

+ 70 - 40
db/user.py

@@ -3,6 +3,7 @@ from typing import Optional
 import os
 import sys
 sys.path.append(os.path.dirname(os.path.dirname(__file__)))
+from db.engine import engine
 
 from sqlmodel import Field, SQLModel,create_engine,Session,select,func
 import psycopg2
@@ -10,8 +11,7 @@ from config import DB_URL,logger
 # from db.common import engine
 from sqlalchemy import UniqueConstraint, Index
 from sqlalchemy.dialects.postgresql import insert
-from db.base import BaseRepository,DouyinBaseRepository
-from db.engine import engine,create_all
+from db.base import BaseRepository
 
 # 定义数据库模型,不推荐使用 __tablename__ = '"UserInfo"' 来定义包含大写的表名字,
 # 因为可能会导致与其他数据库系统不兼容,而且表查询的时候需要额外注意表格名的大小写
@@ -38,48 +38,75 @@ class UserInfo(SQLModel, table=True):
     union_id: str  
     update_time: datetime = Field(default_factory=datetime.now)  
     __table_args__ = (UniqueConstraint('open_id'),) 
-    
+
+
+
+class DouyinBaseRepository(BaseRepository):
+    def __init__(self, model: SQLModel, engine=engine):  
+        super().__init__(model, engine)  
+
+    def add_or_update(self, dict_data: dict) -> SQLModel:
+        clean_data = {k: v for k, v in dict_data.items() if hasattr(self.model, k)}
+        obj_model = self.model(**clean_data)
+        with Session(bind=self.engine) as session:
+            exist_obj = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
+            logger.debug(f"check table '{self.model.__tablename__}' where {self.unique_constraint_fields}")
+            if exist_obj:
+                self.set_update_time(exist_obj)
+                dict_data = self.model_dump_by_field(obj_model, self.non_unique_fields)
+                self.set_obj_by_dict(exist_obj,dict_data)
+                logger.info(f"modify table '{self.model.__tablename__}' id '{exist_obj.open_id}' from {dict_data}")
+                session.commit()
+                return exist_obj
+            else:
+                session.commit()
+                return obj_model
+
+
 
 class UserInfoRepository(DouyinBaseRepository):  
     def __init__(self, engine=engine):  
         super().__init__(UserInfo, engine)  
         self.model:UserInfo
-        
-    async def create_user_info(self, user_info_data):  
-        if hasattr(self.model, 'update_time'): 
-            import datetime 
-            user_info_data['update_time'] = datetime.datetime.now()
-        clean_data = {k: v for k, v in user_info_data.items() if hasattr(self.model, k)} 
-        return await self.aon_conflict_do_update(self.model(**clean_data))
     
-    async def update_user_info(self, user_id, user_info_data):  
-        async with self.session_factory() as session:  
-            update_user_info = await session.get(UserInfo, user_id)  
-            if update_user_info:  
-                for key, value in user_info_data.items():  
-                    setattr(update_user_info, key, value)  
-                await session.commit()  
-                return update_user_info  
-  
-    async def delete_user_info(self, user_id):  
-        async with self.session_factory() as session:  
-            delete_user_info = await session.get(UserInfo, user_id)  
-            if delete_user_info:  
-                await session.delete(delete_user_info)  
-                await session.commit() 
-        
+    def add_or_update_by_unique(self, obj_in: dict) -> SQLModel:
+        clean_data = {k: v for k, v in obj_in.items() if hasattr(self.model, k)}
+        obj_model = self.model(**clean_data)
+        with Session(bind=self.engine) as session:
+            exist_obj = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
+            if exist_obj:
+                self.set_update_time(exist_obj)
+                dict_data = self.model_dump_by_field(obj_model, self.non_unique_fields)
+                self.set_obj_by_dict(exist_obj,dict_data)
+                session.commit()
+                return exist_obj
+            else:
+                self.create(obj_model)
+                session.commit()
+                return obj_model
+    
+    
 # Database manager class
 class UserOAuthRepository(DouyinBaseRepository):  
     def __init__(self, engine=engine):  
         super().__init__(UserOAuthToken, engine)  
         self.model:UserOAuthToken
 
-    async def add_token(self, data: dict):  
-        if hasattr(self.model, 'update_time'): 
-            import datetime 
-            data['update_time'] = datetime.datetime.now()
-        clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)} 
-        return await self.aon_conflict_do_update(self.model(**clean_data))
+    # def add_token(self, data: dict):  
+    #     clean_data = {k: v for k, v in data.items() if hasattr(self.model, k)}
+    #     obj_model = self.model(**clean_data)
+    #     with Session(bind=self.engine) as session:
+    #         exist_obj = self.check_exist(obj_model, self.unique_constraint_fields, ex_session=session)
+    #         if exist_obj:
+    #             self.set_update_time(exist_obj)
+    #             dict_data = self.model_dump_by_field(obj_model, self.unique_constraint_fields)
+    #             self.set_obj_by_dict(exist_obj,dict_data)
+    #             session.commit()
+    #             return exist_obj
+    #         else:
+    #             self.create(obj_model)
+    #             session.commit()
+    #             return obj_model
 
 
     async def delete_token(self, token_id: int):  
@@ -95,8 +122,8 @@ class UserOAuthRepository(DouyinBaseRepository):
 
     
 
-async def test_add():  
-    await create_all()
+def test_add(open_id=None):  
+    SQLModel.metadata.create_all(engine)
     
     user_oauth = {'access_token': 'act.3.wl8L3DFQ3sj3uKYzQShOSs8HbOgKh0FVvjxKeaTum0ZOEXoyBI8D1N7gTBqGbrY32KP-Pm41EAvcobSheOBi8tvRdhj7m5-5ZVoprZZu_GN5J2KnH2fZ_X9_l7Q6iFyvpPoMkX3Zyom3PCkeRZp4Jg9sE2ZiwuvZVdnvft0A25uBWXvj2IEbWW_0Bf8=', 'captcha': '', 'desc_url': '', 'description': '', 'error_code': 0, 'expires_in': 1296000, 'log_id': '20240129123749239735B0529965BC6D93', 'open_id': '_000QadFMhmU1jNCI3JdPnyVDL6XavC70dFy', 'refresh_expires_in': 2592000, 'refresh_token': 'rft.c29d64456ea3d5e4c932247ee93dd735aq5OhtcYNXNFAD70XHKrdntpE6U0', 'scope': 'user_info,trial.whitelist'}
     user_info = {
@@ -104,7 +131,7 @@ async def test_add():
     "avatar_larger": "https://p3.douyinpic.com/aweme/1080x1080/aweme-avatar/tos-cn-i-0813_66c4e34ae8834399bbf967c3d3c919db.jpeg?from=4010531038",
     "captcha": "",
     "city": "",
-    "client_key": "123",
+    "client_key": "55",
     "country": "",
     "desc_url": "",
     "description": "",
@@ -118,14 +145,17 @@ async def test_add():
     "province": "",
     "union_id": "123-01ae-59bd-978a-1de8566186a8"
   }
+    if open_id:
+        user_oauth["open_id"] = open_id
+        user_info["open_id"] = open_id
+        user_info["nickname"] = "user" + open_id[:5]
     db_manager = UserOAuthRepository()
-    res = await db_manager.add_token(user_oauth)
+    res = db_manager.add_or_update(user_oauth)
     # logger.debug(res)
     db_user_info = UserInfoRepository()
-    res = await db_user_info.create_user_info(user_info)
+    res = db_user_info.add_or_update(user_info)
     # logger.debug(res)
     return user_oauth["open_id"]
 
-if __name__ == "__main__":  
-    import asyncio  
-    asyncio.run(test_add())
+if __name__ == "__main__":
+    test_add()