Просмотр исходного кода

refactor: Convert excel_import functions to class-based approach

mrh (aider) 1 год назад
Родитель
Сommit
77bd06e99f
1 измененных файлов с 77 добавлено и 72 удалено
  1. 77 72
      database/excel_import.py

+ 77 - 72
database/excel_import.py

@@ -8,23 +8,6 @@ from sqlalchemy import text
 from database.sqlite_engine import engine
 from typing import Optional
 
-def read_excel_file(file_path: str):
-    return pd.read_excel(file_path)
-
-def process_excel_data(df):
-    """处理Excel数据"""
-    records = []
-    for _, row in df.iterrows():
-        record = {
-            "key_word": row[0],  # 第一列作为key_word
-            "total_pages": row[1] if len(row) > 1 else None,  # 第二列作为total_pages
-            "current_page": 0,  # 新增当前页码
-            "done": False,  # 默认done为False
-            "last_updated": datetime.now()  # 新增最后更新时间
-        }
-        records.append(record)
-    return records
-
 class Keyword(SQLModel, table=True):
     id: int = Field(default=None, primary_key=True)
     key_word: str = Field(unique=True)
@@ -33,69 +16,91 @@ class Keyword(SQLModel, table=True):
     done: bool = Field(default=False)
     last_updated: datetime = Field(default_factory=datetime.now)
 
-def insert_or_update_record(record):
-    with Session(engine) as session:
-        statement = select(Keyword).where(Keyword.key_word == record["key_word"])
-        existing = session.exec(statement).first()
+class ExcelDatabaseManager:
+    def __init__(self):
+        self.engine = engine
         
-        if existing:
-            # 更新现有记录
-            existing.total_pages = record.get("total_pages", existing.total_pages)
-            existing.current_page = record.get("current_page", existing.current_page)
-            existing.done = record.get("done", existing.done)
-            existing.last_updated = datetime.now()
-        else:
-            # 插入新记录
-            new_record = Keyword(**record)
-            session.add(new_record)
-        session.commit()
+    def read_excel_file(self, file_path: str):
+        return pd.read_excel(file_path)
 
-def add_or_update(file_path: str):
-    """从Excel文件导入数据到数据库"""
-    df = read_excel_file(file_path)
-    records = process_excel_data(df)
-    for record in records:
-        insert_or_update_record(record)
+    def process_excel_data(self, df):
+        """处理Excel数据"""
+        records = []
+        for _, row in df.iterrows():
+            record = {
+                "key_word": row[0],  # 第一列作为key_word
+                "total_pages": row[1] if len(row) > 1 else None,  # 第二列作为total_pages
+                "current_page": 0,  # 新增当前页码
+                "done": False,  # 默认done为False
+                "last_updated": datetime.now()  # 新增最后更新时间
+            }
+            records.append(record)
+        return records
 
-def get_keywords_from_db(done: bool = False, limit: int = None):
-    """从数据库获取关键词"""
-    with Session(engine) as session:
-        statement = select(Keyword).where(Keyword.done == done)
-        if limit:
-            statement = statement.limit(limit)
-        keywords = session.exec(statement.order_by(Keyword.last_updated)).all()
-        return keywords
-
-def mark_keyword_done(keyword: str):
-    """标记关键词为已完成"""
-    with Session(engine) as session:
-        statement = select(Keyword).where(Keyword.key_word == keyword)
-        keyword_record = session.exec(statement).first()
-        if keyword_record:
-            keyword_record.done = True
-            keyword_record.last_updated = datetime.now()
+    def insert_or_update_record(self, record):
+        with Session(self.engine) as session:
+            statement = select(Keyword).where(Keyword.key_word == record["key_word"])
+            existing = session.exec(statement).first()
+            
+            if existing:
+                # 更新现有记录
+                existing.total_pages = record.get("total_pages", existing.total_pages)
+                existing.current_page = record.get("current_page", existing.current_page)
+                existing.done = record.get("done", existing.done)
+                existing.last_updated = datetime.now()
+            else:
+                # 插入新记录
+                new_record = Keyword(**record)
+                session.add(new_record)
             session.commit()
 
-def update_keyword_progress(keyword: str, current_page: int):
-    """更新关键词的当前进度"""
-    with Session(engine) as session:
-        statement = select(Keyword).where(Keyword.key_word == keyword)
-        keyword_record = session.exec(statement).first()
-        if keyword_record:
-            keyword_record.current_page = current_page
-            keyword_record.last_updated = datetime.now()
-            session.commit()
+    def add_or_update(self, file_path: str):
+        """从Excel文件导入数据到数据库"""
+        df = self.read_excel_file(file_path)
+        records = self.process_excel_data(df)
+        for record in records:
+            self.insert_or_update_record(record)
+
+    def get_keywords(self, done: bool = False, limit: int = None):
+        """从数据库获取关键词"""
+        with Session(self.engine) as session:
+            statement = select(Keyword).where(Keyword.done == done)
+            if limit:
+                statement = statement.limit(limit)
+            keywords = session.exec(statement.order_by(Keyword.last_updated)).all()
+            return keywords
+
+    def mark_keyword_done(self, keyword: str):
+        """标记关键词为已完成"""
+        with Session(self.engine) as session:
+            statement = select(Keyword).where(Keyword.key_word == keyword)
+            keyword_record = session.exec(statement).first()
+            if keyword_record:
+                keyword_record.done = True
+                keyword_record.last_updated = datetime.now()
+                session.commit()
+
+    def update_keyword_progress(self, keyword: str, current_page: int):
+        """更新关键词的当前进度"""
+        with Session(self.engine) as session:
+            statement = select(Keyword).where(Keyword.key_word == keyword)
+            keyword_record = session.exec(statement).first()
+            if keyword_record:
+                keyword_record.current_page = current_page
+                keyword_record.last_updated = datetime.now()
+                session.commit()
 
-def get_next_keyword():
-    """获取下一个未完成的关键词"""
-    with Session(engine) as session:
-        statement = select(Keyword).where(Keyword.done == False)
-        keyword = session.exec(statement.order_by(Keyword.last_updated)).first()
-        return keyword
+    def get_next_keyword(self):
+        """获取下一个未完成的关键词"""
+        with Session(self.engine) as session:
+            statement = select(Keyword).where(Keyword.done == False)
+            keyword = session.exec(statement.order_by(Keyword.last_updated)).first()
+            return keyword
 
 def main():
-    # add_or_update(r"G:\weixin\WeChat Files\wxid_1fmirgx3vudo21\FileStorage\File\2025-01\测试-精油-2000.xlsx")
-    keywords = get_keywords_from_db()
+    manager = ExcelDatabaseManager()
+    # manager.add_or_update(r"G:\weixin\WeChat Files\wxid_1fmirgx3vudo21\FileStorage\File\2025-01\测试-精油-2000.xlsx")
+    keywords = manager.get_keywords()
     print([k.key_word for k in keywords[:50]])
 
 if __name__ == "__main__":