search_result_db.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. from datetime import datetime
  2. from typing import Optional, List
  3. from sqlmodel import SQLModel, Field, Relationship, create_engine, Session, select, delete, func,distinct
  4. from sqlalchemy.orm import relationship
  5. from sqlalchemy import UniqueConstraint
  6. from sqlalchemy.sql import text
  7. from pathlib import Path
  8. from config.settings import DB_URL
  9. class KeywordTask(SQLModel, table=True):
  10. __table_args__ = (UniqueConstraint("keyword", name="uq_keyword"),)
  11. id: Optional[int] = Field(default=None, primary_key=True)
  12. keyword: str = Field(index=True, unique=True)
  13. total_results: Optional[int] = None
  14. is_completed: Optional[bool] = Field(default=False)
  15. fail_count: Optional[int] = 0
  16. created_at: Optional[datetime] = Field(default_factory=datetime.now)
  17. pages: List["SearchPageResult"] = Relationship(back_populates="keyword_task")
  18. items: List["SearchResultItem"] = Relationship(back_populates="keyword_task")
  19. class SearchPageResult(SQLModel, table=True):
  20. __table_args__ = (UniqueConstraint("keyword_id", "page_number", name="uq_keyword_page"),)
  21. id: Optional[int] = Field(default=None, primary_key=True)
  22. keyword_id: int = Field(foreign_key="keywordtask.id")
  23. keyword: str = Field(index=True)
  24. page_number: int
  25. results_count: int
  26. has_next_page: bool
  27. html_path: Optional[str] = None
  28. created_at: datetime = Field(default_factory=datetime.now)
  29. keyword_task: Optional[KeywordTask] = Relationship(back_populates="pages")
  30. items: List["SearchResultItem"] = Relationship(back_populates="search_page")
  31. class SearchResultItem(SQLModel, table=True):
  32. __table_args__ = (UniqueConstraint("url", "page_id", name="uq_url_page"),)
  33. id: Optional[int] = Field(default=None, primary_key=True)
  34. url: str
  35. title: Optional[str] = None
  36. content: Optional[str] = None
  37. html_path: Optional[str] = None
  38. keyword_id: int = Field(foreign_key="keywordtask.id")
  39. keyword: str = Field(index=True)
  40. page_id: int = Field(foreign_key="searchpageresult.id")
  41. created_at: datetime = Field(default_factory=datetime.now)
  42. keyword_task: Optional[KeywordTask] = Relationship(back_populates="items")
  43. search_page: Optional[SearchPageResult] = Relationship(back_populates="items")
  44. class VerificationItem(SQLModel, table=True):
  45. __table_args__ = (UniqueConstraint("result_item_id", name="uq_verification_item"),)
  46. id: Optional[int] = Field(default=None, primary_key=True)
  47. result_item_id: int = Field(foreign_key="searchresultitem.id")
  48. search_result_item: Optional[SearchResultItem] = Relationship(
  49. sa_relationship=relationship("SearchResultItem", lazy="joined")
  50. )
  51. verified: bool = Field(default=False)
  52. created_at: datetime = Field(default_factory=datetime.now)
  53. class SearchResultManager:
  54. def __init__(self, db_url: str = DB_URL):
  55. self.engine = create_engine(db_url)
  56. SQLModel.metadata.create_all(self.engine)
  57. def get_keyword_task(self, keyword: str) -> KeywordTask | None:
  58. with Session(self.engine) as session:
  59. return session.exec(
  60. select(KeywordTask)
  61. .where(KeywordTask.keyword == keyword)
  62. ).first()
  63. def get_uncompleted_keywords(self) -> list[str]:
  64. """从数据库获取已完成搜索但未完成爬取的关键词"""
  65. with Session(self.engine) as session:
  66. # 使用JOIN优化查询,避免子查询
  67. query = (
  68. select(distinct(KeywordTask.keyword))
  69. .where(KeywordTask.is_completed != True)
  70. )
  71. keywords = session.exec(query).all()
  72. return keywords
  73. def delete_keyword_task(self, keyword: str):
  74. """删除关键词及其所有关联数据"""
  75. with Session(self.engine) as session:
  76. # 先获取关键词任务
  77. keyword_task = session.exec(
  78. select(KeywordTask)
  79. .where(KeywordTask.keyword == keyword)
  80. ).first()
  81. if keyword_task:
  82. # 删除关联的SearchResultItem
  83. session.exec(
  84. delete(SearchResultItem)
  85. .where(SearchResultItem.keyword_id == keyword_task.id)
  86. )
  87. # 删除关联的SearchPageResult
  88. session.exec(
  89. delete(SearchPageResult)
  90. .where(SearchPageResult.keyword_id == keyword_task.id)
  91. )
  92. # 删除KeywordTask
  93. session.delete(keyword_task)
  94. session.commit()
  95. def create_keyword_task(self, keyword: str) -> KeywordTask:
  96. with Session(self.engine) as session:
  97. # 先删除可能存在的旧数据(在process_keyword中已处理)
  98. task = KeywordTask(keyword=keyword)
  99. session.add(task)
  100. session.commit()
  101. session.refresh(task)
  102. return task
  103. def save_page_results(
  104. self,
  105. keyword: str,
  106. page_number: int,
  107. results_count: int,
  108. has_next_page: bool,
  109. html_path: Optional[Path] = None
  110. ) -> SearchPageResult | None:
  111. with Session(self.engine) as session:
  112. existing = session.exec(
  113. select(SearchPageResult)
  114. .where(SearchPageResult.keyword == keyword)
  115. .where(SearchPageResult.page_number == page_number)
  116. ).first()
  117. if existing:
  118. return existing
  119. keyword_task = self.get_keyword_task(keyword)
  120. if not keyword_task:
  121. raise ValueError("Keyword task must exist before saving page results")
  122. page_result = SearchPageResult(
  123. keyword_id=keyword_task.id,
  124. keyword=keyword,
  125. page_number=page_number,
  126. results_count=results_count,
  127. has_next_page=has_next_page,
  128. html_path=str(html_path) if html_path else None
  129. )
  130. session.add(page_result)
  131. session.commit()
  132. session.refresh(page_result)
  133. return page_result
  134. def save_result_items(
  135. self,
  136. keyword: str,
  137. page_id: int,
  138. items: List[SearchResultItem],
  139. html_path: Optional[Path] = None
  140. ) -> int:
  141. with Session(self.engine) as session:
  142. keyword_task = session.exec(
  143. select(KeywordTask)
  144. .where(KeywordTask.keyword == keyword)
  145. ).first()
  146. if not keyword_task:
  147. raise ValueError(f"Keyword task not found for keyword: {keyword}")
  148. new_items = []
  149. for item in items:
  150. exists = session.exec(
  151. select(SearchResultItem)
  152. .where(SearchResultItem.url == item.url)
  153. .where(SearchResultItem.page_id == page_id)
  154. ).first()
  155. if not exists:
  156. new_item = SearchResultItem(
  157. url=item.url,
  158. title=item.title,
  159. content=item.content,
  160. html_path=str(html_path) if html_path else None,
  161. keyword_id=keyword_task.id,
  162. keyword=keyword,
  163. page_id=page_id
  164. )
  165. new_items.append(new_item)
  166. session.add_all(new_items)
  167. session.commit()
  168. return new_items
  169. def mark_task_completed(self, keyword: str):
  170. with Session(self.engine) as session:
  171. task = self.get_keyword_task(keyword)
  172. if not task:
  173. raise ValueError(f"Keyword task {keyword} not found")
  174. # 使用窗口函数确保统计准确性
  175. total_results = session.scalar(
  176. select(func.sum(SearchPageResult.results_count))
  177. .where(SearchPageResult.keyword_id == task.id)
  178. .execution_options(compile_kwargs={"literal_binds": True})
  179. ) or 0
  180. task.is_completed = True
  181. task.total_results = total_results
  182. session.add(task)
  183. session.commit()
  184. session.refresh(task)
  185. return task
  186. def is_task_completed(self, keyword: str) -> bool:
  187. task = self.get_keyword_task(keyword)
  188. return task.is_completed if task else False
  189. def get_all_search_result_items(self) -> List[SearchResultItem]:
  190. """
  191. 获取数据库中所有的 SearchResultItem。
  192. """
  193. with Session(self.engine) as session:
  194. return session.exec(select(SearchResultItem)).all()
  195. def get_task_statistics(self) -> dict:
  196. """获取任务统计信息"""
  197. with Session(self.engine) as session:
  198. total = session.scalar(select(func.count(KeywordTask.id)))
  199. completed = session.scalar(
  200. select(func.count(KeywordTask.id))
  201. .where(KeywordTask.is_completed == True)
  202. )
  203. return {
  204. "total_tasks": total or 0,
  205. "completed_tasks": completed or 0,
  206. "pending_tasks": (total or 0) - (completed or 0)
  207. }
  208. def add_to_verification(self, result_item_id: int):
  209. """
  210. 将 SearchResultItem 添加到 VerificationItem 表中,避免重复添加。
  211. """
  212. with Session(self.engine) as session:
  213. exists = session.exec(
  214. select(VerificationItem)
  215. .where(VerificationItem.result_item_id == result_item_id)
  216. ).first()
  217. if not exists:
  218. verification_item = VerificationItem(result_item_id=result_item_id)
  219. session.add(verification_item)
  220. session.commit()
  221. session.refresh(verification_item)
  222. return verification_item
  223. return exists