Procházet zdrojové kódy

完成批量提交的测试

mrh před 9 měsíci
rodič
revize
011e13244b

+ 1 - 1
tests/mytest/redis_celery_t.py

@@ -1,7 +1,7 @@
 from mylib.logu import logger
 from worker.celery.app import app as celery_app
 from worker.celery.crawl_client import submit_page_crawl_tasks
-from worker.celery.client import get_uncompleted_keywords,submit_tasks
+from worker.celery.search_client import get_uncompleted_keywords,submit_tasks
 
 
 # 提交任务到指定的队列

+ 1 - 1
tests/mytest/t.py

@@ -11,7 +11,7 @@ import sys
 from pathlib import Path
 sys.path.append(Path(r'G:\code\upwork\zhang_crawl_bio'))
 os.environ['DB_URL'] = 'sqlite:///' + str(Path(r'G:\code\upwork\zhang_crawl_bio\output\search_results copy.db'))
-from worker.celery.client import get_uncompleted_keywords
+from worker.celery.search_client import get_uncompleted_keywords
 import yaml
 import socket
 

+ 123 - 0
ui/backend/tests/mytests/t_celery.py

@@ -0,0 +1,123 @@
+from celery import Celery
+from typing import List, Dict, Optional
+import asyncio
+from pathlib import Path
+import sys
+import time
+sys.path.append(str(Path(r'G:\code\upwork\zhang_crawl_bio\ui\backend')))
+from utils.process_mgr import ProcessManager
+from utils.logu import get_logger
+logger = get_logger('mytests', file=True)
+
+app = Celery('search_worker', broker='redis://127.0.0.1:6379/1')
+import redis
+redis_client = redis.Redis(host='127.0.0.1', port=6379, db=1)
+
+def get_pending_task_count() -> int:
+    """
+    获取当前待处理任务的数量
+    """
+    # 获取 Celery 使用的默认队列名称
+    queue_name = app.conf.task_default_queue
+    logger.info(f"{queue_name}")
+    # 获取队列中的任务数量
+    task_count = redis_client.llen(queue_name)
+    logger.info(f"{task_count}")
+    logger.info(f"{app.conf.task_routes}")
+    return task_count
+def  submit_tasks(keywords: List[str], browser_config: Optional[Dict] = None):
+    """提交所有关键词任务"""
+    for keyword in keywords:
+        try:
+            task_data = {
+                'keyword': keyword.strip(),
+                'max_result_items': 200,
+                'skip_existing': True,
+                'browser_config': browser_config or {}
+            }
+            result = app.send_task('search_worker.drission_search', kwargs=task_data, queue='search_queue')
+            logger.info(f"任务已提交: {keyword} (任务ID: {result.id})")
+        except Exception as e:
+            logger.error(f"提交任务失败 [{keyword}]: {str(e)}")
+
+def submit_all_tasks(browser_config: Optional[Dict] = None):
+    """提交所有关键词任务
+    def search_all_uncompleted_keywords_task(max_result_items: int = 200, skip_existing: bool = True, browser_config: dict = {}, proxy_pool: List[str] = None):
+    """
+    clear_specific_queue('search_queue')
+    task_data = {
+        'max_result_items': 1,
+        'skip_existing': True,
+        'browser_config': browser_config or {},
+        'proxy_pool': None,
+        'dry_run': True
+
+    }
+    result = app.send_task('search_worker.search_all_uncompleted_keywords', kwargs=task_data, queue='search_queue')
+    logger.info(f"任务已提交: (任务ID: {result.id})")
+
+def get_queue():
+    keys = redis_client.keys('*')
+    logger.info(f"{keys}")
+    # 查看队列中的任务
+    queue_tasks = redis_client.lrange('search_queue', 0, -1)
+    logger.info(f"len search_queue {len(queue_tasks)}")
+    queue_tasks = redis_client.lrange('crawl_queue', 0, -1)
+    logger.info(f"len crawl_queue {len(queue_tasks)}")
+    default_queue_tasks = redis_client.lrange('default', 0, -1)
+    logger.info(f"len default {len(default_queue_tasks)}")
+    return
+    # for task in queue_tasks:
+    #     print(task.decode('utf-8'))
+
+    # 清空队列
+    # redis_client.delete('search_queue')
+
+# 查看队列中的任务
+def inspect_queue(queue_name: str = 'search_queue'):
+    # 获取 Celery 的 inspect 对象
+    inspector = app.control.inspect()
+
+    # 查看指定队列中的任务
+    reserved_tasks = inspector.reserved()
+    if reserved_tasks:
+        print(f"队列 {queue_name} 中的任务:")
+        for worker, tasks in reserved_tasks.items():
+            for task in tasks:
+                print(f"任务 ID: {task['id']}, 任务名称: {task['name']}, 参数: {task['kwargs']}")
+    else:
+        print(f"队列 {queue_name} 中没有任务。")
+def clear_all_queues():
+    """
+    清空 Redis 中所有队列的信息
+    """
+    try:
+        # 清空当前数据库中的所有数据
+        redis_client.flushdb()
+        logger.info("所有队列信息已清空")
+    except Exception as e:
+        logger.error(f"清空队列信息失败: {str(e)}")
+
+def clear_specific_queue(queue_name: str):
+    """
+    清空指定的队列
+    """
+    try:
+        # 删除指定的队列
+        redis_client.delete(queue_name)
+        logger.info(f"队列 {queue_name} 已清空")
+    except Exception as e:
+        logger.error(f"清空队列 {queue_name} 失败: {str(e)}")
+def main():
+    # submit_tasks(['test', 'test2', 'test3', 'test4', 'test5'])
+    # clear_all_queues()
+    # clear_specific_queue('default')
+    # submit_all_tasks()
+    # 连接到 Redis
+    get_queue()
+    # inspect_queue()
+    # get_pending_task_count()
+
+
+if __name__ == "__main__":
+    main()

+ 1 - 1
worker/api/worker_router.py

@@ -8,7 +8,7 @@ from mylib.drission_page import load_chrome_from_ini
 from mylib.logu import logger
 from worker.celery.app import app as celery_app
 from worker.celery.crawl_client import submit_page_crawl_tasks
-from worker.celery.client import get_uncompleted_keywords
+from worker.celery.search_client import get_uncompleted_keywords
 
 
 app = APIRouter()

+ 6 - 1
worker/celery/celeryconfig.py

@@ -21,4 +21,9 @@ worker_send_task_events=True
 worker_pool = 'solo'
 # worker_pool = 'eventlet'
 
-broker_connection_retry_on_startup=True
+broker_connection_retry_on_startup=True
+task_routes = {
+    'worker.celery.tasks.search_task': {'queue': 'search_queue'},
+    'worker.celery.tasks.crawl_task': {'queue': 'crawl_queue'},
+    'worker.celery.tasks.html_convert_tasks': {'queue': 'convert_queue'},
+}

+ 1 - 1
worker/celery/html_convert_tasks.py

@@ -7,7 +7,7 @@ from worker.search_engine.valid_google_search import ValidSearchResult
 
 logger = get_logger('pandoc_tasks')
 
-@current_app.task(name='html_convert_tasks_worker.convert_single_result')
+@current_app.task(name='html_convert_tasks.convert_single_result')
 def convert_single_result_task(result_id: int):
     """
     Celery task to convert a single SearchResultItem using Pandoc.

+ 0 - 0
worker/celery/client.py → worker/celery/search_client.py


+ 46 - 2
worker/celery/search_tasks.py

@@ -1,4 +1,5 @@
 import random
+from typing import List
 from worker.celery.app import app
 from worker.search_engine.drission_google_search import search_keyword_drission
 from mylib.logu import logger
@@ -7,9 +8,47 @@ import asyncio
 import httpx
 from utils.proxy_pool import get_random_proxy
 from config.settings import PROXIES
+from worker.search_engine.search_result_db import SearchResultManager, KeywordTask, SearchResultItem, SearchPageResult
+from sqlmodel import select, Session, exists, distinct
+from celery import group
+import redis
+# redis_client = redis.Redis(host='127.0.0.1', port=6379, db=1)
+
+
+@app.task(name='search_worker.search_all_uncompleted_keywords')
+def search_all_uncompleted_keywords_task(max_result_items: int = 200, skip_existing: bool = True, browser_config: dict = {}, proxy_pool: List[str] = None, dry_run:bool=False):
+    """异步任务:搜索所有未完成的关键词"""
+    try:
+        # redis_client.delete('search_queue')
+        # logger.info(f"删除旧有search_queue队列")
+        # 获取所有未完成的关键词
+        manager = SearchResultManager()
+        uncompleted_keywords = manager.get_uncompleted_keywords()
+        
+        if not uncompleted_keywords:
+            logger.info("没有未完成的关键词需要搜索。")
+            return {"status": "success", "message": "没有未完成的关键词需要搜索。"}
+        
+        logger.info(f"找到 {len(uncompleted_keywords)} 个未完成的关键词,开始批量搜索...")
+        
+        # 创建任务组,每个关键词对应一个 drission_search_task
+        task_group = group([
+            drission_search_task.s(
+                keyword, max_result_items, skip_existing, browser_config, proxy_pool, dry_run
+                ).set(queue='search_queue')
+            for keyword in uncompleted_keywords
+        ])
+        
+        # 执行任务组
+        result = task_group.apply_async()
+        
+        return {"status": "success", "task_id": result.id, "message": f"已启动 {len(uncompleted_keywords)} 个关键词搜索任务。"}
+    except Exception as e:
+        logger.error(f"批量搜索任务失败: {str(e)}")
+        raise
 
 @app.task(name='search_worker.drission_search')
-def drission_search_task(keyword: str, max_result_items: int=200, skip_existing: bool=True, browser_config: dict={}, proxy_pool:list[str]=None):
+def drission_search_task(keyword: str, max_result_items: int=200, skip_existing: bool=True, browser_config: dict={}, proxy_pool:list[str]=None, dry_run:bool=False):
     """异步关键词搜索任务"""
     if sys.platform == 'win32':
         asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
@@ -20,7 +59,12 @@ def drission_search_task(keyword: str, max_result_items: int=200, skip_existing:
                 browser_config.update({'proxy': get_random_proxy()})
             logger.info(f"browser_config: {browser_config}")
             logger.info(f"开始处理关键词搜索任务: {keyword}")
-            result = search_keyword_drission(keyword, max_result_items=max_result_items, skip_existing=skip_existing, browser_config=browser_config)
+            if dry_run:
+                await asyncio.sleep(3)
+                result = []
+            else:
+                result = await search_keyword_drission(
+                    keyword, max_result_items=max_result_items, skip_existing=skip_existing, browser_config=browser_config)
             return {"keyword": keyword, "result": result}
         except Exception as e:
             logger.error(f"关键词搜索任务失败: {str(e)}")

+ 14 - 2
worker/search_engine/search_result_db.py

@@ -1,6 +1,6 @@
 from datetime import datetime
 from typing import Optional, List
-from sqlmodel import SQLModel, Field, Relationship, create_engine, Session, select, delete, func
+from sqlmodel import SQLModel, Field, Relationship, create_engine, Session, select, delete, func,distinct
 from sqlalchemy.orm import relationship
 from sqlalchemy import UniqueConstraint
 from sqlalchemy.sql import text
@@ -73,7 +73,19 @@ class SearchResultManager:
                 select(KeywordTask)
                 .where(KeywordTask.keyword == keyword)
             ).first()
-    
+
+    def get_uncompleted_keywords(self) -> list[str]:
+        """从数据库获取已完成搜索但未完成爬取的关键词"""
+        with Session(self.engine) as session:
+            # 使用JOIN优化查询,避免子查询
+            query = (
+                select(distinct(KeywordTask.keyword))
+                .where(KeywordTask.is_completed != True)
+            )
+            keywords = session.exec(query).all()
+            return keywords
+
+
     def delete_keyword_task(self, keyword: str):
         """删除关键词及其所有关联数据"""
         with Session(self.engine) as session: