Parcourir la source

feat(eval): increase resource factor for remote runtime when previous run failed due to resource (#5709)

Xingyao Wang il y a 1 an
Parent
commit
581d5ec7a8

+ 10 - 0
evaluation/benchmarks/swe_bench/run_infer.py

@@ -370,6 +370,7 @@ def process_instance(
     instance: pd.Series,
     metadata: EvalMetadata,
     reset_logger: bool = True,
+    runtime_failure_count: int = 0,
 ) -> EvalOutput:
     config = get_config(instance, metadata)
 
@@ -380,6 +381,15 @@ def process_instance(
     else:
         logger.info(f'Starting evaluation for instance {instance.instance_id}.')
 
+    # Increase resource_factor with increasing attempt_id
+    if runtime_failure_count > 0:
+        config.sandbox.remote_runtime_resource_factor = min(
+            config.sandbox.remote_runtime_resource_factor * (2**runtime_failure_count),
+            2,  # hardcode maximum resource factor to 2
+        )
+        logger.warning(
+            f'This is the second attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
+        )
     runtime = create_runtime(config)
     call_async_from_sync(runtime.connect)
 

+ 15 - 4
evaluation/utils/shared.py

@@ -8,6 +8,7 @@ import subprocess
 import time
 import traceback
 from contextlib import contextmanager
+from inspect import signature
 from typing import Any, Awaitable, Callable, TextIO
 
 import pandas as pd
@@ -24,7 +25,6 @@ from openhands.core.exceptions import (
     AgentRuntimeNotReadyError,
     AgentRuntimeTimeoutError,
     AgentRuntimeUnavailableError,
-    AgentStuckInLoopError,
 )
 from openhands.core.logger import get_console_handler
 from openhands.core.logger import openhands_logger as logger
@@ -316,13 +316,20 @@ def _process_instance_wrapper(
     timeout_seconds: int | None = None,
 ) -> EvalOutput:
     """Wrap the process_instance_func to handle retries and errors."""
+    runtime_failure_count = 0
     for attempt in range(max_retries + 1):
         try:
+            kwargs = {}
+            # check if process_instance_func accepts timeout_seconds parameter
+            sig = signature(process_instance_func)
+            if 'runtime_failure_count' in sig.parameters:
+                kwargs['runtime_failure_count'] = runtime_failure_count
+
             if timeout_seconds is not None:
                 with timeout(timeout_seconds):
-                    result = process_instance_func(instance, metadata, use_mp)
+                    result = process_instance_func(instance, metadata, use_mp, **kwargs)
             else:
-                result = process_instance_func(instance, metadata, use_mp)
+                result = process_instance_func(instance, metadata, use_mp, **kwargs)
             return result
         except EvalTimeoutException as e:
             error = f'Timeout after {timeout_seconds} seconds'
@@ -368,6 +375,11 @@ def _process_instance_wrapper(
                 + '-' * 10
                 + '\n'
             )
+            if isinstance(
+                e, (AgentRuntimeDisconnectedError, AgentRuntimeUnavailableError)
+            ):
+                runtime_failure_count += 1
+                msg += f'Runtime disconnected error detected for instance {instance.instance_id}, runtime failure count: {runtime_failure_count}'
             logger.error(msg)
             if use_mp:
                 print(msg)  # use print to directly print to console
@@ -527,7 +539,6 @@ def is_fatal_evaluation_error(error: str | None) -> bool:
         AgentRuntimeNotReadyError,
         AgentRuntimeDisconnectedError,
         AgentRuntimeNotFoundError,
-        AgentStuckInLoopError,
     ]
 
     if any(exception.__name__ in error for exception in FATAL_EXCEPTIONS):

+ 3 - 0
openhands/core/config/sandbox_config.py

@@ -32,6 +32,8 @@ class SandboxConfig:
         browsergym_eval_env: The BrowserGym environment to use for evaluation.
             Default is None for general purpose browsing. Check evaluation/miniwob and evaluation/webarena for examples.
         platform: The platform on which the image should be built. Default is None.
+        remote_runtime_resource_factor: Factor to scale the resource allocation for remote runtime.
+            Must be one of [1, 2, 4, 8]. Will only be used if the runtime is remote.
     """
 
     remote_runtime_api_url: str = 'http://localhost:8000'
@@ -56,6 +58,7 @@ class SandboxConfig:
     browsergym_eval_env: str | None = None
     platform: str | None = None
     close_delay: int = 15
+    remote_runtime_resource_factor: int = 1
 
     def defaults_to_dict(self) -> dict:
         """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""

+ 9 - 5
openhands/runtime/impl/remote/remote_runtime.py

@@ -41,6 +41,7 @@ from openhands.runtime.builder.remote import RemoteRuntimeBuilder
 from openhands.runtime.plugins import PluginRequirement
 from openhands.runtime.utils.command import get_remote_startup_command
 from openhands.runtime.utils.request import (
+    RequestHTTPError,
     send_request,
 )
 from openhands.runtime.utils.runtime_build import build_runtime_image
@@ -246,6 +247,7 @@ class RemoteRuntime(Runtime):
             'working_dir': '/openhands/code/',
             'environment': {'DEBUG': 'true'} if self.config.debug else {},
             'session_id': self.sid,
+            'resource_factor': self.config.sandbox.remote_runtime_resource_factor,
         }
 
         # Start the sandbox using the /start endpoint
@@ -451,11 +453,11 @@ class RemoteRuntime(Runtime):
         except requests.Timeout:
             self.log('error', 'No response received within the timeout period.')
             raise
-        except requests.HTTPError as e:
-            if is_runtime_request and e.response.status_code == 404:
+        except RequestHTTPError as e:
+            if is_runtime_request and e.response.status_code in (404, 502):
                 raise AgentRuntimeDisconnectedError(
-                    f'404 error while connecting to {self.runtime_url}'
-                )
+                    f'{e.response.status_code} error while connecting to {self.runtime_url}'
+                ) from e
             elif is_runtime_request and e.response.status_code == 503:
                 if not is_retry:
                     self.log('warning', 'Runtime appears to be paused. Resuming...')
@@ -463,7 +465,9 @@ class RemoteRuntime(Runtime):
                     self._wait_until_alive()
                     return self._send_request(method, url, True, **kwargs)
                 else:
-                    raise e
+                    raise AgentRuntimeUnavailableError(
+                        f'{e.response.status_code} error while connecting to {self.runtime_url}'
+                    ) from e
 
             else:
                 raise e

+ 6 - 3
openhands/runtime/utils/request.py

@@ -1,3 +1,4 @@
+import json
 from typing import Any
 
 import requests
@@ -30,9 +31,11 @@ def send_request(
     except requests.HTTPError as e:
         try:
             _json = response.json()
-        except requests.JSONDecodeError:
-            raise e
+        except (requests.exceptions.JSONDecodeError, json.decoder.JSONDecodeError):
+            _json = None
         raise RequestHTTPError(
-            e, response=e.response, detail=_json.get('detail')
+            e,
+            response=e.response,
+            detail=_json.get('detail') if _json is not None else None,
         ) from e
     return response