Browse Source

完成数据库存储

mrh 10 months ago
parent
commit
bf847093e0
4 changed files with 221 additions and 50 deletions
  1. 3 1
      config/settings.py
  2. 144 0
      database/search_result_db.py
  3. 2 4
      database/sqlite_engine.py
  4. 72 45
      worker/search_engine/google_search.py

+ 3 - 1
config/settings.py

@@ -5,4 +5,6 @@ CONFIG_DIR = WORK_DIR / "config" / "conf"
 GOOGLE_SEARCH_DIR = OUTPUT_DIR / "google_search"
 
 LOG_LEVEL='info'
-LOG_DIR = OUTPUT_DIR / "logs"
+LOG_DIR = OUTPUT_DIR / "logs"
+
+DB_URL = f"sqlite:///{OUTPUT_DIR}/search_results.db"

+ 144 - 0
database/search_result_db.py

@@ -0,0 +1,144 @@
+from datetime import datetime
+from typing import Optional, List
+from sqlmodel import SQLModel, Field, Relationship, create_engine, Session, select
+from pathlib import Path
+from config.settings import DB_URL
+
+class KeywordTask(SQLModel, table=True):
+    id: Optional[int] = Field(default=None, primary_key=True)
+    keyword: str = Field(index=True, unique=True)
+    total_results: Optional[int] = None
+    is_completed: bool = Field(default=False)
+    created_at: datetime = Field(default_factory=datetime.now)
+    
+    pages: List["SearchPageResult"] = Relationship(back_populates="keyword_task")
+    items: List["SearchResultItem"] = Relationship(back_populates="keyword_task")
+
+class SearchPageResult(SQLModel, table=True):
+    id: Optional[int] = Field(default=None, primary_key=True)
+    keyword_id: int = Field(foreign_key="keywordtask.id")
+    keyword: str = Field(index=True)
+    page_number: int
+    results_count: int
+    has_next_page: bool
+    html_path: Optional[str] = None
+    created_at: datetime = Field(default_factory=datetime.now)
+    
+    keyword_task: Optional[KeywordTask] = Relationship(back_populates="pages")
+    items: List["SearchResultItem"] = Relationship(back_populates="search_page")
+
+class SearchResultItem(SQLModel, table=True):
+    id: Optional[int] = Field(default=None, primary_key=True)
+    url: str
+    title: Optional[str] = None
+    content: Optional[str] = None
+    html_path: Optional[str] = None
+    keyword_id: int = Field(foreign_key="keywordtask.id")
+    keyword: str = Field(index=True)
+    page_id: int = Field(foreign_key="searchpageresult.id")
+    created_at: datetime = Field(default_factory=datetime.now)
+    
+    keyword_task: Optional[KeywordTask] = Relationship(back_populates="items")
+    search_page: Optional[SearchPageResult] = Relationship(back_populates="items")
+
+class SearchResultManager:
+    def __init__(self, db_url: str = DB_URL):
+        self.engine = create_engine(db_url)
+        SQLModel.metadata.create_all(self.engine)
+    
+    def create_keyword_task(self, keyword: str) -> KeywordTask:
+        with Session(self.engine) as session:
+            # 创建或获取已有任务
+            task = session.exec(
+                select(KeywordTask)
+                .where(KeywordTask.keyword == keyword)
+            ).first()
+            
+            if not task:
+                task = KeywordTask(keyword=keyword)
+                session.add(task)
+                session.commit()
+                session.refresh(task)
+            return task
+    
+    def save_page_results(
+        self,
+        keyword: str,
+        page_number: int,
+        results_count: int,
+        has_next_page: bool,
+        html_path: Optional[Path] = None
+    ) -> SearchPageResult:
+        with Session(self.engine) as session:
+            # 获取关联的关键词任务
+            keyword_task = session.exec(
+                select(KeywordTask)
+                .where(KeywordTask.keyword == keyword)
+            ).first()
+            
+            if not keyword_task:
+                keyword_task = self.create_keyword_task(keyword)
+            
+            # 创建分页结果记录
+            page_result = SearchPageResult(
+                keyword_id=keyword_task.id,
+                keyword=keyword,
+                page_number=page_number,
+                results_count=results_count,
+                has_next_page=has_next_page,
+                html_path=str(html_path) if html_path else None
+            )
+            
+            session.add(page_result)
+            session.commit()
+            session.refresh(page_result)
+            return page_result
+    
+    def save_result_items(
+        self,
+        keyword: str,
+        page_id: int,
+        items: List[SearchResultItem],
+        html_dir: Optional[Path] = None
+    ):
+        with Session(self.engine) as session:
+            # 获取关联的关键词任务
+            keyword_task = session.exec(
+                select(KeywordTask)
+                .where(KeywordTask.keyword == keyword)
+            ).first()
+            
+            if not keyword_task:
+                keyword_task = self.create_keyword_task(keyword)
+            
+            # 批量保存结果项
+            for item in items:
+                html_path = None
+                if html_dir and item.url:
+                    html_path = html_dir / f"{item.url.replace('://', '_').replace('/', '_')}.html"
+                
+                result_item = SearchResultItem(
+                    **item.dict(),
+                    keyword_id=keyword_task.id,
+                    keyword=keyword,
+                    page_id=page_id,
+                    html_path=str(html_path) if html_path else None
+                )
+                session.add(result_item)
+            
+            session.commit()
+    
+    def mark_task_completed(self, keyword: str, total_results: int):
+        with Session(self.engine) as session:
+            task = session.exec(
+                select(KeywordTask)
+                .where(KeywordTask.keyword == keyword)
+            ).first()
+            
+            if task:
+                task.is_completed = True
+                task.total_results = total_results
+                session.add(task)
+                session.commit()
+                session.refresh(task)
+            return task

