Quellcode durchsuchen

Refactor sessions a bit, and fix issue where runtimes get killed (#4900)

Robert Brennan vor 1 Jahr
Ursprung
Commit
17f4c6e1a9

+ 0 - 2
.github/workflows/ghcr-build.yml

@@ -286,7 +286,6 @@ jobs:
           image_name=ghcr.io/${{ github.repository_owner }}/runtime:${{ env.RELEVANT_SHA }}-${{ matrix.base_image }}
           image_name=$(echo $image_name | tr '[:upper:]' '[:lower:]')
 
-          SKIP_CONTAINER_LOGS=true \
           TEST_RUNTIME=eventstream \
           SANDBOX_USER_ID=$(id -u) \
           SANDBOX_RUNTIME_CONTAINER_IMAGE=$image_name \
@@ -364,7 +363,6 @@ jobs:
           image_name=ghcr.io/${{ github.repository_owner }}/runtime:${{ env.RELEVANT_SHA }}-${{ matrix.base_image }}
           image_name=$(echo $image_name | tr '[:upper:]' '[:lower:]')
 
-          SKIP_CONTAINER_LOGS=true \
           TEST_RUNTIME=eventstream \
           SANDBOX_USER_ID=$(id -u) \
           SANDBOX_RUNTIME_CONTAINER_IMAGE=$image_name \

+ 1 - 1
docs/modules/usage/runtimes.md

@@ -59,7 +59,7 @@ docker run # ...
     -e RUNTIME=remote \
     -e SANDBOX_REMOTE_RUNTIME_API_URL="https://runtime.app.all-hands.dev" \
     -e SANDBOX_API_KEY="your-all-hands-api-key" \
-    -e SANDBOX_KEEP_REMOTE_RUNTIME_ALIVE="true" \
+    -e SANDBOX_KEEP_RUNTIME_ALIVE="true" \
     # ...
 ```
 

+ 1 - 1
evaluation/miniwob/run_infer.py

@@ -66,7 +66,7 @@ def get_config(
             browsergym_eval_env=env_id,
             api_key=os.environ.get('ALLHANDS_API_KEY', None),
             remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
-            keep_remote_runtime_alive=False,
+            keep_runtime_alive=False,
         ),
         # do not mount workspace
         workspace_base=None,

+ 1 - 1
evaluation/scienceagentbench/run_infer.py

@@ -72,7 +72,7 @@ def get_config(
             timeout=300,
             api_key=os.environ.get('ALLHANDS_API_KEY', None),
             remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
-            keep_remote_runtime_alive=False,
+            keep_runtime_alive=False,
         ),
         # do not mount workspace
         workspace_base=None,

+ 1 - 1
evaluation/swe_bench/run_infer.py

@@ -145,7 +145,7 @@ def get_config(
             platform='linux/amd64',
             api_key=os.environ.get('ALLHANDS_API_KEY', None),
             remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
-            keep_remote_runtime_alive=False,
+            keep_runtime_alive=False,
             remote_runtime_init_timeout=1800,
         ),
         # do not mount workspace

+ 1 - 1
openhands/core/config/sandbox_config.py

@@ -36,7 +36,7 @@ class SandboxConfig:
 
     remote_runtime_api_url: str = 'http://localhost:8000'
     local_runtime_url: str = 'http://localhost'
-    keep_remote_runtime_alive: bool = True
+    keep_runtime_alive: bool = True
     api_key: str | None = None
     base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22'  # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime
     runtime_container_image: str | None = None

+ 18 - 0
openhands/runtime/impl/eventstream/containers.py

@@ -0,0 +1,18 @@
+import docker
+
+
+def remove_all_containers(prefix: str):
+    docker_client = docker.from_env()
+
+    try:
+        containers = docker_client.containers.list(all=True)
+        for container in containers:
+            try:
+                if container.name.startswith(prefix):
+                    container.remove(force=True)
+            except docker.errors.APIError:
+                pass
+            except docker.errors.NotFound:
+                pass
+    except docker.errors.NotFound:  # yes, this can happen!
+        pass

+ 33 - 45
openhands/runtime/impl/eventstream/eventstream_runtime.py

@@ -1,8 +1,9 @@
+import atexit
 import os
-from pathlib import Path
 import tempfile
 import threading
 from functools import lru_cache
+from pathlib import Path
 from typing import Callable
 from zipfile import ZipFile
 
@@ -35,6 +36,7 @@ from openhands.events.serialization import event_to_dict, observation_from_dict
 from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
 from openhands.runtime.base import Runtime
 from openhands.runtime.builder import DockerRuntimeBuilder
+from openhands.runtime.impl.eventstream.containers import remove_all_containers
 from openhands.runtime.plugins import PluginRequirement
 from openhands.runtime.utils import find_available_tcp_port
 from openhands.runtime.utils.request import send_request
@@ -42,6 +44,15 @@ from openhands.runtime.utils.runtime_build import build_runtime_image
 from openhands.utils.async_utils import call_sync_from_async
 from openhands.utils.tenacity_stop import stop_if_should_exit
 
+CONTAINER_NAME_PREFIX = 'openhands-runtime-'
+
+
+def remove_all_runtime_containers():
+    remove_all_containers(CONTAINER_NAME_PREFIX)
+
+
+atexit.register(remove_all_runtime_containers)
+
 
 class LogBuffer:
     """Synchronous buffer for Docker container logs.
@@ -114,8 +125,6 @@ class EventStreamRuntime(Runtime):
         env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None.
     """
 
-    container_name_prefix = 'openhands-runtime-'
-
     # Need to provide this method to allow inheritors to init the Runtime
     # without initting the EventStreamRuntime.
     def init_base_runtime(
@@ -158,7 +167,7 @@ class EventStreamRuntime(Runtime):
         self.docker_client: docker.DockerClient = self._init_docker_client()
         self.base_container_image = self.config.sandbox.base_container_image
         self.runtime_container_image = self.config.sandbox.runtime_container_image
-        self.container_name = self.container_name_prefix + sid
+        self.container_name = CONTAINER_NAME_PREFIX + sid
         self.container = None
         self.action_semaphore = threading.Semaphore(1)  # Ensure one action at a time
 
@@ -173,10 +182,6 @@ class EventStreamRuntime(Runtime):
                 f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}',
             )
 
-        self.skip_container_logs = (
-            os.environ.get('SKIP_CONTAINER_LOGS', 'false').lower() == 'true'
-        )
-
         self.init_base_runtime(
             config,
             event_stream,
@@ -189,7 +194,15 @@ class EventStreamRuntime(Runtime):
 
     async def connect(self):
         self.send_status_message('STATUS$STARTING_RUNTIME')
-        if not self.attach_to_existing:
+        try:
+            await call_sync_from_async(self._attach_to_container)
+        except docker.errors.NotFound as e:
+            if self.attach_to_existing:
+                self.log(
+                    'error',
+                    f'Container {self.container_name} not found.',
+                )
+                raise e
             if self.runtime_container_image is None:
                 if self.base_container_image is None:
                     raise ValueError(
@@ -210,13 +223,12 @@ class EventStreamRuntime(Runtime):
             await call_sync_from_async(self._init_container)
             self.log('info', f'Container started: {self.container_name}')
 
-        else:
-            await call_sync_from_async(self._attach_to_container)
-
         if not self.attach_to_existing:
             self.log('info', f'Waiting for client to become ready at {self.api_url}...')
-        self.send_status_message('STATUS$WAITING_FOR_CLIENT')
+            self.send_status_message('STATUS$WAITING_FOR_CLIENT')
+
         await call_sync_from_async(self._wait_until_alive)
+
         if not self.attach_to_existing:
             self.log('info', 'Runtime is ready.')
 
@@ -227,7 +239,8 @@ class EventStreamRuntime(Runtime):
             'debug',
             f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}',
         )
-        self.send_status_message(' ')
+        if not self.attach_to_existing:
+            self.send_status_message(' ')
 
     @staticmethod
     @lru_cache(maxsize=1)
@@ -332,13 +345,12 @@ class EventStreamRuntime(Runtime):
             self.log('debug', f'Container started. Server url: {self.api_url}')
             self.send_status_message('STATUS$CONTAINER_STARTED')
         except docker.errors.APIError as e:
-            # check 409 error
             if '409' in str(e):
                 self.log(
                     'warning',
                     f'Container {self.container_name} already exists. Removing...',
                 )
-                self._close_containers(rm_all_containers=True)
+                remove_all_containers(self.container_name)
                 return self._init_container()
 
             else:
@@ -414,42 +426,18 @@ class EventStreamRuntime(Runtime):
         Parameters:
         - rm_all_containers (bool): Whether to remove all containers with the 'openhands-sandbox-' prefix
         """
-
         if self.log_buffer:
             self.log_buffer.close()
 
         if self.session:
             self.session.close()
 
-        if self.attach_to_existing:
+        if self.config.sandbox.keep_runtime_alive or self.attach_to_existing:
             return
-        self._close_containers(rm_all_containers)
-
-    def _close_containers(self, rm_all_containers: bool = True):
-        try:
-            containers = self.docker_client.containers.list(all=True)
-            for container in containers:
-                try:
-                    # If the app doesn't shut down properly, it can leave runtime containers on the system. This ensures
-                    # that all 'openhands-sandbox-' containers are removed as well.
-                    if rm_all_containers and container.name.startswith(
-                        self.container_name_prefix
-                    ):
-                        container.remove(force=True)
-                    elif container.name == self.container_name:
-                        if not self.skip_container_logs:
-                            logs = container.logs(tail=1000).decode('utf-8')
-                            self.log(
-                                'debug',
-                                f'==== Container logs on close ====\n{logs}\n==== End of container logs ====',
-                            )
-                        container.remove(force=True)
-                except docker.errors.APIError:
-                    pass
-                except docker.errors.NotFound:
-                    pass
-        except docker.errors.NotFound:  # yes, this can happen!
-            pass
+        close_prefix = (
+            CONTAINER_NAME_PREFIX if rm_all_containers else self.container_name
+        )
+        remove_all_containers(close_prefix)
 
     def run_action(self, action: Action) -> Observation:
         if isinstance(action, FileEditAction):

+ 1 - 2
openhands/runtime/impl/remote/remote_runtime.py

@@ -288,7 +288,6 @@ class RemoteRuntime(Runtime):
         assert runtime_data['runtime_id'] == self.runtime_id
         assert 'pod_status' in runtime_data
         pod_status = runtime_data['pod_status']
-        self.log('debug', runtime_data)
         self.log('debug', f'Pod status: {pod_status}')
 
         # FIXME: We should fix it at the backend of /start endpoint, make sure
@@ -333,7 +332,7 @@ class RemoteRuntime(Runtime):
         raise RuntimeNotReadyError()
 
     def close(self, timeout: int = 10):
-        if self.config.sandbox.keep_remote_runtime_alive or self.attach_to_existing:
+        if self.config.sandbox.keep_runtime_alive or self.attach_to_existing:
             self.session.close()
             return
         if self.runtime_id and self.session:

+ 4 - 2
openhands/runtime/impl/runloop/runloop_runtime.py

@@ -21,6 +21,8 @@ from openhands.runtime.utils.command import get_remote_startup_command
 from openhands.runtime.utils.request import send_request
 from openhands.utils.tenacity_stop import stop_if_should_exit
 
+CONTAINER_NAME_PREFIX = 'openhands-runtime-'
+
 
 class RunloopLogBuffer(LogBuffer):
     """Synchronous buffer for Runloop devbox logs.
@@ -115,7 +117,7 @@ class RunloopRuntime(EventStreamRuntime):
             bearer_token=config.runloop_api_key,
         )
         self.session = requests.Session()
-        self.container_name = self.container_name_prefix + sid
+        self.container_name = CONTAINER_NAME_PREFIX + sid
         self.action_semaphore = threading.Semaphore(1)  # Ensure one action at a time
         self.init_base_runtime(
             config,
@@ -190,7 +192,7 @@ class RunloopRuntime(EventStreamRuntime):
             prebuilt='openhands',
             launch_parameters=LaunchParameters(
                 available_ports=[self._sandbox_port],
-                resource_size_request="LARGE",
+                resource_size_request='LARGE',
             ),
             metadata={'container-name': self.container_name},
         )

+ 1 - 9
openhands/server/listen.py

@@ -5,7 +5,6 @@ import tempfile
 import time
 import uuid
 import warnings
-from contextlib import asynccontextmanager
 
 import jwt
 import requests
@@ -74,14 +73,7 @@ file_store = get_file_store(config.file_store, config.file_store_path)
 session_manager = SessionManager(config, file_store)
 
 
-@asynccontextmanager
-async def lifespan(app: FastAPI):
-    global session_manager
-    async with session_manager:
-        yield
-
-
-app = FastAPI(lifespan=lifespan)
+app = FastAPI()
 app.add_middleware(
     LocalhostCORSMiddleware,
     allow_credentials=True,

+ 7 - 65
openhands/server/session/manager.py

@@ -1,14 +1,11 @@
-import asyncio
 import time
-from dataclasses import dataclass, field
-from typing import Optional
+from dataclasses import dataclass
 
 from fastapi import WebSocket
 
 from openhands.core.config import AppConfig
 from openhands.core.logger import openhands_logger as logger
 from openhands.events.stream import session_exists
-from openhands.runtime.utils.shutdown_listener import should_continue
 from openhands.server.session.conversation import Conversation
 from openhands.server.session.session import Session
 from openhands.storage.files import FileStore
@@ -18,78 +15,23 @@ from openhands.storage.files import FileStore
 class SessionManager:
     config: AppConfig
     file_store: FileStore
-    cleanup_interval: int = 300
-    session_timeout: int = 600
-    _sessions: dict[str, Session] = field(default_factory=dict)
-    _session_cleanup_task: Optional[asyncio.Task] = None
-
-    async def __aenter__(self):
-        if not self._session_cleanup_task:
-            self._session_cleanup_task = asyncio.create_task(self._cleanup_sessions())
-        return self
-
-    async def __aexit__(self, exc_type, exc_value, traceback):
-        if self._session_cleanup_task:
-            self._session_cleanup_task.cancel()
-            self._session_cleanup_task = None
 
     def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
-        if sid in self._sessions:
-            self._sessions[sid].close()
-        self._sessions[sid] = Session(
+        return Session(
             sid=sid, file_store=self.file_store, ws=ws_conn, config=self.config
         )
-        return self._sessions[sid]
-
-    def get_session(self, sid: str) -> Session | None:
-        if sid not in self._sessions:
-            return None
-        return self._sessions.get(sid)
 
     async def attach_to_conversation(self, sid: str) -> Conversation | None:
+        start_time = time.time()
         if not await session_exists(sid, self.file_store):
             return None
         c = Conversation(sid, file_store=self.file_store, config=self.config)
         await c.connect()
+        end_time = time.time()
+        logger.info(
+            f'Conversation {c.sid} connected in {end_time - start_time} seconds'
+        )
         return c
 
     async def detach_from_conversation(self, conversation: Conversation):
         await conversation.disconnect()
-
-    async def send(self, sid: str, data: dict[str, object]) -> bool:
-        """Sends data to the client."""
-        session = self.get_session(sid)
-        if session is None:
-            logger.error(f'*** No session found for {sid}, skipping message ***')
-            return False
-        return await session.send(data)
-
-    async def send_error(self, sid: str, message: str) -> bool:
-        """Sends an error message to the client."""
-        return await self.send(sid, {'error': True, 'message': message})
-
-    async def send_message(self, sid: str, message: str) -> bool:
-        """Sends a message to the client."""
-        return await self.send(sid, {'message': message})
-
-    async def _cleanup_sessions(self):
-        while should_continue():
-            current_time = time.time()
-            session_ids_to_remove = []
-            for sid, session in list(self._sessions.items()):
-                # if session inactive for a long time, remove it
-                if (
-                    not session.is_alive
-                    and current_time - session.last_active_ts > self.session_timeout
-                ):
-                    session_ids_to_remove.append(sid)
-
-            for sid in session_ids_to_remove:
-                to_del_session: Session | None = self._sessions.pop(sid, None)
-                if to_del_session is not None:
-                    to_del_session.close()
-                    logger.debug(
-                        f'Session {sid} and related resource have been removed due to inactivity.'
-                    )
-
-            await asyncio.sleep(self.cleanup_interval)

+ 1 - 0
tests/runtime/conftest.py

@@ -224,6 +224,7 @@ def _load_runtime(
     config = load_app_config()
     config.run_as_openhands = run_as_openhands
     config.sandbox.force_rebuild_runtime = force_rebuild_runtime
+    config.sandbox.keep_runtime_alive = False
     # Folder where all tests create their own folder
     global test_mount_path
     if use_workspace:

+ 1 - 1
tests/runtime/test_stress_remote_runtime.py

@@ -64,7 +64,7 @@ def get_config(
             timeout=300,
             api_key=os.environ.get('ALLHANDS_API_KEY', None),
             remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
-            keep_remote_runtime_alive=False,
+            keep_runtime_alive=False,
         ),
         # do not mount workspace
         workspace_base=None,