Przeglądaj źródła

allow reconnecting to a runtime (#4223)

Robert Brennan 1 rok temu
rodzic
commit
45fb4fb9bc

+ 1 - 0
evaluation/swe_bench/run_infer.py

@@ -133,6 +133,7 @@ def get_config(
             timeout=300,
             timeout=300,
             api_key=os.environ.get('ALLHANDS_API_KEY', None),
             api_key=os.environ.get('ALLHANDS_API_KEY', None),
             remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
             remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
+            keep_remote_runtime_alive=False,
         ),
         ),
         # do not mount workspace
         # do not mount workspace
         workspace_base=None,
         workspace_base=None,

+ 7 - 5
frontend/src/components/AgentStatusBar.tsx

@@ -94,12 +94,14 @@ function AgentStatusBar() {
   const [statusMessage, setStatusMessage] = React.useState<string>("");
   const [statusMessage, setStatusMessage] = React.useState<string>("");
 
 
   React.useEffect(() => {
   React.useEffect(() => {
-    const trimmedCustomMessage = curStatusMessage.status.trim();
-    if (trimmedCustomMessage) {
-      setStatusMessage(t(trimmedCustomMessage));
-    } else {
-      setStatusMessage(AgentStatusMap[curAgentState].message);
+    if (curAgentState === AgentState.LOADING) {
+      const trimmedCustomMessage = curStatusMessage.status.trim();
+      if (trimmedCustomMessage) {
+        setStatusMessage(t(trimmedCustomMessage));
+        return;
+      }
     }
     }
+    setStatusMessage(AgentStatusMap[curAgentState].message);
   }, [curAgentState, curStatusMessage.status]);
   }, [curAgentState, curStatusMessage.status]);
 
 
   return (
   return (

+ 1 - 0
openhands/core/config/sandbox_config.py

@@ -34,6 +34,7 @@ class SandboxConfig:
 
 
     remote_runtime_api_url: str = 'http://localhost:8000'
     remote_runtime_api_url: str = 'http://localhost:8000'
     local_runtime_url: str = 'http://localhost'
     local_runtime_url: str = 'http://localhost'
+    keep_remote_runtime_alive: bool = True
     api_key: str | None = None
     api_key: str | None = None
     base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22'  # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime
     base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22'  # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime
     runtime_container_image: str | None = None
     runtime_container_image: str | None = None

+ 125 - 84
openhands/runtime/remote/runtime.py

@@ -1,7 +1,6 @@
 import os
 import os
 import tempfile
 import tempfile
 import threading
 import threading
-import uuid
 from typing import Callable, Optional
 from typing import Callable, Optional
 from zipfile import ZipFile
 from zipfile import ZipFile
 
 
@@ -11,7 +10,7 @@ from tenacity import (
     retry,
     retry,
     retry_if_exception_type,
     retry_if_exception_type,
     stop_after_attempt,
     stop_after_attempt,
-    wait_exponential,
+    wait_fixed,
 )
 )
 
 
 from openhands.core.config import AppConfig
 from openhands.core.config import AppConfig
@@ -60,14 +59,13 @@ class RemoteRuntime(Runtime):
         status_message_callback: Optional[Callable] = None,
         status_message_callback: Optional[Callable] = None,
     ):
     ):
         self.config = config
         self.config = config
+        self.status_message_callback = status_message_callback
 
 
         if self.config.sandbox.api_key is None:
         if self.config.sandbox.api_key is None:
             raise ValueError(
             raise ValueError(
                 'API key is required to use the remote runtime. '
                 'API key is required to use the remote runtime. '
                 'Please set the API key in the config (config.toml) or as an environment variable (SANDBOX_API_KEY).'
                 'Please set the API key in the config (config.toml) or as an environment variable (SANDBOX_API_KEY).'
             )
             )
-        self.status_message_callback = status_message_callback
-        self.send_status_message('STATUS$STARTING_RUNTIME')
         self.session = requests.Session()
         self.session = requests.Session()
         self.session.headers.update({'X-API-Key': self.config.sandbox.api_key})
         self.session.headers.update({'X-API-Key': self.config.sandbox.api_key})
         self.action_semaphore = threading.Semaphore(1)
         self.action_semaphore = threading.Semaphore(1)
