from datetime import datetime from typing import Optional, List from sqlmodel import SQLModel, Field, Relationship, create_engine, Session, select, delete, func,distinct from sqlalchemy.orm import relationship from sqlalchemy import UniqueConstraint from sqlalchemy.sql import text from pathlib import Path from config.settings import DB_URL class KeywordTask(SQLModel, table=True): __table_args__ = (UniqueConstraint("keyword", name="uq_keyword"),) id: Optional[int] = Field(default=None, primary_key=True) keyword: str = Field(index=True, unique=True) total_results: Optional[int] = None is_completed: Optional[bool] = Field(default=False) fail_count: Optional[int] = 0 created_at: Optional[datetime] = Field(default_factory=datetime.now) pages: List["SearchPageResult"] = Relationship(back_populates="keyword_task") items: List["SearchResultItem"] = Relationship(back_populates="keyword_task") class SearchPageResult(SQLModel, table=True): __table_args__ = (UniqueConstraint("keyword_id", "page_number", name="uq_keyword_page"),) id: Optional[int] = Field(default=None, primary_key=True) keyword_id: int = Field(foreign_key="keywordtask.id") keyword: str = Field(index=True) page_number: int results_count: int has_next_page: bool html_path: Optional[str] = None created_at: datetime = Field(default_factory=datetime.now) keyword_task: Optional[KeywordTask] = Relationship(back_populates="pages") items: List["SearchResultItem"] = Relationship(back_populates="search_page") class SearchResultItem(SQLModel, table=True): __table_args__ = (UniqueConstraint("url", "page_id", name="uq_url_page"),) id: Optional[int] = Field(default=None, primary_key=True) url: str title: Optional[str] = None content: Optional[str] = None html_path: Optional[str] = None keyword_id: int = Field(foreign_key="keywordtask.id") keyword: str = Field(index=True) page_id: int = Field(foreign_key="searchpageresult.id") created_at: datetime = Field(default_factory=datetime.now) keyword_task: Optional[KeywordTask] = Relationship(back_populates="items") search_page: Optional[SearchPageResult] = Relationship(back_populates="items") class VerificationItem(SQLModel, table=True): __table_args__ = (UniqueConstraint("result_item_id", name="uq_verification_item"),) id: Optional[int] = Field(default=None, primary_key=True) result_item_id: int = Field(foreign_key="searchresultitem.id") search_result_item: Optional[SearchResultItem] = Relationship( sa_relationship=relationship("SearchResultItem", lazy="joined") ) verified: bool = Field(default=False) created_at: datetime = Field(default_factory=datetime.now) class SearchResultManager: def __init__(self, db_url: str = DB_URL): self.engine = create_engine(db_url) SQLModel.metadata.create_all(self.engine) def get_keyword_task(self, keyword: str) -> KeywordTask | None: with Session(self.engine) as session: return session.exec( select(KeywordTask) .where(KeywordTask.keyword == keyword) ).first() def get_uncompleted_keywords(self) -> list[str]: """从数据库获取已完成搜索但未完成爬取的关键词""" with Session(self.engine) as session: # 使用JOIN优化查询,避免子查询 query = ( select(distinct(KeywordTask.keyword)) .where(KeywordTask.is_completed != True) ) keywords = session.exec(query).all() return keywords def delete_keyword_task(self, keyword: str): """删除关键词及其所有关联数据""" with Session(self.engine) as session: # 先获取关键词任务 keyword_task = session.exec( select(KeywordTask) .where(KeywordTask.keyword == keyword) ).first() if keyword_task: # 删除关联的SearchResultItem session.exec( delete(SearchResultItem) .where(SearchResultItem.keyword_id == keyword_task.id) ) # 删除关联的SearchPageResult session.exec( delete(SearchPageResult) .where(SearchPageResult.keyword_id == keyword_task.id) ) # 删除KeywordTask session.delete(keyword_task) session.commit() def create_keyword_task(self, keyword: str) -> KeywordTask: with Session(self.engine) as session: # 先删除可能存在的旧数据(在process_keyword中已处理) task = KeywordTask(keyword=keyword) session.add(task) session.commit() session.refresh(task) return task def save_page_results( self, keyword: str, page_number: int, results_count: int, has_next_page: bool, html_path: Optional[Path] = None ) -> SearchPageResult | None: with Session(self.engine) as session: existing = session.exec( select(SearchPageResult) .where(SearchPageResult.keyword == keyword) .where(SearchPageResult.page_number == page_number) ).first() if existing: return existing keyword_task = self.get_keyword_task(keyword) if not keyword_task: raise ValueError("Keyword task must exist before saving page results") page_result = SearchPageResult( keyword_id=keyword_task.id, keyword=keyword, page_number=page_number, results_count=results_count, has_next_page=has_next_page, html_path=str(html_path) if html_path else None ) session.add(page_result) session.commit() session.refresh(page_result) return page_result def save_result_items( self, keyword: str, page_id: int, items: List[SearchResultItem], html_path: Optional[Path] = None ) -> int: with Session(self.engine) as session: keyword_task = session.exec( select(KeywordTask) .where(KeywordTask.keyword == keyword) ).first() if not keyword_task: raise ValueError(f"Keyword task not found for keyword: {keyword}") new_items = [] for item in items: exists = session.exec( select(SearchResultItem) .where(SearchResultItem.url == item.url) .where(SearchResultItem.page_id == page_id) ).first() if not exists: new_item = SearchResultItem( url=item.url, title=item.title, content=item.content, html_path=str(html_path) if html_path else None, keyword_id=keyword_task.id, keyword=keyword, page_id=page_id ) new_items.append(new_item) session.add_all(new_items) session.commit() return new_items def mark_task_completed(self, keyword: str): with Session(self.engine) as session: task = self.get_keyword_task(keyword) if not task: raise ValueError(f"Keyword task {keyword} not found") # 使用窗口函数确保统计准确性 total_results = session.scalar( select(func.sum(SearchPageResult.results_count)) .where(SearchPageResult.keyword_id == task.id) .execution_options(compile_kwargs={"literal_binds": True}) ) or 0 task.is_completed = True task.total_results = total_results session.add(task) session.commit() session.refresh(task) return task def is_task_completed(self, keyword: str) -> bool: task = self.get_keyword_task(keyword) return task.is_completed if task else False def get_all_search_result_items(self) -> List[SearchResultItem]: """ 获取数据库中所有的 SearchResultItem。 """ with Session(self.engine) as session: return session.exec(select(SearchResultItem)).all() def get_task_statistics(self) -> dict: """获取任务统计信息""" with Session(self.engine) as session: total = session.scalar(select(func.count(KeywordTask.id))) completed = session.scalar( select(func.count(KeywordTask.id)) .where(KeywordTask.is_completed == True) ) return { "total_tasks": total or 0, "completed_tasks": completed or 0, "pending_tasks": (total or 0) - (completed or 0) } def add_to_verification(self, result_item_id: int): """ 将 SearchResultItem 添加到 VerificationItem 表中,避免重复添加。 """ with Session(self.engine) as session: exists = session.exec( select(VerificationItem) .where(VerificationItem.result_item_id == result_item_id) ).first() if not exists: verification_item = VerificationItem(result_item_id=result_item_id) session.add(verification_item) session.commit() session.refresh(verification_item) return verification_item return exists