search_result_db.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. from datetime import datetime
  2. from typing import Optional, List
  3. from sqlmodel import SQLModel, Field, Relationship, create_engine, Session, select, delete, func,distinct
  4. from sqlalchemy.orm import relationship
  5. from sqlalchemy import UniqueConstraint
  6. from sqlalchemy.sql import text
  7. from pathlib import Path
  8. from config.settings import DB_URL
  9. class KeywordTask(SQLModel, table=True):
  10. __table_args__ = (UniqueConstraint("keyword", name="uq_keyword"),)
  11. id: Optional[int] = Field(default=None, primary_key=True)
  12. keyword: str = Field(index=True, unique=True)
  13. total_results: Optional[int] = None
  14. is_completed: Optional[bool] = Field(default=False)
  15. fail_count: Optional[int] = 0
  16. created_at: Optional[datetime] = Field(default_factory=datetime.now)
  17. pages: List["SearchPageResult"] = Relationship(back_populates="keyword_task")
  18. items: List["SearchResultItem"] = Relationship(back_populates="keyword_task")
  19. class SearchPageResult(SQLModel, table=True):
  20. __table_args__ = (UniqueConstraint("keyword_id", "page_number", name="uq_keyword_page"),)
  21. id: Optional[int] = Field(default=None, primary_key=True)
  22. keyword_id: int = Field(foreign_key="keywordtask.id")
  23. keyword: str = Field(index=True)
  24. page_number: int
  25. results_count: int
  26. has_next_page: bool
  27. html_path: Optional[str] = None
  28. created_at: datetime = Field(default_factory=datetime.now)
  29. keyword_task: Optional[KeywordTask] = Relationship(back_populates="pages")
  30. items: List["SearchResultItem"] = Relationship(back_populates="search_page")
  31. class SearchResultItem(SQLModel, table=True):
  32. __table_args__ = (UniqueConstraint("url", "page_id", name="uq_url_page"),)
  33. id: Optional[int] = Field(default=None, primary_key=True)
  34. url: str
  35. title: Optional[str] = None
  36. content: Optional[str] = None
  37. content_type: Optional[str] = None
  38. save_path: Optional[str] = None
  39. markdown_path: Optional[str] = None
  40. keyword_id: int = Field(foreign_key="keywordtask.id")
  41. keyword: str = Field(index=True)
  42. page_id: int = Field(foreign_key="searchpageresult.id")
  43. created_at: datetime = Field(default_factory=datetime.now)
  44. keyword_task: Optional[KeywordTask] = Relationship(back_populates="items")
  45. search_page: Optional[SearchPageResult] = Relationship(back_populates="items")
  46. class VerificationItem(SQLModel, table=True):
  47. __table_args__ = (UniqueConstraint("result_item_id", name="uq_verification_item"),)
  48. id: Optional[int] = Field(default=None, primary_key=True)
  49. result_item_id: int = Field(foreign_key="searchresultitem.id")
  50. search_result_item: Optional[SearchResultItem] = Relationship(
  51. sa_relationship=relationship("SearchResultItem", lazy="joined")
  52. )
  53. verified: bool = Field(default=False)
  54. created_at: datetime = Field(default_factory=datetime.now)
  55. class SearchResultManager:
  56. def __init__(self, db_url: str = DB_URL):
  57. self.engine = create_engine(db_url)
  58. SQLModel.metadata.create_all(self.engine)
  59. def get_keyword_task(self, keyword: str) -> KeywordTask | None:
  60. with Session(self.engine) as session:
  61. return session.exec(
  62. select(KeywordTask)
  63. .where(KeywordTask.keyword == keyword)
  64. ).first()
  65. def get_uncompleted_keywords(self) -> list[str]:
  66. """从数据库获取已完成搜索但未完成爬取的关键词"""
  67. with Session(self.engine) as session:
  68. # 使用JOIN优化查询,避免子查询
  69. query = (
  70. select(distinct(KeywordTask.keyword))
  71. .where(KeywordTask.is_completed != True)
  72. )
  73. keywords = session.exec(query).all()
  74. return keywords
  75. def delete_keyword_task(self, keyword: str):
  76. """删除关键词及其所有关联数据"""
  77. ret = None
  78. with Session(self.engine) as session:
  79. # 先获取关键词任务
  80. keyword_task = session.exec(
  81. select(KeywordTask)
  82. .where(KeywordTask.keyword == keyword)
  83. ).first()
  84. if keyword_task:
  85. ret = keyword_task
  86. # 删除关联的SearchResultItem
  87. session.exec(
  88. delete(SearchResultItem)
  89. .where(SearchResultItem.keyword_id == keyword_task.id)
  90. )
  91. # 删除关联的SearchPageResult
  92. session.exec(
  93. delete(SearchPageResult)
  94. .where(SearchPageResult.keyword_id == keyword_task.id)
  95. )
  96. # 删除KeywordTask
  97. session.delete(keyword_task)
  98. session.commit()
  99. return ret
  100. def create_keyword_task(self, keyword: str) -> KeywordTask:
  101. with Session(self.engine) as session:
  102. # 先删除可能存在的旧数据(在process_keyword中已处理)
  103. task = KeywordTask(keyword=keyword)
  104. session.add(task)
  105. session.commit()
  106. session.refresh(task)
  107. return task
  108. def save_page_results(
  109. self,
  110. keyword: str,
  111. page_number: int,
  112. results_count: int,
  113. has_next_page: bool,
  114. html_path: Optional[Path] = None
  115. ) -> SearchPageResult | None:
  116. with Session(self.engine) as session:
  117. existing = session.exec(
  118. select(SearchPageResult)
  119. .where(SearchPageResult.keyword == keyword)
  120. .where(SearchPageResult.page_number == page_number)
  121. ).first()
  122. if existing:
  123. return existing
  124. keyword_task = self.get_keyword_task(keyword)
  125. if not keyword_task:
  126. raise ValueError("Keyword task must exist before saving page results")
  127. page_result = SearchPageResult(
  128. keyword_id=keyword_task.id,
  129. keyword=keyword,
  130. page_number=page_number,
  131. results_count=results_count,
  132. has_next_page=has_next_page,
  133. html_path=str(html_path) if html_path else None
  134. )
  135. session.add(page_result)
  136. session.commit()
  137. session.refresh(page_result)
  138. return page_result
  139. def save_result_items(
  140. self,
  141. keyword: str,
  142. page_id: int,
  143. items: List[SearchResultItem],
  144. html_path: Optional[Path] = None
  145. ) -> int:
  146. with Session(self.engine) as session:
  147. keyword_task = session.exec(
  148. select(KeywordTask)
  149. .where(KeywordTask.keyword == keyword)
  150. ).first()
  151. if not keyword_task:
  152. raise ValueError(f"Keyword task not found for keyword: {keyword}")
  153. new_items = []
  154. for item in items:
  155. exists = session.exec(
  156. select(SearchResultItem)
  157. .where(SearchResultItem.url == item.url)
  158. .where(SearchResultItem.page_id == page_id)
  159. ).first()
  160. if not exists:
  161. new_item = SearchResultItem(
  162. url=item.url,
  163. title=item.title,
  164. content=item.content,
  165. save_path=str(html_path) if html_path else None,
  166. keyword_id=keyword_task.id,
  167. keyword=keyword,
  168. page_id=page_id
  169. )
  170. new_items.append(new_item)
  171. session.add_all(new_items)
  172. session.commit()
  173. return new_items
  174. def mark_task_completed(self, keyword: str):
  175. with Session(self.engine) as session:
  176. task = self.get_keyword_task(keyword)
  177. if not task:
  178. raise ValueError(f"Keyword task {keyword} not found")
  179. # 使用窗口函数确保统计准确性
  180. total_results = session.scalar(
  181. select(func.sum(SearchPageResult.results_count))
  182. .where(SearchPageResult.keyword_id == task.id)
  183. .execution_options(compile_kwargs={"literal_binds": True})
  184. ) or 0
  185. task.is_completed = True
  186. task.total_results = total_results
  187. session.add(task)
  188. session.commit()
  189. session.refresh(task)
  190. return task
  191. def is_task_completed(self, keyword: str) -> bool:
  192. task = self.get_keyword_task(keyword)
  193. return task.is_completed if task else False
  194. def get_all_search_result_items(self) -> List[SearchResultItem]:
  195. """
  196. 获取数据库中所有的 SearchResultItem。
  197. """
  198. with Session(self.engine) as session:
  199. return session.exec(select(SearchResultItem)).all()
  200. def get_pages_with_unprocessed_urls(self) -> list[int]:
  201. """获取包含未处理URL的页面ID"""
  202. with Session(self.engine) as session:
  203. query = (
  204. select(distinct(SearchPageResult.id))
  205. .join(SearchResultItem, SearchPageResult.id == SearchResultItem.page_id)
  206. .where(SearchResultItem.save_path.is_(None))
  207. )
  208. page_ids = session.exec(query).all()
  209. return page_ids
  210. def get_task_statistics(self) -> dict:
  211. """获取任务统计信息"""
  212. with Session(self.engine) as session:
  213. total = session.scalar(select(func.count(KeywordTask.id)))
  214. completed = session.scalar(
  215. select(func.count(KeywordTask.id))
  216. .where(KeywordTask.is_completed == True)
  217. )
  218. return {
  219. "total_tasks": total or 0,
  220. "completed_tasks": completed or 0,
  221. "pending_tasks": (total or 0) - (completed or 0)
  222. }
  223. def add_to_verification(self, result_item_id: int):
  224. """
  225. 将 SearchResultItem 添加到 VerificationItem 表中,避免重复添加。
  226. """
  227. with Session(self.engine) as session:
  228. exists = session.exec(
  229. select(VerificationItem)
  230. .where(VerificationItem.result_item_id == result_item_id)
  231. ).first()
  232. if not exists:
  233. verification_item = VerificationItem(result_item_id=result_item_id)
  234. session.add(verification_item)
  235. session.commit()
  236. session.refresh(verification_item)
  237. return verification_item
  238. return exists
  239. def get_complete_search_result_items(self) -> list[SearchResultItem]:
  240. """Get all successful search result items"""
  241. with Session(self.engine) as session:
  242. return session.exec(
  243. select(SearchResultItem)
  244. .where(SearchResultItem.save_path.is_not(None))
  245. ).all()
  246. def get_uncomplete_search_result_items(self) -> list[SearchResultItem]:
  247. """Get all unsuccessful search result items"""
  248. with Session(self.engine) as session:
  249. return session.exec(
  250. select(SearchResultItem)
  251. .where(SearchResultItem.save_path.is_(None))
  252. ).all()
  253. def add_or_update_search_result_item(self, search_result_item: SearchResultItem):
  254. with Session(self.engine) as session:
  255. session.add(search_result_item)
  256. session.commit()
  257. session.refresh(search_result_item)
  258. return search_result_item
  259. db_manager = SearchResultManager()