+ 2 - 4
database/sqlite_engine.py

@@ -1,13 +1,11 @@
 from sqlite3 import connect
-from config.settings import OUTPUT_DIR, WORK_DIR
+from config.settings import OUTPUT_DIR, WORK_DIR, DB_URL
 from datetime import datetime
 from sqlmodel import Field, SQLModel, create_engine, Session, select
 from sqlalchemy import text  # Add this import
 
 
-sqlite_file_name = OUTPUT_DIR / "database.db"
-sqlite_url = f"sqlite:///{sqlite_file_name}"
-engine = create_engine(sqlite_url, echo=False)
+engine = create_engine(DB_URL, echo=False)
 
 
 def create_db_and_tables():

+ 72 - 45
worker/search_engine/google_search.py

@@ -11,6 +11,8 @@ from playwright.sync_api import sync_playwright
 from mylib.logu import logger
 from mylib.base import save_to_file
 from config.settings import OUTPUT_DIR
+from database.search_result_db import SearchResultManager, SearchResultItem  # 新增导入
+
 class SearchResultItem(BaseModel):
     url: str | None = None
     title: str | None = None
@@ -26,11 +28,11 @@ class SearchResultEle(BaseModel):
     current_page: int | None = None
     results: SearchResult | None = None
 
-# ------------------- Search Engine Implementation -------------------
 class GoogleSearchHandler():
     """搜索引擎专用处理器(通过CDP连接)"""
-    def __init__(self, page:Page):
+    def __init__(self, page: Page):
         self.page = page
+        self.db_manager = SearchResultManager()  # 初始化数据库管理器
         
     async def goto_home_page(self):
         url = "https://www.google.com"
@@ -43,23 +45,19 @@ class GoogleSearchHandler():
             await self.goto_home_page()
             await self.page.fill('textarea[aria-label="Search"]', query, timeout=10000)
             await self.page.press('textarea[aria-label="Search"]', 'Enter')
-            # 等待加载完成
             await self.page.wait_for_load_state(state='load', timeout=10000)
             return await self.page.content()
         except Exception as e:
             logger.exception(f"Search failed: {str(e)}")
             return {"status": "error", "message": str(e)}
+
     def get_current_page_num(self) -> int:
         if '/search?q=' in self.page.url:
-            # 从 self.page.url 查找 &start=20 ,如果存在,则返回 20/10 + 1 ,否则返回 1
             match = re.search(r'&start=(\d+)', self.page.url)
-            if match:
-                return int(match.group(1)) // 10 + 1
-            else:
-                return 1
+            return int(match.group(1)) // 10 + 1 if match else 1
         raise ValueError("Invalid URL")
 
