| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229 |
- from datetime import datetime
- from typing import Optional, List
- from sqlmodel import SQLModel, Field, Relationship, create_engine, Session, select, delete, func
- from sqlalchemy.orm import relationship
- from sqlalchemy import UniqueConstraint
- from sqlalchemy.sql import text
- from pathlib import Path
- from config.settings import DB_URL
- class KeywordTask(SQLModel, table=True):
- __table_args__ = (UniqueConstraint("keyword", name="uq_keyword"),)
-
- id: Optional[int] = Field(default=None, primary_key=True)
- keyword: str = Field(index=True, unique=True)
- total_results: Optional[int] = None
- is_completed: bool = Field(default=False)
- created_at: datetime = Field(default_factory=datetime.now)
-
- pages: List["SearchPageResult"] = Relationship(back_populates="keyword_task")
- items: List["SearchResultItem"] = Relationship(back_populates="keyword_task")
- class SearchPageResult(SQLModel, table=True):
- __table_args__ = (UniqueConstraint("keyword_id", "page_number", name="uq_keyword_page"),)
-
- id: Optional[int] = Field(default=None, primary_key=True)
- keyword_id: int = Field(foreign_key="keywordtask.id")
- keyword: str = Field(index=True)
- page_number: int
- results_count: int
- has_next_page: bool
- html_path: Optional[str] = None
- created_at: datetime = Field(default_factory=datetime.now)
-
- keyword_task: Optional[KeywordTask] = Relationship(back_populates="pages")
- items: List["SearchResultItem"] = Relationship(back_populates="search_page")
- class SearchResultItem(SQLModel, table=True):
- __table_args__ = (UniqueConstraint("url", "page_id", name="uq_url_page"),)
-
- id: Optional[int] = Field(default=None, primary_key=True)
- url: str
- title: Optional[str] = None
- content: Optional[str] = None
- html_path: Optional[str] = None
- keyword_id: int = Field(foreign_key="keywordtask.id")
- keyword: str = Field(index=True)
- page_id: int = Field(foreign_key="searchpageresult.id")
- created_at: datetime = Field(default_factory=datetime.now)
-
- keyword_task: Optional[KeywordTask] = Relationship(back_populates="items")
- search_page: Optional[SearchPageResult] = Relationship(back_populates="items")
- class VerificationItem(SQLModel, table=True):
- __table_args__ = (UniqueConstraint("result_item_id", name="uq_verification_item"),)
-
- id: Optional[int] = Field(default=None, primary_key=True)
- result_item_id: int = Field(foreign_key="searchresultitem.id")
- search_result_item: Optional[SearchResultItem] = Relationship(
- sa_relationship=relationship("SearchResultItem", lazy="joined")
- )
- verified: bool = Field(default=False)
- created_at: datetime = Field(default_factory=datetime.now)
- class SearchResultManager:
- def __init__(self, db_url: str = DB_URL):
- self.engine = create_engine(db_url)
- SQLModel.metadata.create_all(self.engine)
-
- def get_keyword_task(self, keyword: str) -> KeywordTask | None:
- with Session(self.engine) as session:
- return session.exec(
- select(KeywordTask)
- .where(KeywordTask.keyword == keyword)
- ).first()
-
- def delete_keyword_task(self, keyword: str):
- """删除关键词及其所有关联数据"""
- with Session(self.engine) as session:
- # 先获取关键词任务
- keyword_task = session.exec(
- select(KeywordTask)
- .where(KeywordTask.keyword == keyword)
- ).first()
-
- if keyword_task:
- # 删除关联的SearchResultItem
- session.execute(
- delete(SearchResultItem)
- .where(SearchResultItem.keyword_id == keyword_task.id)
- )
- # 删除关联的SearchPageResult
- session.execute(
- delete(SearchPageResult)
- .where(SearchPageResult.keyword_id == keyword_task.id)
- )
- # 删除KeywordTask
- session.delete(keyword_task)
- session.commit()
-
- def create_keyword_task(self, keyword: str) -> KeywordTask:
- with Session(self.engine) as session:
- # 先删除可能存在的旧数据(在process_keyword中已处理)
- task = KeywordTask(keyword=keyword)
- session.add(task)
- session.commit()
- session.refresh(task)
- return task
-
- def save_page_results(
- self,
- keyword: str,
- page_number: int,
- results_count: int,
- has_next_page: bool,
- html_path: Optional[Path] = None
- ) -> SearchPageResult | None:
- with Session(self.engine) as session:
- existing = session.exec(
- select(SearchPageResult)
- .where(SearchPageResult.keyword == keyword)
- .where(SearchPageResult.page_number == page_number)
- ).first()
- if existing:
- return existing
- keyword_task = self.get_keyword_task(keyword)
- if not keyword_task:
- raise ValueError("Keyword task must exist before saving page results")
- page_result = SearchPageResult(
- keyword_id=keyword_task.id,
- keyword=keyword,
- page_number=page_number,
- results_count=results_count,
- has_next_page=has_next_page,
- html_path=str(html_path) if html_path else None
- )
- session.add(page_result)
- session.commit()
- session.refresh(page_result)
- return page_result
-
- def save_result_items(
- self,
- keyword: str,
- page_id: int,
- items: List[SearchResultItem],
- html_path: Optional[Path] = None
- ) -> int:
- with Session(self.engine) as session:
- keyword_task = session.exec(
- select(KeywordTask)
- .where(KeywordTask.keyword == keyword)
- ).first()
-
- if not keyword_task:
- raise ValueError(f"Keyword task not found for keyword: {keyword}")
- new_items = []
- for item in items:
- exists = session.exec(
- select(SearchResultItem)
- .where(SearchResultItem.url == item.url)
- .where(SearchResultItem.page_id == page_id)
- ).first()
-
- if not exists:
- new_item = SearchResultItem(
- url=item.url,
- title=item.title,
- content=item.content,
- html_path=str(html_path) if html_path else None,
- keyword_id=keyword_task.id,
- keyword=keyword,
- page_id=page_id
- )
- new_items.append(new_item)
-
- session.add_all(new_items)
- session.commit()
- return new_items
-
- def mark_task_completed(self, keyword: str):
- with Session(self.engine) as session:
- task = self.get_keyword_task(keyword)
- if not task:
- raise ValueError(f"Keyword task {keyword} not found")
-
- # 使用窗口函数确保统计准确性
- total_results = session.scalar(
- select(func.sum(SearchPageResult.results_count))
- .where(SearchPageResult.keyword_id == task.id)
- .execution_options(compile_kwargs={"literal_binds": True})
- ) or 0
-
- task.is_completed = True
- task.total_results = total_results
- session.add(task)
- session.commit()
- session.refresh(task)
- return task
-
- def is_task_completed(self, keyword: str) -> bool:
- task = self.get_keyword_task(keyword)
- return task.is_completed if task else False
-
- def get_all_search_result_items(self) -> List[SearchResultItem]:
- """
- 获取数据库中所有的 SearchResultItem。
- """
- with Session(self.engine) as session:
- return session.exec(select(SearchResultItem)).all()
-
- def add_to_verification(self, result_item_id: int):
- """
- 将 SearchResultItem 添加到 VerificationItem 表中,避免重复添加。
- """
- with Session(self.engine) as session:
- exists = session.exec(
- select(VerificationItem)
- .where(VerificationItem.result_item_id == result_item_id)
- ).first()
- if not exists:
- verification_item = VerificationItem(result_item_id=result_item_id)
- session.add(verification_item)
- session.commit()
- session.refresh(verification_item)
- return verification_item
- return exists
|