from fastapi import FastAPI, File, UploadFile, HTTPException, APIRouter import asyncio from sqlalchemy.dialects.sqlite import insert from pathlib import Path import pandas as pd from typing import List from sqlmodel import select, SQLModel, Session from mylib.logu import logger from worker.search_engine.search_result_db import SearchResultManager, KeywordTask from fastapi.responses import JSONResponse from io import BytesIO from config.settings import DB_URL app = APIRouter() @app.get("/tasks/statistics", summary="获取任务统计信息") async def get_task_statistics(): try: db_manager = SearchResultManager() stats = db_manager.get_task_statistics() return JSONResponse(content=stats) except Exception as e: logger.error(f"获取统计信息失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/upload", summary="导入关键词文件", description="支持Excel和CSV文件,读取第一列作为关键词") async def import_keywords(file: UploadFile = File(...)): try: db_manager = SearchResultManager() logger.info(f"数据库连接: {DB_URL}") # 异步读取文件内容并使用线程池处理 content = await file.read() loop = asyncio.get_event_loop() # 在线程池中处理文件读取 file_path = Path(file.filename) df = await loop.run_in_executor( None, lambda: pd.read_excel(BytesIO(content), header=0) if file_path.suffix.lower() in ('.xlsx', '.xls') else pd.read_csv(BytesIO(content), sep='\t' if file_path.suffix == '.tsv' else ',') ) keywords = df.iloc[:, 0].astype(str).tolist() # 在另一个线程中处理数据库插入 inserted_count = await asyncio.to_thread( bulk_import_keywords_to_db, keywords, db_manager ) return JSONResponse(content={ "total_keywords": len(keywords), "inserted_count": inserted_count, "message": "导入成功" }) except Exception as e: logger.error(f"文件处理失败: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) def bulk_import_keywords_to_db(keywords: List[str], db_manager: SearchResultManager): """ 使用 SQLite 的 UPSERT 语法进行批量插入,避免重复 """ try: with Session(db_manager.engine) as session: # 使用 SQL 原生批量插入语法 stmt = insert(KeywordTask).values( [{"keyword": kw} for kw in keywords] ).on_conflict_do_nothing(index_elements=["keyword"]) # 执行批量插入 result = session.exec(stmt) session.commit() # 获取实际插入数量 inserted_count = result.rowcount logger.info(f"成功导入 {inserted_count} 个新关键词") return inserted_count except Exception as e: logger.error(f"批量导入失败: {str(e)}") if 'session' in locals(): session.rollback() raise