@@ -83,61 +81,116 @@ class RemoteRuntime(Runtime):
         self.runtime_id: str | None = None
         self.runtime_id: str | None = None
         self.runtime_url: str | None = None
         self.runtime_url: str | None = None
 
 
-        self.instance_id = (
-            sid + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
+        self.instance_id = sid
+
+        self._start_or_attach_to_runtime(plugins)
+
+        # Initialize the eventstream and env vars
+        super().__init__(
+            config, event_stream, sid, plugins, env_vars, status_message_callback
         )
         )
-        self.container_name = 'oh-remote-runtime-' + self.instance_id
-        if self.config.sandbox.runtime_container_image is not None:
-            logger.info(
-                f'Running remote runtime with image: {self.config.sandbox.runtime_container_image}'
-            )
-            self.container_image = self.config.sandbox.runtime_container_image
+        self._wait_until_alive()
+        self.setup_initial_env()
+
+    def _start_or_attach_to_runtime(self, plugins: list[PluginRequirement] | None):
+        existing_runtime = self._check_existing_runtime()
+        if existing_runtime:
+            logger.info(f'Using existing runtime with ID: {self.runtime_id}')
         else:
         else:
-            logger.info(
-                f'Building remote runtime with base image: {self.config.sandbox.base_container_image}'
-            )
-            logger.debug(f'RemoteRuntime `{sid}` config:\n{self.config}')
+            self.send_status_message('STATUS$STARTING_CONTAINER')
+            if self.config.sandbox.runtime_container_image is None:
+                logger.info(
+                    f'Building remote runtime with base image: {self.config.sandbox.base_container_image}'
+                )
+                self._build_runtime()
+            else:
+                logger.info(
+                    f'Running remote runtime with image: {self.config.sandbox.runtime_container_image}'
+                )
+            self._start_runtime(plugins)
+        assert (
+            self.runtime_id is not None
+        ), 'Runtime ID is not set. This should never happen.'
+        assert (
+            self.runtime_url is not None
+        ), 'Runtime URL is not set. This should never happen.'
+        self.send_status_message('STATUS$WAITING_FOR_CLIENT')
+        self._wait_until_alive()
+
+    def _check_existing_runtime(self) -> bool:
+        try:
             response = send_request_with_retry(
             response = send_request_with_retry(
                 self.session,
                 self.session,
                 'GET',
                 'GET',
-                f'{self.config.sandbox.remote_runtime_api_url}/registry_prefix',
-                timeout=30,
-            )
-            response_json = response.json()
-            registry_prefix = response_json['registry_prefix']
-            os.environ['OH_RUNTIME_RUNTIME_IMAGE_REPO'] = (
-                registry_prefix.rstrip('/') + '/runtime'
-            )
-            logger.info(
-                f'Runtime image repo: {os.environ["OH_RUNTIME_RUNTIME_IMAGE_REPO"]}'
+                f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.instance_id}',
+                timeout=5,
             )
             )
+        except Exception as e:
+            logger.error(f'Error while looking for remote runtime: {e}')
+            return False
+
+        if response.status_code == 200:
+            data = response.json()
+            status = data.get('status')
+            if status == 'running':
+                self._parse_runtime_response(response)
+                return True
+            elif status == 'stopped':
+                logger.info('Found existing remote runtime, but it is stopped')
+                return False
+            elif status == 'paused':
+                logger.info('Found existing remote runtime, but it is paused')
+                self._parse_runtime_response(response)
+                self._resume_runtime()
+                return True
+            else:
+                logger.error(f'Invalid response from runtime API: {data}')
+                return False
+        else:
+            logger.info('Could not find existing remote runtime')
+            return False
 
 
-            if self.config.sandbox.runtime_extra_deps:
-                logger.info(
-                    f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}'
-                )
+    def _build_runtime(self):
+        logger.debug(f'RemoteRuntime `{self.instance_id}` config:\n{self.config}')
+        response = send_request_with_retry(
+            self.session,
+            'GET',
+            f'{self.config.sandbox.remote_runtime_api_url}/registry_prefix',
+            timeout=30,
+        )
+        response_json = response.json()
+        registry_prefix = response_json['registry_prefix']
+        os.environ['OH_RUNTIME_RUNTIME_IMAGE_REPO'] = (
+            registry_prefix.rstrip('/') + '/runtime'
+        )
+        logger.info(
+            f'Runtime image repo: {os.environ["OH_RUNTIME_RUNTIME_IMAGE_REPO"]}'
+        )
 
 
-            # Build the container image
-            self.send_status_message('STATUS$STARTING_CONTAINER')
-            self.container_image = build_runtime_image(
-                self.config.sandbox.base_container_image,
-                self.runtime_builder,
-                extra_deps=self.config.sandbox.runtime_extra_deps,
-                force_rebuild=self.config.sandbox.force_rebuild_runtime,
+        if self.config.sandbox.runtime_extra_deps:
+            logger.info(
+                f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}'
             )
             )
 
 