-    def get_search_result_ele(self, html_content:str):
+    def get_search_result_ele(self, html_content: str):
         res = SearchResultEle(
             search_div=None,
             next_ele=None,
@@ -74,47 +72,75 @@ class GoogleSearchHandler():
         res.search_div = bool(search_div)
         res.next_ele = bool(next_ele)
 
-        if search_div:
-            result_list = search_div.xpath('//*[@data-snc]')
-            logger.info(f"result_list {len(result_list)}")
-            
-            search_res = SearchResult(total_count=len(result_list))
-            
-            for result_item in result_list:
-                if len(result_item.children) < 2:
-                    continue
-                    
-                result = SearchResultItem()
-                title_ele = result_item.children[0]
-                if title_ele:
-                    result.url = title_ele.xpath_first('.//a/@href')
-                    result.title = title_ele.xpath_first('.//h3/text()')
-
-                content_ele = result_item.children[1]
-                if content_ele:
-                    content_list = content_ele.xpath('.//span/text()')
-                    result.content = ''.join(content_list) if content_list else None
-
-                if any([result.url, result.title, result.content]):
-                    search_res.results.append(result)
-            
-            res.results = search_res
+        if not search_div:
+            return res
+
+        result_list = search_div.xpath('//*[@data-snc]')
+        logger.info(f"result_list {len(result_list)}")
+        
+        search_res = SearchResult(total_count=len(result_list))
+        
+        for result_item in result_list:
+            if len(result_item.children) < 2:
+                continue
+                
+            result = SearchResultItem()
+            title_ele = result_item.children[0]
+            if title_ele:
+                result.url = title_ele.xpath_first('.//a/@href')
+                result.title = title_ele.xpath_first('.//h3/text()')
+
+            content_ele = result_item.children[1]
+            if content_ele:
+                content_list = content_ele.xpath('.//span/text()')
+                result.content = ''.join(content_list) if content_list else None
 
+            if any([result.url, result.title, result.content]):
+                search_res.results.append(result)
+        
+        res.results = search_res
         return res
+
+    async def save_search_results(self, keyword: str, html_content: str, html_dir: Path):
+        """保存搜索结果到数据库"""
+        result_ele = self.get_search_result_ele(html_content)
+        html_path = save_to_file(html_content, html_dir / f"page_{result_ele.current_page}.html")
+
+        # 保存分页结果
+        page_result = self.db_manager.save_page_results(
+            keyword=keyword,
+            page_number=result_ele.current_page,
+            results_count=result_ele.results.total_count if result_ele.results else 0,
+            has_next_page=bool(result_ele.next_ele),
+            html_path=html_path
+        )
+
+        # 保存具体结果项
+        if result_ele.results and result_ele.results.results:
+            self.db_manager.save_result_items(
+                keyword=keyword,
+                page_id=page_result.id,
+                items=result_ele.results.results,
+                html_dir=html_dir
+            )
+
+        return page_result
+
 async def aio_main(config: BrowserConfig = BrowserConfig()):
     try:
         core = await BrowserCore.get_instance(config)
         search_handler = GoogleSearchHandler(core.page)
+        keyword = 'Acampe carinata essential oil'
+        html_dir = OUTPUT_DIR / 'results' / keyword.replace(' ', '_')
+        html_dir.mkdir(parents=True, exist_ok=True)
+
+        # 创建关键词任务
+        search_handler.db_manager.create_keyword_task(keyword)
+        
+        # 执行搜索并保存结果
+        content = await search_handler.search(keyword)
+        await search_handler.save_search_results(keyword, content, html_dir)
         
-        # 测试搜索功能
-        content = await search_handler.search('Acampe carinata essential oil')
-        save_path = save_to_file(content, OUTPUT_DIR /'analyze'/ 'test.html')
-        logger.info(f"save_path {save_path}")
-        logger.info(f"当前页面: {search_handler.page.url}")
-        res = search_handler.get_search_result_ele(content)
-        # 漂亮输出
-        logger.info(f"{json.dumps(res.dict(), indent=4, ensure_ascii=False)}")
-        # html_save = 
         # 保持连接活跃
         while True:
             await asyncio.sleep(5)
@@ -131,14 +157,15 @@ def connet_ws():
         match = re.search(ws_url_pattern, content)
         ws_url = match.group(0)
         browser = p.firefox.connect(ws_url)
-        # browser = p.firefox.connect('ws://localhost:11250/f75c64655c4913c727feff7cf8e6d242')
         page = browser.new_page()
         print(page.url)
     return
 
 def analyze():
     html_file = Path(r"K:\code\upwork\zhang_crawl_bio\output\analyze\test.html")
-    search_handler = GoogleSearchHandler(None)
+    class TestPage:
+        url = html_file.as_uri()
+    search_handler = GoogleSearchHandler(page=TestPage())
     res = search_handler.get_search_result_ele(html_file.read_text())
     logger.info(f"{json.dumps(res.model_dump(), indent=4, ensure_ascii=False)}")