excel_load.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from fastapi import FastAPI, File, UploadFile, HTTPException, APIRouter
  2. import asyncio
  3. from sqlalchemy.dialects.sqlite import insert
  4. from pathlib import Path
  5. import pandas as pd
  6. from typing import List
  7. from sqlmodel import select, SQLModel, Session
  8. from mylib.logu import logger
  9. from worker.search_engine.search_result_db import SearchResultManager, KeywordTask
  10. from fastapi.responses import JSONResponse
  11. from io import BytesIO
  12. from config.settings import DB_URL
  13. app = APIRouter()
  14. @app.get("/tasks/statistics", summary="获取任务统计信息")
  15. async def get_task_statistics():
  16. try:
  17. db_manager = SearchResultManager()
  18. stats = db_manager.get_task_statistics()
  19. return JSONResponse(content=stats)
  20. except Exception as e:
  21. logger.error(f"获取统计信息失败: {str(e)}")
  22. raise HTTPException(status_code=500, detail=str(e))
  23. @app.post("/upload",
  24. summary="导入关键词文件",
  25. description="支持Excel和CSV文件,读取第一列作为关键词")
  26. async def import_keywords(file: UploadFile = File(...)):
  27. try:
  28. db_manager = SearchResultManager()
  29. logger.info(f"数据库连接: {DB_URL}")
  30. # 异步读取文件内容并使用线程池处理
  31. content = await file.read()
  32. loop = asyncio.get_event_loop()
  33. # 在线程池中处理文件读取
  34. file_path = Path(file.filename)
  35. df = await loop.run_in_executor(
  36. None,
  37. lambda: pd.read_excel(BytesIO(content), header=0) if file_path.suffix.lower() in ('.xlsx', '.xls')
  38. else pd.read_csv(BytesIO(content), sep='\t' if file_path.suffix == '.tsv' else ',')
  39. )
  40. keywords = df.iloc[:, 0].astype(str).tolist()
  41. # 在另一个线程中处理数据库插入
  42. inserted_count = await asyncio.to_thread(
  43. bulk_import_keywords_to_db,
  44. keywords,
  45. db_manager
  46. )
  47. return JSONResponse(content={
  48. "total_keywords": len(keywords),
  49. "inserted_count": inserted_count,
  50. "message": "导入成功"
  51. })
  52. except Exception as e:
  53. logger.error(f"文件处理失败: {str(e)}")
  54. raise HTTPException(status_code=500, detail=str(e))
  55. def bulk_import_keywords_to_db(keywords: List[str], db_manager: SearchResultManager):
  56. """
  57. 使用 SQLite 的 UPSERT 语法进行批量插入,避免重复
  58. """
  59. try:
  60. with Session(db_manager.engine) as session:
  61. # 使用 SQL 原生批量插入语法
  62. stmt = insert(KeywordTask).values(
  63. [{"keyword": kw} for kw in keywords]
  64. ).on_conflict_do_nothing(index_elements=["keyword"])
  65. # 执行批量插入
  66. result = session.exec(stmt)
  67. session.commit()
  68. # 获取实际插入数量
  69. inserted_count = result.rowcount
  70. logger.info(f"成功导入 {inserted_count} 个新关键词")
  71. return inserted_count
  72. except Exception as e:
  73. logger.error(f"批量导入失败: {str(e)}")
  74. if 'session' in locals():
  75. session.rollback()
  76. raise