Parcourir la source

Refactor runtime to add a `connect` method (#4410)

Co-authored-by: Tim O'Farrell <tofarr@gmail.com>
Robert Brennan il y a 1 an
Parent
commit
8d2b2d4318

+ 1 - 0
openhands/core/cli.py

@@ -114,6 +114,7 @@ async def main():
         sid=sid,
         plugins=agent_cls.sandbox_plugins,
     )
+    await runtime.connect()
 
     controller = AgentController(
         agent=agent,

+ 1 - 0
openhands/core/main.py

@@ -122,6 +122,7 @@ async def run_controller(
 
     if runtime is None:
         runtime = create_runtime(config, sid=sid)
+    await runtime.connect()
 
     event_stream = runtime.event_stream
     # restore cli session if enabled

+ 4 - 0
openhands/runtime/base.py

@@ -172,6 +172,10 @@ class Runtime(FileEditRuntimeMixin):
     def __exit__(self, exc_type, exc_value, traceback) -> None:
         self.close()
 
+    @abstractmethod
+    async def connect(self) -> None:
+        pass
+
     # ====================================================================
     # Action execution
     # ====================================================================

+ 19 - 16
openhands/runtime/impl/eventstream/eventstream_runtime.py

@@ -154,7 +154,6 @@ class EventStreamRuntime(Runtime):
         self.session = requests.Session()
         self.status_message_callback = status_message_callback
 
-        self.send_status_message('STATUS$STARTING_RUNTIME')
         self.docker_client: docker.DockerClient = self._init_docker_client()
         self.base_container_image = self.config.sandbox.base_container_image
         self.runtime_container_image = self.config.sandbox.runtime_container_image
@@ -175,7 +174,20 @@ class EventStreamRuntime(Runtime):
         self.skip_container_logs = (
             os.environ.get('SKIP_CONTAINER_LOGS', 'false').lower() == 'true'
         )
-        if not attach_to_existing:
+
+        self.init_base_runtime(
+            config,
+            event_stream,
+            sid,
+            plugins,
+            env_vars,
+            status_message_callback,
+            attach_to_existing,
+        )
+
+    async def connect(self):
+        self.send_status_message('STATUS$STARTING_RUNTIME')
+        if not self.attach_to_existing:
             if self.runtime_container_image is None:
                 if self.base_container_image is None:
                     raise ValueError(
@@ -194,27 +206,18 @@ class EventStreamRuntime(Runtime):
             self._init_container(
                 sandbox_workspace_dir=self.config.workspace_mount_path_in_sandbox,  # e.g. /workspace
                 mount_dir=self.config.workspace_mount_path,  # e.g. /opt/openhands/_test_workspace
-                plugins=plugins,
+                plugins=self.plugins,
             )
+
         else:
             self._attach_to_container()
 
-        # Will initialize both the event stream and the env vars
-        self.init_base_runtime(
-            config,
-            event_stream,
-            sid,
-            plugins,
-            env_vars,
-            status_message_callback,
-            attach_to_existing,
-        )
-
         logger.info('Waiting for client to become ready...')
         self.send_status_message('STATUS$WAITING_FOR_CLIENT')
-
         self._wait_until_alive()
-        self.setup_initial_env()
+
+        if not self.attach_to_existing:
+            self.setup_initial_env()
 
         logger.info(
             f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}'

+ 67 - 57
openhands/runtime/impl/modal/modal_runtime.py

@@ -1,7 +1,6 @@
 import os
 import tempfile
 import threading
-import uuid
 from pathlib import Path
 from typing import Callable, Generator
 
@@ -12,11 +11,19 @@ import tenacity
 from openhands.core.config import AppConfig
 from openhands.core.logger import openhands_logger as logger
 from openhands.events import EventStream
-from openhands.runtime.client.runtime import EventStreamRuntime, LogBuffer
+from openhands.runtime.impl.eventstream.eventstream_runtime import (
+    EventStreamRuntime,
+    LogBuffer,
+)
 from openhands.runtime.plugins import PluginRequirement
+from openhands.runtime.utils.command import get_remote_startup_command
 from openhands.runtime.utils.runtime_build import (
     prep_build_folder,
 )
+from openhands.utils.async_utils import call_sync_from_async
+
+# FIXME: this will not work in HA mode. We need a better way to track IDs
+MODAL_RUNTIME_IDS: dict[str, str] = {}
 
 
 # Modal's log generator returns strings, but the upstream LogBuffer expects bytes.
@@ -60,6 +67,7 @@ class ModalRuntime(EventStreamRuntime):
     """
 
     container_name_prefix = 'openhands-sandbox-'
+    sandbox: modal.Sandbox | None
 
     def __init__(
         self,
@@ -69,11 +77,13 @@ class ModalRuntime(EventStreamRuntime):
         plugins: list[PluginRequirement] | None = None,
         env_vars: dict[str, str] | None = None,
         status_message_callback: Callable | None = None,
+        attach_to_existing: bool = False,
     ):
         assert config.modal_api_token_id, 'Modal API token id is required'
         assert config.modal_api_token_secret, 'Modal API token secret is required'
 
         self.config = config
+        self.sandbox = None
 
         self.modal_client = modal.Client.from_credentials(
             config.modal_api_token_id, config.modal_api_token_secret
@@ -92,18 +102,11 @@ class ModalRuntime(EventStreamRuntime):
         self.container_port = 3000
 
         self.session = requests.Session()
-        self.instance_id = (
-            sid + '_' + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
-        )
         self.status_message_callback = status_message_callback
-
-        self.send_status_message('STATUS$STARTING_RUNTIME')
         self.base_container_image_id = self.config.sandbox.base_container_image
         self.runtime_container_image_id = self.config.sandbox.runtime_container_image
         self.action_semaphore = threading.Semaphore(1)  # Ensure one action at a time
 
-        logger.info(f'ModalRuntime `{self.instance_id}`')
-
         # Buffer for container logs
         self.log_buffer: LogBuffer | None = None
 
@@ -112,32 +115,60 @@ class ModalRuntime(EventStreamRuntime):
                 f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}'
             )
 
+        self.init_base_runtime(
+            config,
+            event_stream,
+            sid,
+            plugins,
+            env_vars,
+            status_message_callback,
+            attach_to_existing,
+        )
+
+    async def connect(self):
+        self.send_status_message('STATUS$STARTING_RUNTIME')
+
+        logger.info(f'ModalRuntime `{self.sid}`')
+
         self.image = self._get_image_definition(
             self.base_container_image_id,
             self.runtime_container_image_id,
             self.config.sandbox.runtime_extra_deps,
         )
 
-        self.sandbox = self._init_sandbox(
-            sandbox_workspace_dir=self.config.workspace_mount_path_in_sandbox,
-            plugins=plugins,
-        )
+        if self.attach_to_existing:
+            if self.sid in MODAL_RUNTIME_IDS:
+                sandbox_id = MODAL_RUNTIME_IDS[self.sid]
+                logger.info(f'Attaching to existing Modal sandbox: {sandbox_id}')
+                self.sandbox = modal.Sandbox.from_id(
+                    sandbox_id, client=self.modal_client
+                )
+        else:
+            self.send_status_message('STATUS$PREPARING_CONTAINER')
+            await call_sync_from_async(
+                self._init_sandbox,
+                sandbox_workspace_dir=self.config.workspace_mount_path_in_sandbox,
+                plugins=self.plugins,
+            )
 
-        # Will initialize both the event stream and the env vars
-        self.init_base_runtime(
-            config, event_stream, sid, plugins, env_vars, status_message_callback
-        )
+            self.send_status_message('STATUS$CONTAINER_STARTED')
+
+        self.log_buffer = ModalLogBuffer(self.sandbox)
+        if self.sandbox is None:
+            raise Exception('Sandbox not initialized')
+        tunnel = self.sandbox.tunnels()[self.container_port]
+        self.api_url = tunnel.url
+        logger.info(f'Container started. Server url: {self.api_url}')
 
-        logger.info('Waiting for client to become ready...')
-        self.send_status_message('STATUS$WAITING_FOR_CLIENT')
+        if not self.attach_to_existing:
+            logger.info('Waiting for client to become ready...')
+            self.send_status_message('STATUS$WAITING_FOR_CLIENT')
 
         self._wait_until_alive()
         self.setup_initial_env()
 
-        logger.info(
-            f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}'
-        )
-        self.send_status_message(' ')
+        if not self.attach_to_existing:
+            self.send_status_message(' ')
 
     def _get_image_definition(
         self,
@@ -185,10 +216,9 @@ echo 'export INPUTRC=/etc/inputrc' >> /etc/bash.bashrc
         self,
         sandbox_workspace_dir: str,
         plugins: list[PluginRequirement] | None = None,
-    ) -> modal.Sandbox:
+    ):
         try:
             logger.info('Preparing to start container...')
-            self.send_status_message('STATUS$PREPARING_CONTAINER')
             plugin_args = []
             if plugins is not None and len(plugins) > 0:
                 plugin_args.append('--plugins')
@@ -212,29 +242,16 @@ echo 'export INPUTRC=/etc/inputrc' >> /etc/bash.bashrc
             env_secret = modal.Secret.from_dict(environment)
 
             logger.debug(f'Sandbox workspace: {sandbox_workspace_dir}')
-            sandbox_start_cmd: list[str] = [
-                '/openhands/micromamba/bin/micromamba',
-                'run',
-                '-n',
-                'openhands',
-                'poetry',
-                'run',
-                'python',
-                '-u',
-                '-m',
-                'openhands.runtime.client.client',
-                str(self.container_port),
-                '--working-dir',
+            sandbox_start_cmd = get_remote_startup_command(
+                self.container_port,
                 sandbox_workspace_dir,
-                *plugin_args,
-                '--username',
                 'openhands' if self.config.run_as_openhands else 'root',
-                '--user-id',
-                str(self.config.sandbox.user_id),
-                *browsergym_args,
-            ]
-
-            sandbox = modal.Sandbox.create(
+                self.config.sandbox.user_id,
+                plugin_args,
+                browsergym_args,
+            )
+            logger.debug(f'Starting container with command: {sandbox_start_cmd}')
+            self.sandbox = modal.Sandbox.create(
                 *sandbox_start_cmd,
                 secrets=[env_secret],
                 workdir='/openhands/code',
@@ -244,18 +261,11 @@ echo 'export INPUTRC=/etc/inputrc' >> /etc/bash.bashrc
                 client=self.modal_client,
                 timeout=60 * 60,
             )
+            MODAL_RUNTIME_IDS[self.sid] = self.sandbox.object_id
+            logger.info('Container started')
 
-            tunnel = sandbox.tunnels()[self.container_port]
-            self.api_url = tunnel.url
-
-            self.log_buffer = ModalLogBuffer(sandbox)
-            logger.info(f'Container started. Server url: {self.api_url}')
-            self.send_status_message('STATUS$CONTAINER_STARTED')
-            return sandbox
         except Exception as e:
-            logger.error(
-                f'Error: Instance {self.instance_id} FAILED to start container!\n'
-            )
+            logger.error(f'Error: Instance {self.sid} FAILED to start container!\n')
             logger.exception(e)
             self.close()
             raise e
@@ -271,5 +281,5 @@ echo 'export INPUTRC=/etc/inputrc' >> /etc/bash.bashrc
         if self.session:
             self.session.close()
 
-        if self.sandbox:
+        if not self.attach_to_existing and self.sandbox:
             self.sandbox.terminate()

+ 25 - 28
openhands/runtime/impl/remote/remote_runtime.py

@@ -31,6 +31,7 @@ from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
 from openhands.runtime.base import Runtime
 from openhands.runtime.builder.remote import RemoteRuntimeBuilder
 from openhands.runtime.plugins import PluginRequirement
+from openhands.runtime.utils.command import get_remote_startup_command
 from openhands.runtime.utils.request import (
     is_404_error,
     is_503_error,
@@ -77,11 +78,6 @@ class RemoteRuntime(Runtime):
         self.runtime_id: str | None = None
         self.runtime_url: str | None = None
 
-        self.sid = sid
-
-        self._start_or_attach_to_runtime(plugins, attach_to_existing)
-
-        # Initialize the eventstream and env vars
         super().__init__(
             config,
             event_stream,
@@ -91,15 +87,17 @@ class RemoteRuntime(Runtime):
             status_message_callback,
             attach_to_existing,
         )
+
+    async def connect(self):
+        self._start_or_attach_to_runtime()
+        self._wait_until_alive()
         self.setup_initial_env()
 
-    def _start_or_attach_to_runtime(
-        self, plugins: list[PluginRequirement] | None, attach_to_existing: bool = False
-    ):
+    def _start_or_attach_to_runtime(self):
         existing_runtime = self._check_existing_runtime()
         if existing_runtime:
             logger.info(f'Using existing runtime with ID: {self.runtime_id}')
-        elif attach_to_existing:
+        elif self.attach_to_existing:
             raise RuntimeError('Could not find existing runtime to attach to.')
         else:
             self.send_status_message('STATUS$STARTING_CONTAINER')
@@ -113,7 +111,7 @@ class RemoteRuntime(Runtime):
                     f'Running remote runtime with image: {self.config.sandbox.runtime_container_image}'
                 )
                 self.container_image = self.config.sandbox.runtime_container_image
-            self._start_runtime(plugins)
+            self._start_runtime()
         assert (
             self.runtime_id is not None
         ), 'Runtime ID is not set. This should never happen.'
@@ -197,28 +195,27 @@ class RemoteRuntime(Runtime):
         if response.status_code != 200 or not response.json()['exists']:
             raise RuntimeError(f'Container image {self.container_image} does not exist')
 
-    def _start_runtime(self, plugins: list[PluginRequirement] | None):
+    def _start_runtime(self):
         # Prepare the request body for the /start endpoint
-        plugin_arg = ''
-        if plugins is not None and len(plugins) > 0:
-            plugin_arg = f'--plugins {" ".join([plugin.name for plugin in plugins])} '
-        browsergym_arg = (
-            f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}'
-            if self.config.sandbox.browsergym_eval_env is not None
-            else ''
+        plugin_args = []
+        if self.plugins is not None and len(self.plugins) > 0:
+            plugin_args = ['--plugins'] + [plugin.name for plugin in self.plugins]
+        browsergym_args = []
+        if self.config.sandbox.browsergym_eval_env is not None:
+            browsergym_args = [
+                '--browsergym-eval-env'
+            ] + self.config.sandbox.browsergym_eval_env.split(' ')
+        command = get_remote_startup_command(
+            self.port,
+            self.config.workspace_mount_path_in_sandbox,
+            'openhands' if self.config.run_as_openhands else 'root',
+            self.config.sandbox.user_id,
+            plugin_args,
+            browsergym_args,
         )
         start_request = {
             'image': self.container_image,
-            'command': (
-                f'/openhands/micromamba/bin/micromamba run -n openhands '
-                'poetry run '
-                f'python -u -m openhands.runtime.action_execution_server {self.port} '
-                f'--working-dir {self.config.workspace_mount_path_in_sandbox} '
-                f'{plugin_arg}'
-                f'--username {"openhands" if self.config.run_as_openhands else "root"} '
-                f'--user-id {self.config.sandbox.user_id} '
-                f'{browsergym_arg}'
-            ),
+            'command': command,
             'working_dir': '/openhands/code/',
             'environment': {'DEBUG': 'true'} if self.config.debug else {},
             'runtime_id': self.sid,

+ 29 - 0
openhands/runtime/utils/command.py

@@ -0,0 +1,29 @@
+def get_remote_startup_command(
+    port: int,
+    sandbox_workspace_dir: str,
+    username: str,
+    user_id: int,
+    plugin_args: list[str],
+    browsergym_args: list[str],
+):
+    return [
+        '/openhands/micromamba/bin/micromamba',
+        'run',
+        '-n',
+        'openhands',
+        'poetry',
+        'run',
+        'python',
+        '-u',
+        '-m',
+        'openhands.runtime.action_execution_server',
+        str(port),
+        '--working-dir',
+        sandbox_workspace_dir,
+        *plugin_args,
+        '--username',
+        username,
+        '--user-id',
+        str(user_id),
+        *browsergym_args,
+    ]

+ 2 - 2
openhands/server/listen.py

@@ -239,8 +239,8 @@ async def attach_session(request: Request, call_next):
             content={'error': 'Invalid token'},
         )
 
-    request.state.conversation = await call_sync_from_async(
-        session_manager.attach_to_conversation, request.state.sid
+    request.state.conversation = await session_manager.attach_to_conversation(
+        request.state.sid
     )
     if request.state.conversation is None:
         return JSONResponse(

+ 11 - 11
openhands/server/session/agent_session.py

@@ -14,7 +14,6 @@ from openhands.runtime import get_runtime_cls
 from openhands.runtime.base import Runtime
 from openhands.security import SecurityAnalyzer, options
 from openhands.storage.files import FileStore
-from openhands.utils.async_utils import call_sync_from_async
 
 
 class AgentSession:
@@ -88,6 +87,7 @@ class AgentSession:
         try:
             asyncio.run(self._start(*args), debug=True)
         except RuntimeError:
+            logger.error(f'Error starting session: {RuntimeError}', exc_info=True)
             logger.info('Session Finished')
 
     async def _start(
@@ -103,8 +103,7 @@ class AgentSession:
     ):
         self.loop = asyncio.get_running_loop()
         self._create_security_analyzer(config.security.security_analyzer)
-        await call_sync_from_async(
-            self._create_runtime,
+        await self._create_runtime(
             runtime_name=runtime_name,
             config=config,
             agent=agent,
@@ -157,7 +156,7 @@ class AgentSession:
                 security_analyzer, SecurityAnalyzer
             )(self.event_stream)
 
-    def _create_runtime(
+    async def _create_runtime(
         self,
         runtime_name: str,
         config: AppConfig,
@@ -177,15 +176,16 @@ class AgentSession:
 
         logger.info(f'Initializing runtime `{runtime_name}` now...')
         runtime_cls = get_runtime_cls(runtime_name)
+        self.runtime = runtime_cls(
+            config=config,
+            event_stream=self.event_stream,
+            sid=self.sid,
+            plugins=agent.sandbox_plugins,
+            status_message_callback=status_message_callback,
+        )
 
         try:
-            self.runtime = runtime_cls(
-                config=config,
-                event_stream=self.event_stream,
-                sid=self.sid,
-                plugins=agent.sandbox_plugins,
-                status_message_callback=status_message_callback,
-            )
+            await self.runtime.connect()
         except Exception as e:
             logger.error(f'Runtime initialization failed: {e}', exc_info=True)
             raise

+ 3 - 0
openhands/server/session/conversation.py

@@ -34,3 +34,6 @@ class Conversation:
             sid=self.sid,
             attach_to_existing=True,
         )
+
+    async def connect(self):
+        await self.runtime.connect()

+ 4 - 2
openhands/server/session/manager.py

@@ -46,10 +46,12 @@ class SessionManager:
             return None
         return self._sessions.get(sid)
 
-    def attach_to_conversation(self, sid: str) -> Conversation | None:
+    async def attach_to_conversation(self, sid: str) -> Conversation | None:
         if not session_exists(sid, self.file_store):
             return None
-        return Conversation(sid, file_store=self.file_store, config=self.config)
+        c = Conversation(sid, file_store=self.file_store, config=self.config)
+        await c.connect()
+        return c
 
     async def send(self, sid: str, data: dict[str, object]) -> bool:
         """Sends data to the client."""

+ 7 - 5
tests/runtime/conftest.py

@@ -16,6 +16,7 @@ from openhands.runtime.impl.eventstream.eventstream_runtime import EventStreamRu
 from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
 from openhands.runtime.plugins import AgentSkillsRequirement, JupyterRequirement
 from openhands.storage import get_file_store
+from openhands.utils.async_utils import call_async_from_sync
 
 TEST_IN_CI = os.getenv('TEST_IN_CI', 'False').lower() in ['true', '1', 'yes']
 TEST_RUNTIME = os.getenv('TEST_RUNTIME', 'eventstream').lower()
@@ -124,7 +125,7 @@ def temp_dir(tmp_path_factory: TempPathFactory, request) -> str:
 
 
 # Depending on TEST_RUNTIME, feed the appropriate box class(es) to the test.
-def get_box_classes():
+def get_runtime_classes():
     runtime = TEST_RUNTIME
     if runtime.lower() == 'eventstream':
         return [EventStreamRuntime]
@@ -161,8 +162,8 @@ def runtime_setup_session():
 
 # This assures that all tests run together per runtime, not alternating between them,
 # which cause errors (especially outside GitHub actions).
-@pytest.fixture(scope='module', params=get_box_classes())
-def box_class(request):
+@pytest.fixture(scope='module', params=get_runtime_classes())
+def runtime_cls(request):
     time.sleep(1)
     return request.param
 
@@ -202,7 +203,7 @@ def base_container_image(request):
 
 def _load_runtime(
     temp_dir,
-    box_class,
+    runtime_cls,
     run_as_openhands: bool = True,
     enable_auto_lint: bool = False,
     base_container_image: str | None = None,
@@ -252,12 +253,13 @@ def _load_runtime(
     file_store = get_file_store(config.file_store, config.file_store_path)
     event_stream = EventStream(sid, file_store)
 
-    runtime = box_class(
+    runtime = runtime_cls(
         config=config,
         event_stream=event_stream,
         sid=sid,
         plugins=plugins,
     )
+    call_async_from_sync(runtime.connect)
     time.sleep(2)
     return runtime
 

+ 46 - 46
tests/runtime/test_bash.py

@@ -29,8 +29,8 @@ def _run_cmd_action(runtime, custom_command: str, keep_prompt=True):
     return obs
 
 
-def test_bash_command_pexcept(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_bash_command_pexcept(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     try:
         # We set env var PS1="\u@\h:\w $"
         # and construct the PEXCEPT prompt base on it.
@@ -58,8 +58,8 @@ def test_bash_command_pexcept(temp_dir, box_class, run_as_openhands):
         _close_test_runtime(runtime)
 
 
-def test_bash_timeout_and_keyboard_interrupt(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_bash_timeout_and_keyboard_interrupt(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     try:
         action = CmdRunAction(command='python -c "import time; time.sleep(10)"')
         action.timeout = 1
@@ -103,8 +103,8 @@ def test_bash_timeout_and_keyboard_interrupt(temp_dir, box_class, run_as_openhan
         _close_test_runtime(runtime)
 
 
-def test_bash_pexcept_eof(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_bash_pexcept_eof(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     try:
         action = CmdRunAction(command='python3 -m http.server 8080')
         action.timeout = 1
@@ -144,8 +144,8 @@ def test_bash_pexcept_eof(temp_dir, box_class, run_as_openhands):
         _close_test_runtime(runtime)
 
 
-def test_process_resistant_to_one_sigint(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_process_resistant_to_one_sigint(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     try:
         # Create a bash script that ignores SIGINT up to 1 times
         script_content = """
@@ -197,8 +197,8 @@ done
         _close_test_runtime(runtime)
 
 
-def test_process_resistant_to_multiple_sigint(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_process_resistant_to_multiple_sigint(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     try:
         # Create a bash script that ignores SIGINT up to 2 times
         script_content = """
@@ -250,8 +250,8 @@ done
         _close_test_runtime(runtime)
 
 
-def test_multiline_commands(temp_dir, box_class):
-    runtime = _load_runtime(temp_dir, box_class)
+def test_multiline_commands(temp_dir, runtime_cls):
+    runtime = _load_runtime(temp_dir, runtime_cls)
     try:
         # single multiline command
         obs = _run_cmd_action(runtime, 'echo \\\n -e "foo"')
@@ -271,7 +271,7 @@ def test_multiline_commands(temp_dir, box_class):
         _close_test_runtime(runtime)
 
 
-def test_multiple_multiline_commands(temp_dir, box_class, run_as_openhands):
+def test_multiple_multiline_commands(temp_dir, runtime_cls, run_as_openhands):
     cmds = [
         'ls -l',
         'echo -e "hello\nworld"',
@@ -301,7 +301,7 @@ world "
     ]
     joined_cmds = '\n'.join(cmds)
 
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     try:
         obs = _run_cmd_action(runtime, joined_cmds)
         assert obs.exit_code == 0, 'The exit code should be 0.'
@@ -316,9 +316,9 @@ world "
         _close_test_runtime(runtime)
 
 
-def test_no_ps2_in_output(temp_dir, box_class, run_as_openhands):
+def test_no_ps2_in_output(temp_dir, runtime_cls, run_as_openhands):
     """Test that the PS2 sign is not added to the output of a multiline command."""
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     try:
         obs = _run_cmd_action(runtime, 'echo -e "hello\nworld"')
         assert obs.exit_code == 0, 'The exit code should be 0.'
@@ -329,7 +329,7 @@ def test_no_ps2_in_output(temp_dir, box_class, run_as_openhands):
         _close_test_runtime(runtime)
 
 
-def test_multiline_command_loop(temp_dir, box_class):
+def test_multiline_command_loop(temp_dir, runtime_cls):
     # https://github.com/All-Hands-AI/OpenHands/issues/3143
     init_cmd = """
 mkdir -p _modules && \
@@ -347,7 +347,7 @@ for file in _modules/*.md; do
 done
 echo "success"
 """
-    runtime = _load_runtime(temp_dir, box_class)
+    runtime = _load_runtime(temp_dir, runtime_cls)
     try:
         obs = _run_cmd_action(runtime, init_cmd)
         assert obs.exit_code == 0, 'The exit code should be 0.'
@@ -360,8 +360,8 @@ echo "success"
         _close_test_runtime(runtime)
 
 
-def test_cmd_run(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_cmd_run(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     try:
         obs = _run_cmd_action(runtime, 'ls -l /openhands/workspace')
         assert obs.exit_code == 0
@@ -397,8 +397,8 @@ def test_cmd_run(temp_dir, box_class, run_as_openhands):
         _close_test_runtime(runtime)
 
 
-def test_run_as_user_correct_home_dir(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_run_as_user_correct_home_dir(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     try:
         obs = _run_cmd_action(runtime, 'cd ~ && pwd')
         assert obs.exit_code == 0
@@ -410,8 +410,8 @@ def test_run_as_user_correct_home_dir(temp_dir, box_class, run_as_openhands):
         _close_test_runtime(runtime)
 
 
-def test_multi_cmd_run_in_single_line(temp_dir, box_class):
-    runtime = _load_runtime(temp_dir, box_class)
+def test_multi_cmd_run_in_single_line(temp_dir, runtime_cls):
+    runtime = _load_runtime(temp_dir, runtime_cls)
     try:
         obs = _run_cmd_action(runtime, 'pwd && ls -l')
         assert obs.exit_code == 0
@@ -421,8 +421,8 @@ def test_multi_cmd_run_in_single_line(temp_dir, box_class):
         _close_test_runtime(runtime)
 
 
-def test_stateful_cmd(temp_dir, box_class):
-    runtime = _load_runtime(temp_dir, box_class)
+def test_stateful_cmd(temp_dir, runtime_cls):
+    runtime = _load_runtime(temp_dir, runtime_cls)
     sandbox_dir = _get_sandbox_folder(runtime)
     try:
         obs = _run_cmd_action(runtime, 'mkdir -p test')
@@ -438,8 +438,8 @@ def test_stateful_cmd(temp_dir, box_class):
         _close_test_runtime(runtime)
 
 
-def test_failed_cmd(temp_dir, box_class):
-    runtime = _load_runtime(temp_dir, box_class)
+def test_failed_cmd(temp_dir, runtime_cls):
+    runtime = _load_runtime(temp_dir, runtime_cls)
     try:
         obs = _run_cmd_action(runtime, 'non_existing_command')
         assert obs.exit_code != 0, 'The exit code should not be 0 for a failed command.'
@@ -453,8 +453,8 @@ def _create_test_file(host_temp_dir):
         f.write('Hello, World!')
 
 
-def test_copy_single_file(temp_dir, box_class):
-    runtime = _load_runtime(temp_dir, box_class)
+def test_copy_single_file(temp_dir, runtime_cls):
+    runtime = _load_runtime(temp_dir, runtime_cls)
     try:
         sandbox_dir = _get_sandbox_folder(runtime)
         sandbox_file = os.path.join(sandbox_dir, 'test_file.txt')
@@ -483,8 +483,8 @@ def _create_host_test_dir_with_files(test_dir):
         f.write('File 2 content')
 
 
-def test_copy_directory_recursively(temp_dir, box_class):
-    runtime = _load_runtime(temp_dir, box_class)
+def test_copy_directory_recursively(temp_dir, runtime_cls):
+    runtime = _load_runtime(temp_dir, runtime_cls)
 
     sandbox_dir = _get_sandbox_folder(runtime)
     try:
@@ -512,8 +512,8 @@ def test_copy_directory_recursively(temp_dir, box_class):
         _close_test_runtime(runtime)
 
 
-def test_copy_to_non_existent_directory(temp_dir, box_class):
-    runtime = _load_runtime(temp_dir, box_class)
+def test_copy_to_non_existent_directory(temp_dir, runtime_cls):
+    runtime = _load_runtime(temp_dir, runtime_cls)
     try:
         sandbox_dir = _get_sandbox_folder(runtime)
         _create_test_file(temp_dir)
@@ -528,8 +528,8 @@ def test_copy_to_non_existent_directory(temp_dir, box_class):
         _close_test_runtime(runtime)
 
 
-def test_overwrite_existing_file(temp_dir, box_class):
-    runtime = _load_runtime(temp_dir, box_class)
+def test_overwrite_existing_file(temp_dir, runtime_cls):
+    runtime = _load_runtime(temp_dir, runtime_cls)
     try:
         sandbox_dir = _get_sandbox_folder(runtime)
 
@@ -556,8 +556,8 @@ def test_overwrite_existing_file(temp_dir, box_class):
         _close_test_runtime(runtime)
 
 
-def test_copy_non_existent_file(temp_dir, box_class):
-    runtime = _load_runtime(temp_dir, box_class)
+def test_copy_non_existent_file(temp_dir, runtime_cls):
+    runtime = _load_runtime(temp_dir, runtime_cls)
     try:
         sandbox_dir = _get_sandbox_folder(runtime)
         with pytest.raises(FileNotFoundError):
@@ -572,8 +572,8 @@ def test_copy_non_existent_file(temp_dir, box_class):
         _close_test_runtime(runtime)
 
 
-def test_copy_from_directory(temp_dir, box_class):
-    runtime: Runtime = _load_runtime(temp_dir, box_class)
+def test_copy_from_directory(temp_dir, runtime_cls):
+    runtime: Runtime = _load_runtime(temp_dir, runtime_cls)
     sandbox_dir = _get_sandbox_folder(runtime)
     try:
         temp_dir_copy = os.path.join(temp_dir, 'test_dir')
@@ -592,10 +592,10 @@ def test_copy_from_directory(temp_dir, box_class):
         _close_test_runtime(runtime)
 
 
-def test_keep_prompt(box_class, temp_dir):
+def test_keep_prompt(runtime_cls, temp_dir):
     runtime = _load_runtime(
         temp_dir,
-        box_class=box_class,
+        runtime_cls=runtime_cls,
         run_as_openhands=False,
     )
     try:
@@ -618,13 +618,13 @@ def test_keep_prompt(box_class, temp_dir):
     TEST_IN_CI != 'True',
     reason='This test is not working in WSL (file ownership)',
 )
-def test_git_operation(box_class):
+def test_git_operation(runtime_cls):
     # do not mount workspace, since workspace mount by tests will be owned by root
     # while the user_id we get via os.getuid() is different from root
     # which causes permission issues
     runtime = _load_runtime(
         temp_dir=None,
-        box_class=box_class,
+        runtime_cls=runtime_cls,
         # Need to use non-root user to expose issues
         run_as_openhands=True,
     )
@@ -670,8 +670,8 @@ def test_git_operation(box_class):
         _close_test_runtime(runtime)
 
 
-def test_python_version(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_python_version(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     try:
         obs = runtime.run_action(CmdRunAction(command='python --version'))
 

+ 4 - 4
tests/runtime/test_browsing.py

@@ -22,8 +22,8 @@ from openhands.events.observation import (
 PY3_FOR_TESTING = '/openhands/micromamba/bin/micromamba run -n openhands python3'
 
 
-def test_simple_browse(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_simple_browse(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
 
     # Test browse
     action_cmd = CmdRunAction(
@@ -68,10 +68,10 @@ def test_simple_browse(temp_dir, box_class, run_as_openhands):
     _close_test_runtime(runtime)
 
 
-def test_browsergym_eval_env(box_class, temp_dir):
+def test_browsergym_eval_env(runtime_cls, temp_dir):
     runtime = _load_runtime(
         temp_dir,
-        box_class=box_class,
+        runtime_cls=runtime_cls,
         run_as_openhands=False,  # need root permission to access file
         base_container_image='xingyaoww/od-eval-miniwob:v1.0',
         browsergym_eval_env='browsergym/miniwob.choose-list',

+ 6 - 6
tests/runtime/test_edit.py

@@ -31,8 +31,8 @@ if __name__ == '__main__':
     TEST_IN_CI != 'True',
     reason='This test requires LLM to run.',
 )
-def test_edit_from_scratch(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_edit_from_scratch(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     try:
         action = FileEditAction(
             content=ORGINAL,
@@ -71,8 +71,8 @@ def index():
     TEST_IN_CI != 'True',
     reason='This test requires LLM to run.',
 )
-def test_edit(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_edit(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     try:
         action = FileEditAction(
             content=ORGINAL,
@@ -130,8 +130,8 @@ This is line 101 + 10
     TEST_IN_CI != 'True',
     reason='This test requires LLM to run.',
 )
-def test_edit_long_file(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_edit_long_file(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     try:
         action = FileEditAction(
             content=ORIGINAL_LONG,

+ 7 - 7
tests/runtime/test_env_vars.py

@@ -13,9 +13,9 @@ from openhands.events.observation import CmdOutputObservation
 # ============================================================================================================================
 
 
-def test_env_vars_os_environ(temp_dir, box_class, run_as_openhands):
+def test_env_vars_os_environ(temp_dir, runtime_cls, run_as_openhands):
     with patch.dict(os.environ, {'SANDBOX_ENV_FOOBAR': 'BAZ'}):
-        runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+        runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
 
         obs: CmdOutputObservation = runtime.run_action(CmdRunAction(command='env'))
         print(obs)
@@ -27,13 +27,13 @@ def test_env_vars_os_environ(temp_dir, box_class, run_as_openhands):
         assert obs.exit_code == 0, 'The exit code should be 0.'
         assert (
             obs.content.strip().split('\n\r')[0].strip() == 'BAZ'
-        ), f'Output: [{obs.content}] for {box_class}'
+        ), f'Output: [{obs.content}] for {runtime_cls}'
 
         _close_test_runtime(runtime)
 
 
-def test_env_vars_runtime_operations(temp_dir, box_class):
-    runtime = _load_runtime(temp_dir, box_class)
+def test_env_vars_runtime_operations(temp_dir, runtime_cls):
+    runtime = _load_runtime(temp_dir, runtime_cls)
 
     # Test adding single env var
     runtime.add_env_vars({'QUUX': 'abc"def'})
@@ -67,10 +67,10 @@ def test_env_vars_runtime_operations(temp_dir, box_class):
     _close_test_runtime(runtime)
 
 
-def test_env_vars_added_by_config(temp_dir, box_class):
+def test_env_vars_added_by_config(temp_dir, runtime_cls):
     runtime = _load_runtime(
         temp_dir,
-        box_class,
+        runtime_cls,
         runtime_startup_env_vars={'ADDED_ENV_VAR': 'added_value'},
     )
 

+ 6 - 6
tests/runtime/test_images.py

@@ -11,7 +11,7 @@ from openhands.events.action import CmdRunAction
 # ============================================================================================================================
 
 
-def test_bash_python_version(temp_dir, box_class, base_container_image):
+def test_bash_python_version(temp_dir, runtime_cls, base_container_image):
     """Make sure Python is available in bash."""
     if base_container_image not in [
         'python:3.12-bookworm',
@@ -19,7 +19,7 @@ def test_bash_python_version(temp_dir, box_class, base_container_image):
         pytest.skip('This test is only for python-related images')
 
     runtime = _load_runtime(
-        temp_dir, box_class, base_container_image=base_container_image
+        temp_dir, runtime_cls, base_container_image=base_container_image
     )
 
     action = CmdRunAction(command='which python')
@@ -45,7 +45,7 @@ def test_bash_python_version(temp_dir, box_class, base_container_image):
     _close_test_runtime(runtime)
 
 
-def test_nodejs_22_version(temp_dir, box_class, base_container_image):
+def test_nodejs_22_version(temp_dir, runtime_cls, base_container_image):
     """Make sure Node.js is available in bash."""
     if base_container_image not in [
         'node:22-bookworm',
@@ -53,7 +53,7 @@ def test_nodejs_22_version(temp_dir, box_class, base_container_image):
         pytest.skip('This test is only for nodejs-related images')
 
     runtime = _load_runtime(
-        temp_dir, box_class, base_container_image=base_container_image
+        temp_dir, runtime_cls, base_container_image=base_container_image
     )
 
     action = CmdRunAction(command='node --version')
@@ -66,7 +66,7 @@ def test_nodejs_22_version(temp_dir, box_class, base_container_image):
     _close_test_runtime(runtime)
 
 
-def test_go_version(temp_dir, box_class, base_container_image):
+def test_go_version(temp_dir, runtime_cls, base_container_image):
     """Make sure Go is available in bash."""
     if base_container_image not in [
         'golang:1.23-bookworm',
@@ -74,7 +74,7 @@ def test_go_version(temp_dir, box_class, base_container_image):
         pytest.skip('This test is only for go-related images')
 
     runtime = _load_runtime(
-        temp_dir, box_class, base_container_image=base_container_image
+        temp_dir, runtime_cls, base_container_image=base_container_image
     )
 
     action = CmdRunAction(command='go version')

+ 8 - 8
tests/runtime/test_ipython.py

@@ -28,8 +28,8 @@ from openhands.events.observation import (
 # ============================================================================================================================
 
 
-def test_simple_cmd_ipython_and_fileop(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_simple_cmd_ipython_and_fileop(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
 
     sandbox_dir = _get_sandbox_folder(runtime)
 
@@ -102,8 +102,8 @@ def test_simple_cmd_ipython_and_fileop(temp_dir, box_class, run_as_openhands):
     TEST_IN_CI != 'True',
     reason='This test is not working in WSL (file ownership)',
 )
-def test_ipython_multi_user(temp_dir, box_class, run_as_openhands):
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+def test_ipython_multi_user(temp_dir, runtime_cls, run_as_openhands):
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
 
     # Test run ipython
     # get username
@@ -174,8 +174,8 @@ def test_ipython_multi_user(temp_dir, box_class, run_as_openhands):
     _close_test_runtime(runtime)
 
 
-def test_ipython_simple(temp_dir, box_class):
-    runtime = _load_runtime(temp_dir, box_class)
+def test_ipython_simple(temp_dir, runtime_cls):
+    runtime = _load_runtime(temp_dir, runtime_cls)
     sandbox_dir = _get_sandbox_folder(runtime)
 
     # Test run ipython
@@ -198,9 +198,9 @@ def test_ipython_simple(temp_dir, box_class):
     _close_test_runtime(runtime)
 
 
-def test_ipython_package_install(temp_dir, box_class, run_as_openhands):
+def test_ipython_package_install(temp_dir, runtime_cls, run_as_openhands):
     """Make sure that cd in bash also update the current working directory in ipython."""
-    runtime = _load_runtime(temp_dir, box_class, run_as_openhands)
+    runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
     sandbox_dir = _get_sandbox_folder(runtime)
 
     # It should error out since pymsgbox is not installed