-            response = send_request_with_retry(
-                self.session,
-                'GET',
-                f'{self.config.sandbox.remote_runtime_api_url}/image_exists',
-                params={'image': self.container_image},
-                timeout=30,
-            )
-            if response.status_code != 200 or not response.json()['exists']:
-                raise RuntimeError(
-                    f'Container image {self.container_image} does not exist'
-                )
+        # Build the container image
+        self.container_image = build_runtime_image(
+            self.config.sandbox.base_container_image,
+            self.runtime_builder,
+            extra_deps=self.config.sandbox.runtime_extra_deps,
+            force_rebuild=self.config.sandbox.force_rebuild_runtime,
+        )
 
 
+        response = send_request_with_retry(
+            self.session,
+            'GET',
+            f'{self.config.sandbox.remote_runtime_api_url}/image_exists',
+            params={'image': self.container_image},
+            timeout=30,
+        )
+        if response.status_code != 200 or not response.json()['exists']:
+            raise RuntimeError(f'Container image {self.container_image} does not exist')
+
+    def _start_runtime(self, plugins: list[PluginRequirement] | None):
         # Prepare the request body for the /start endpoint
         # Prepare the request body for the /start endpoint
         plugin_arg = ''
         plugin_arg = ''
         if plugins is not None and len(plugins) > 0:
         if plugins is not None and len(plugins) > 0:
@@ -160,11 +213,10 @@ class RemoteRuntime(Runtime):
                 f'{browsergym_arg}'
                 f'{browsergym_arg}'
             ),
             ),
             'working_dir': '/openhands/code/',
             'working_dir': '/openhands/code/',
-            'name': self.container_name,
             'environment': {'DEBUG': 'true'} if self.config.debug else {},
             'environment': {'DEBUG': 'true'} if self.config.debug else {},
+            'runtime_id': self.instance_id,
         }
         }
 
 
-        self.send_status_message('STATUS$WAITING_FOR_CLIENT')
         # Start the sandbox using the /start endpoint
         # Start the sandbox using the /start endpoint
         response = send_request_with_retry(
         response = send_request_with_retry(
             self.session,
             self.session,
@@ -175,45 +227,35 @@ class RemoteRuntime(Runtime):
         )
         )
         if response.status_code != 201:
         if response.status_code != 201:
             raise RuntimeError(f'Failed to start sandbox: {response.text}')
             raise RuntimeError(f'Failed to start sandbox: {response.text}')
-        start_response = response.json()
-        self.runtime_id = start_response['runtime_id']
-        self.runtime_url = start_response['url']
-
+        self._parse_runtime_response(response)
         logger.info(
         logger.info(
             f'Sandbox started. Runtime ID: {self.runtime_id}, URL: {self.runtime_url}'
             f'Sandbox started. Runtime ID: {self.runtime_id}, URL: {self.runtime_url}'
         )
         )
 
 
+    def _resume_runtime(self):
+        response = send_request_with_retry(
+            self.session,
+            'POST',
+            f'{self.config.sandbox.remote_runtime_api_url}/resume',
+            json={'runtime_id': self.runtime_id},
+            timeout=30,
+        )
+        if response.status_code != 200:
+            raise RuntimeError(f'Failed to resume sandbox: {response.text}')
+        logger.info(f'Sandbox resumed. Runtime ID: {self.runtime_id}')
+
+    def _parse_runtime_response(self, response: requests.Response):
+        start_response = response.json()
+        self.runtime_id = start_response['runtime_id']
+        self.runtime_url = start_response['url']
         if 'session_api_key' in start_response:
         if 'session_api_key' in start_response:
             self.session.headers.update(
             self.session.headers.update(
                 {'X-Session-API-Key': start_response['session_api_key']}
                 {'X-Session-API-Key': start_response['session_api_key']}
             )
             )
 
 
