|
|
@@ -6,6 +6,7 @@ import httpx
|
|
|
from pydantic import BaseModel,field_validator,Field
|
|
|
import redis
|
|
|
import asyncio
|
|
|
+import psutil
|
|
|
from utils.config import config, APP_PATH
|
|
|
from utils.process_mgr import process_manager
|
|
|
from utils.logu import logger
|
|
|
@@ -17,7 +18,6 @@ py_client: Optional[Dict[str,Any]] = {
|
|
|
'convert': WORKER_DIR_BASE / r'worker\celery\html_convert_tasks.py'
|
|
|
}
|
|
|
|
|
|
-
|
|
|
class WorkerModel(BaseModel):
|
|
|
name: str
|
|
|
queue_name: Optional[str] = Field(default=None,validate_default=True)
|
|
|
@@ -33,7 +33,6 @@ class WorkerModel(BaseModel):
|
|
|
return f"{values.data['name']}_queue"
|
|
|
return v
|
|
|
|
|
|
-
|
|
|
class CeleryWorker:
|
|
|
def __init__(self, python_exe: str=sys.executable):
|
|
|
self.python_exe = python_exe
|
|
|
@@ -46,30 +45,29 @@ class CeleryWorker:
|
|
|
async def wait_for_worker_online(self, name: str, timeout: int = 30) -> bool:
|
|
|
"""Wait for worker to appear in flower status with timeout"""
|
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
- flower_url = "http://127.0.0.1:5555/workers?json=1"
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
- async with httpx.AsyncClient() as client:
|
|
|
- response = await client.get(flower_url)
|
|
|
- response.raise_for_status()
|
|
|
- workers_status = response.json()
|
|
|
-
|
|
|
- # Check if our worker is in the list
|
|
|
- for worker in workers_status.get("data", []):
|
|
|
- if worker["hostname"].startswith(f"{name}@"):
|
|
|
- # Update pid in model if found
|
|
|
- worker_model = self.workers_model.get(name)
|
|
|
- if worker_model:
|
|
|
- # 同时更新进程管理器的PID记录
|
|
|
- worker_model.pid = worker["pid"]
|
|
|
- return True
|
|
|
-
|
|
|
- # Check timeout
|
|
|
- if (asyncio.get_event_loop().time() - start_time) > timeout:
|
|
|
- return False
|
|
|
-
|
|
|
+ worker_status = await self.check_worker_status()
|
|
|
+ if worker_status["err"] != 0:
|
|
|
await asyncio.sleep(1)
|
|
|
+ continue
|
|
|
+
|
|
|
+ workers_data = worker_status.get("workers", {}).get("data", [])
|
|
|
+
|
|
|
+ # Check if our worker is in the list
|
|
|
+ for worker in workers_data:
|
|
|
+ if worker["hostname"].startswith(f"{name}@"):
|
|
|
+ worker_model = self.workers_model.get(name)
|
|
|
+ if worker_model:
|
|
|
+ worker_model.pid = worker["pid"]
|
|
|
+ return True
|
|
|
+
|
|
|
+ # Check timeout
|
|
|
+ if (asyncio.get_event_loop().time() - start_time) > timeout:
|
|
|
+ return False
|
|
|
+
|
|
|
+ await asyncio.sleep(1)
|
|
|
except Exception:
|
|
|
await asyncio.sleep(1)
|
|
|
continue
|
|
|
@@ -108,7 +106,7 @@ class CeleryWorker:
|
|
|
if not worker_model:
|
|
|
raise ValueError(f"Invalid worker name: {name}")
|
|
|
if in_cmd_windows:
|
|
|
- cmd = ['start','cmd', '/k' ]
|
|
|
+ cmd = ['start','cmd', '/c' ]
|
|
|
sub_cmd = ' '.join(worker_model.cmd)
|
|
|
cmd.append(f'{sub_cmd}')
|
|
|
logger.info(f"run {' '.join(cmd)}")
|
|
|
@@ -125,36 +123,34 @@ class CeleryWorker:
|
|
|
|
|
|
async def stop_worker(self, name: str):
|
|
|
worker_model = self.workers_model.get(name)
|
|
|
+ logger.info(f"{worker_model}")
|
|
|
if not worker_model:
|
|
|
raise ValueError(f"Invalid worker name: {name}")
|
|
|
- if worker_model.pid is None:
|
|
|
- await process_manager.stop_process(worker_model.name)
|
|
|
- else:
|
|
|
- logger.info(f"停止进程 {worker_model}")
|
|
|
- # 同时清理进程管理器和flower的记录
|
|
|
- subprocess.run(["taskkill", "/F", "/PID", str(worker_model.pid)])
|
|
|
- if name in process_manager.processes:
|
|
|
- del process_manager.processes[name]
|
|
|
- worker_model.pid = None
|
|
|
- return worker_model
|
|
|
-
|
|
|
- async def clean_worker_queue(self, name: str):
|
|
|
- worker_model = self.workers_model.get(name)
|
|
|
- if not worker_model:
|
|
|
- raise ValueError(f"Invalid worker name: {name}")
|
|
|
- queue_name = worker_model.queue_name
|
|
|
- return subprocess.run([sys.executable, "-m", "celery", "-A", "worker.celery.app", "purge", "-Q", queue_name])
|
|
|
-
|
|
|
+ # 检查进程是否真实存在
|
|
|
+ if psutil.pid_exists(worker_model.pid):
|
|
|
+ proc = await asyncio.to_thread(psutil.Process, worker_model.pid) # 异步获取进程对象
|
|
|
+ await asyncio.to_thread(proc.terminate) # 异步发送终止信号
|
|
|
+ try:
|
|
|
+ await asyncio.to_thread(proc.wait, 5) # 异步等待进程结束
|
|
|
+ except psutil.TimeoutExpired:
|
|
|
+ await asyncio.to_thread(proc.kill) # 异步强制杀死
|
|
|
+ logger.warning(f"{worker_model} 进程 {worker_model.pid} 强制终止")
|
|
|
+
|
|
|
async def check_worker_status(self) -> Dict[str, Any]:
|
|
|
flower_url = "http://127.0.0.1:5555/workers?json=1"
|
|
|
async with httpx.AsyncClient() as client:
|
|
|
try:
|
|
|
- # 检查 Redis 状态
|
|
|
response = await client.get(flower_url)
|
|
|
response.raise_for_status()
|
|
|
workers_status = response.json()
|
|
|
-
|
|
|
- # 合并 Worker 状态和 Redis 状态
|
|
|
+
|
|
|
+ # Update worker models with current pids
|
|
|
+ workers_data = workers_status.get("data", [])
|
|
|
+ for worker in workers_data:
|
|
|
+ worker_name = worker["hostname"].split("@")[0]
|
|
|
+ if worker_name in self.workers_model:
|
|
|
+ self.workers_model[worker_name].pid = worker["pid"]
|
|
|
+
|
|
|
return {"err": 0, "msg": "success", "workers": workers_status}
|
|
|
except httpx.HTTPStatusError as e:
|
|
|
return {"err": 1, "msg": f"HTTP error occurred: {e}"}
|