Browse Source

Fix Graceful cleanup of session manager (#4306)

tofarr 1 year ago
parent
commit
f867fda2f9
2 changed files with 22 additions and 2 deletions
  1. 10 1
      openhands/server/listen.py
  2. 12 1
      openhands/server/session/manager.py

+ 10 - 1
openhands/server/listen.py

@@ -5,6 +5,7 @@ import re
 import tempfile
 import uuid
 import warnings
+from contextlib import asynccontextmanager
 
 import requests
 from pathspec import PathSpec
@@ -64,7 +65,15 @@ config = load_app_config()
 file_store = get_file_store(config.file_store, config.file_store_path)
 session_manager = SessionManager(config, file_store)
 
-app = FastAPI()
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    global session_manager
+    async with session_manager:
+        yield
+
+
+app = FastAPI(lifespan=lifespan)
 app.add_middleware(
     CORSMiddleware,
     allow_origins=['http://localhost:3001', 'http://127.0.0.1:3001'],

+ 12 - 1
openhands/server/session/manager.py

@@ -1,5 +1,6 @@
 import asyncio
 import time
+from typing import Optional
 
 from fastapi import WebSocket
 
@@ -14,12 +15,22 @@ class SessionManager:
     _sessions: dict[str, Session] = {}
     cleanup_interval: int = 300
     session_timeout: int = 600
+    _session_cleanup_task: Optional[asyncio.Task] = None
 
     def __init__(self, config: AppConfig, file_store: FileStore):
-        asyncio.create_task(self._cleanup_sessions())
         self.config = config
         self.file_store = file_store
 
+    async def __aenter__(self):
+        if not self._session_cleanup_task:
+            self._session_cleanup_task = asyncio.create_task(self._cleanup_sessions())
+        return self
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        if self._session_cleanup_task:
+            self._session_cleanup_task.cancel()
+            self._session_cleanup_task = None
+
     def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
         if sid in self._sessions:
             asyncio.create_task(self._sessions[sid].close())