-        # Initialize the eventstream and env vars
-        super().__init__(
-            config, event_stream, sid, plugins, env_vars, status_message_callback
-        )
-
-        logger.info(
-            f'Runtime initialized with plugins: {[plugin.name for plugin in self.plugins]}'
-        )
-        logger.info(f'Runtime initialized with env vars: {env_vars}')
-        assert (
-            self.runtime_id is not None
-        ), 'Runtime ID is not set. This should never happen.'
-        assert (
-            self.runtime_url is not None
-        ), 'Runtime URL is not set. This should never happen.'
-
-        self._wait_until_alive()
-
-        self.send_status_message(' ')
-
-        self._wait_until_alive()
-        self.setup_initial_env()
-
     @retry(
     @retry(
-        stop=stop_after_attempt(10) | stop_if_should_exit(),
-        wait=wait_exponential(multiplier=1, min=4, max=60),
+        stop=stop_after_attempt(60) | stop_if_should_exit(),
+        wait=wait_fixed(2),
         retry=retry_if_exception_type(RuntimeError),
         retry=retry_if_exception_type(RuntimeError),
         reraise=True,
         reraise=True,
     )
     )
@@ -236,6 +278,9 @@ class RemoteRuntime(Runtime):
             raise RuntimeError(msg)
             raise RuntimeError(msg)
 
 
     def close(self, timeout: int = 10):
     def close(self, timeout: int = 10):
+        if self.config.sandbox.keep_remote_runtime_alive:
+            self.session.close()
+            return
         if self.runtime_id:
         if self.runtime_id:
             try:
             try:
                 response = send_request_with_retry(
                 response = send_request_with_retry(
@@ -268,8 +313,6 @@ class RemoteRuntime(Runtime):
                     f'Action {action_type} is not supported in the current runtime.'
                     f'Action {action_type} is not supported in the current runtime.'
                 )
                 )
 
 
-            self._wait_until_alive()
-
             assert action.timeout is not None
             assert action.timeout is not None
 
 
             try:
             try:
@@ -331,7 +374,6 @@ class RemoteRuntime(Runtime):
         if not os.path.exists(host_src):
         if not os.path.exists(host_src):
             raise FileNotFoundError(f'Source file {host_src} does not exist')
             raise FileNotFoundError(f'Source file {host_src} does not exist')
 
 
-        self._wait_until_alive()
         try:
         try:
             if recursive:
             if recursive:
                 with tempfile.NamedTemporaryFile(
                 with tempfile.NamedTemporaryFile(
@@ -383,7 +425,6 @@ class RemoteRuntime(Runtime):
             logger.info(f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}')
             logger.info(f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}')
 
 
     def list_files(self, path: str | None = None) -> list[str]:
     def list_files(self, path: str | None = None) -> list[str]:
-        self._wait_until_alive()
         try:
         try:
             data = {}
             data = {}
             if path is not None:
             if path is not None:
@@ -397,7 +438,7 @@ class RemoteRuntime(Runtime):
                 retry_exceptions=list(
                 retry_exceptions=list(
                     filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
                     filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
                 ),
                 ),
-                timeout=30,  # The runtime sbould already be running here
+                timeout=30,
             )
             )
             if response.status_code == 200:
             if response.status_code == 200:
                 response_json = response.json()
                 response_json = response.json()

+ 1 - 1
openhands/server/session/agent_session.py

@@ -180,7 +180,7 @@ class AgentSession:
                 status_message_callback=status_message_callback,
                 status_message_callback=status_message_callback,
             )
             )
         except Exception as e:
         except Exception as e:
-            logger.error(f'Runtime initialization failed: {e}')
+            logger.error(f'Runtime initialization failed: {e}', exc_info=True)
             raise
             raise
 
 
         if self.runtime is not None:
         if self.runtime is not None: