Переглянути джерело

修复 process_mgr 无法 Ctrl C 退出问题

mrh 9 місяців тому
батько
коміт
4c4ce66666
3 змінених файлів з 71 додано та 29 видалено
  1. 11 2
      ui/backend/main.py
  2. 43 10
      ui/backend/routers/proxy.py
  3. 17 17
      ui/backend/utils/process_mgr.py

+ 11 - 2
ui/backend/main.py

@@ -13,6 +13,7 @@ from utils.process_mgr import process_manager
 from utils.config import config,WORKER_DIR_BASE
 from src.services.celery_worker import CeleryWorker
 from utils.logu import get_logger,logger
+
 async def startup():
     """应用启动时执行的操作"""
     global process_manager,config
@@ -20,14 +21,21 @@ async def startup():
     logger.info(f"startup")
     tasks.append(asyncio.create_task(CeleryWorker().run()))
     tasks.append(asyncio.create_task(sub_mgr.startup()))
+    return tasks  # 确保返回任务列表
+
 @asynccontextmanager
 async def lifespan(app: FastAPI):
     """应用生命周期管理"""
     tasks = await startup()
     yield
-    await process_manager.cleanup()
+    # 先取消所有任务再清理进程
     for task in tasks:
         task.cancel()
+        try:
+            await task
+        except asyncio.CancelledError:
+            pass
+    await process_manager.cleanup()
 
 # 创建 FastAPI 应用实例
 app = FastAPI(
@@ -35,6 +43,7 @@ app = FastAPI(
     version="1.0.0",
     lifespan=lifespan
 )
+
 app.add_middleware(
     CORSMiddleware,
     allow_origins=["*"],  # 允许所有域名访问
@@ -42,11 +51,11 @@ app.add_middleware(
     allow_methods=["*"],  # 允许所有方法(GET, POST, PUT, DELETE 等)
     allow_headers=["*"],  # 允许所有头部
 )
+
 # 将 gpt_router 挂载到应用中
 app.include_router(router, prefix="/api/proxy", tags=["chat"])
 app.include_router(worker_router, prefix="/api/worker", tags=["worker"])
 
-# 如果你需要运行这个应用,可以使用以下代码
 if __name__ == "__main__":
     import uvicorn
     # 注意: 由于涉及到多进程管理机制,因此不能使用 reload 。例如 uvicorn main:app --port 5835 --reload

+ 43 - 10
ui/backend/routers/proxy.py

@@ -28,6 +28,10 @@ class SysProxyResponse(BaseModel):
 class SubUrlPost(BaseModel):
     sub_url: str
 
+class ProxyPoolResponse(BaseModel):
+    proxies: List[str]
+    cached: bool
+
 from fastapi.requests import Request
 
 @router.get("/sys")
@@ -81,7 +85,7 @@ async def get_all_proxy_response(use_cache: bool = True) -> List[ProxyResponse]:
     ret = []
     tasks = []
     for port,porxy_model in sub_mgr.sub.proxies.items():
-        tasks.append(get_proxy_response(port, use_cache))
+        tasks.append(get_proxy_response(port))
     ret = await asyncio.gather(*tasks)
     return ret
 
@@ -100,6 +104,24 @@ async def ping_proxies() -> Dict[str, int]:
         logger.info(f"use cache: {cache_key}")
     return cache[cache_key]
 
+@router.get("/proxies-pool")
+async def get_proxies_pool(force_refresh: bool = False):
+    global cache
+    cache_key = "proxy_pool"
+    
+    if not force_refresh and cache_key in cache:
+        return ProxyPoolResponse(proxies=cache[cache_key], cached=True)
+    
+    proxies = []
+    all_proxies = await get_all_proxy_response()
+    for p in all_proxies:
+        if p.reachable:  # 健康检查
+            proxies.append(f"127.0.0.1:{p.port}")
+    
+    # 更新缓存并返回
+    cache[cache_key] = proxies
+    return ProxyPoolResponse(proxies=proxies, cached=False)
+
 @router.get("/proxies/{port}")
 @router.get("/proxies")
 async def get_proxies(port: int = None):
@@ -142,6 +164,10 @@ async def delete_proxy(port: int):
         if port in sub_mgr.sub.proxies:
             del sub_mgr.sub.proxies[port]
             sub_mgr.save_config()
+        
+        # 清除代理池缓存
+        if 'proxy_pool' in cache:
+            del cache['proxy_pool']
             
         return await get_all_proxy_response()
     except Exception as e:
@@ -161,28 +187,26 @@ async def create_proxy(request:ProxyPost):
             porxy_port = request.port
             proxy_mgr = sub_mgr.get_proxy_manager(porxy_port)
             if proxy_mgr and proxy_mgr.running:
-                # return {'err': 0, "msg": f"已开启,跳过 {porxy_port} ", "data": await get_proxy_response(porxy_port)}
                 return ProxyPostResponse(err=0, msg=f"已开启,跳过 {porxy_port} ", data=await get_proxy_response(porxy_port))
             porxy_port_is_using = await port_is_using(porxy_port)
             controler_port = request.port + 1
             if porxy_port_is_using:
-                # return ProxyPostResponse(err=1, msg=f"porxy_port={porxy_port} 端口已被占用")
                 raise HTTPException(status_code=400, detail=ProxyPostResponse(err=1, msg=f"porxy_port={porxy_port} 端口已被占用"))
             if await port_is_using(controler_port):
-                # return {"err": 1, "msg": f"controler_port={controler_port} 端口已被占用"}
-                # return ProxyPostResponse(err=1, msg=f"controler_port={controler_port} 端口已被占用")
                 raise HTTPException(status_code=400, detail=ProxyPostResponse(err=1, msg=f"controler_port={controler_port} 端口已被占用"))
         else:
-            # return ProxyPostResponse(err=1, msg="port 或 auto 必须有一个")
             raise HTTPException(status_code=400, detail=ProxyPostResponse(err=1, msg="port 或 auto 必须有一个"))
         await sub_mgr.create_custom_config(porxy_port, controler_port)
         await sub_mgr.start_proxy(porxy_port)
         await auto_select_proxy(porxy_port)
-        # return {"err": 0, "msg": "ok", "data": await get_proxy_response(porxy_port)}
         res = ProxyPostResponse(err=0, msg="ok", data=await get_proxy_response(porxy_port))
         logger.info(f"{res}")
+        
+        # 清除代理池缓存
+        if 'proxy_pool' in cache:
+            del cache['proxy_pool']
+            
         return res
-    # return ProxyPostResponse(err=1, msg="proxy_lock error", data=sub_mgr.sub)
     return HTTPException(status_code=500, detail=ProxyPostResponse(err=1, msg="proxy_lock error", data=sub_mgr.sub))
 
 @router.post("/proxies/{port}/stop")
@@ -192,7 +216,13 @@ async def stop_proxy(port: int):
     if not proxy_mgr:
         raise HTTPException(status_code=404, detail=f"Proxy with port {port} not found")
     await sub_mgr.stop_proxy(port)
+    
+    # 清除代理池缓存
+    if 'proxy_pool' in cache:
+        del cache['proxy_pool']
+        
     return await get_proxy_response(port)
+
 @router.get("/subs")
 async def get_subscriptions():
     global sub_mgr
@@ -207,7 +237,6 @@ async def create_subscription(request: SubUrlPost):
     except Exception as e:
         return {"err": 1, "msg": str(e)}
 
-
 class StartupRequest(BaseModel):
     auto_start: bool
 
@@ -215,6 +244,11 @@ class StartupRequest(BaseModel):
 async def startup(request: StartupRequest):
     global sub_mgr,config
     sub_mgr.save_startup(request.auto_start)
+    
+    # 清除代理池缓存
+    if 'proxy_pool' in cache:
+        del cache['proxy_pool']
+        
     return {"err": 0, "msg": "ok", "data": config}
 
 def main():
@@ -226,6 +260,5 @@ def main():
     else:
         print("代理未启用")
 
-
 if __name__ == "__main__":
     main()

+ 17 - 17
ui/backend/utils/process_mgr.py

@@ -23,7 +23,7 @@ class ProcessManager:
     def __init__(self):
         self.processes: MutableMapping[str, dict] = {}
         self.job_object = None
-        self.lock = asyncio.Lock()
+        self.lock = asyncio.Lock()  # 直接初始化锁
         
         if platform.system() == 'Windows':
             self._create_windows_job()
@@ -57,6 +57,7 @@ class ProcessManager:
             self._setup_windows_ctrl_handler()
         else:
             self._setup_unix_signal_handlers()
+
     def _setup_windows_ctrl_handler(self):
         """Windows控制台事件处理"""
         try:
@@ -70,9 +71,20 @@ class ProcessManager:
     def _windows_ctrl_handler(self, dwCtrlType):
         """Windows控制台事件回调"""
         if dwCtrlType in {win32con.CTRL_C_EVENT, win32con.CTRL_BREAK_EVENT}:
-            asyncio.create_task(self.cleanup())
-            return True  # 表示已处理该事件
-        return False  # 继续传递事件
+            try:
+                # 使用独立事件循环处理清理
+                loop = asyncio.new_event_loop()
+                asyncio.set_event_loop(loop)
+                loop.run_until_complete(self.cleanup())
+                loop.close()
+                
+                # 退出应用程序
+                os._exit(0)
+                return True
+            except Exception as e:
+                logger.error(f"Critical error during cleanup: {str(e)}")
+                os._exit(1)
+        return False
 
     def _setup_unix_signal_handlers(self):
         """Unix信号处理配置"""
@@ -87,9 +99,6 @@ class ProcessManager:
         """Unix信号处理回调"""
         logger.info(f"Received signal {signum.name}")
         asyncio.create_task(self.cleanup())
-    def _signal_handler(self):
-        """信号处理入口"""
-        asyncio.create_task(self.cleanup())
 
     async def start_process(
         self,
@@ -108,20 +117,17 @@ class ProcessManager:
             log_file = log_dir / f"{name}.log"
 
             try:
-                # 使用二进制追加模式打开日志文件
                 log_fd = open(log_file, "ab")
                 
-                # 创建子进程
                 process = await asyncio.create_subprocess_exec(
                     *command,
                     stdout=log_fd,
                     cwd=cwd,
                     stderr=subprocess.STDOUT,
                     stdin=subprocess.DEVNULL,
-                    start_new_session=True  # 重要:创建新会话/进程组
+                    start_new_session=True
                 )
 
-                # Windows作业对象绑定
                 if platform.system() == 'Windows' and self.job_object:
                     self._bind_to_windows_job(process.pid)
 
@@ -164,13 +170,11 @@ class ProcessManager:
             log_fd = proc_info["log_fd"]
 
             try:
-                # 进程已自然退出
                 if process.returncode is not None:
                     del self.processes[name]
                     log_fd.close()
                     return True
 
-                # 跨平台终止逻辑
                 if platform.system() == 'Windows':
                     subprocess.run(
                         ["taskkill", "/F", "/T", "/PID", str(process.pid)],
@@ -179,10 +183,8 @@ class ProcessManager:
                         stderr=subprocess.DEVNULL
                     )
                 else:
-                    # 发送信号到整个进程组
                     os.killpg(os.getpgid(process.pid), signal.SIGTERM)
 
-                # 等待进程终止
                 await process.wait()
                 logger.info(f"Stopped process {name} (PID: {process.pid})")
                 return True
@@ -206,7 +208,6 @@ class ProcessManager:
             for name in list(self.processes.keys()):
                 await self.stop_process(name)
             
-            # 清理Windows作业对象
             if platform.system() == 'Windows' and self.job_object:
                 win32api.CloseHandle(self.job_object)
                 self.job_object = None
@@ -226,4 +227,3 @@ class ProcessManager:
                 pass
 
 process_manager = ProcessManager()
-