| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- from fastapi import FastAPI, File, UploadFile, HTTPException, APIRouter
- import asyncio
- from sqlalchemy.dialects.sqlite import insert
- from pathlib import Path
- import pandas as pd
- from typing import List
- from sqlmodel import select, SQLModel, Session
- from mylib.logu import logger
- from worker.search_engine.search_result_db import SearchResultManager, KeywordTask
- from fastapi.responses import JSONResponse
- from io import BytesIO
- from config.settings import DB_URL
- app = APIRouter()
- @app.get("/tasks/statistics", summary="获取任务统计信息")
- async def get_task_statistics():
- try:
- db_manager = SearchResultManager()
- stats = db_manager.get_task_statistics()
- return JSONResponse(content=stats)
- except Exception as e:
- logger.error(f"获取统计信息失败: {str(e)}")
- raise HTTPException(status_code=500, detail=str(e))
- @app.post("/upload",
- summary="导入关键词文件",
- description="支持Excel和CSV文件,读取第一列作为关键词")
- async def import_keywords(file: UploadFile = File(...)):
- try:
- db_manager = SearchResultManager()
- logger.info(f"数据库连接: {DB_URL}")
-
- # 异步读取文件内容并使用线程池处理
- content = await file.read()
- loop = asyncio.get_event_loop()
-
- # 在线程池中处理文件读取
- file_path = Path(file.filename)
- df = await loop.run_in_executor(
- None,
- lambda: pd.read_excel(BytesIO(content), header=0) if file_path.suffix.lower() in ('.xlsx', '.xls')
- else pd.read_csv(BytesIO(content), sep='\t' if file_path.suffix == '.tsv' else ',')
- )
- keywords = df.iloc[:, 0].astype(str).tolist()
-
- # 在另一个线程中处理数据库插入
- inserted_count = await asyncio.to_thread(
- bulk_import_keywords_to_db,
- keywords,
- db_manager
- )
-
- return JSONResponse(content={
- "total_keywords": len(keywords),
- "inserted_count": inserted_count,
- "message": "导入成功"
- })
-
- except Exception as e:
- logger.error(f"文件处理失败: {str(e)}")
- raise HTTPException(status_code=500, detail=str(e))
- def bulk_import_keywords_to_db(keywords: List[str], db_manager: SearchResultManager):
- """
- 使用 SQLite 的 UPSERT 语法进行批量插入,避免重复
- """
- try:
- with Session(db_manager.engine) as session:
- # 使用 SQL 原生批量插入语法
- stmt = insert(KeywordTask).values(
- [{"keyword": kw} for kw in keywords]
- ).on_conflict_do_nothing(index_elements=["keyword"])
-
- # 执行批量插入
- result = session.exec(stmt)
- session.commit()
-
- # 获取实际插入数量
- inserted_count = result.rowcount
-
- logger.info(f"成功导入 {inserted_count} 个新关键词")
- return inserted_count
-
- except Exception as e:
- logger.error(f"批量导入失败: {str(e)}")
- if 'session' in locals():
- session.rollback()
- raise
|