Browse Source

refactor & improve retry for the reliability of `RemoteRuntime` & evaluation (#3846)

Xingyao Wang 1 năm trước cách đây
mục cha
commit
78c5f58adc

+ 2 - 0
evaluation/swe_bench/run_infer.py

@@ -176,6 +176,7 @@ def initialize_runtime(
 
         # inject the instance info
         action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances')
+        action.timeout = 600
         logger.info(action, extra={'msg_type': 'ACTION'})
         obs = runtime.run_action(action)
         logger.info(obs, extra={'msg_type': 'OBSERVATION'})
@@ -233,6 +234,7 @@ def initialize_runtime(
         ), f'Failed to source /swe_util/swe_entry.sh: {obs.content}'
 
     action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
+    action.timeout = 600
     logger.info(action, extra={'msg_type': 'ACTION'})
     obs = runtime.run_action(action)
     logger.info(obs, extra={'msg_type': 'OBSERVATION'})

+ 73 - 29
evaluation/utils/shared.py

@@ -5,6 +5,7 @@ import os
 import pathlib
 import subprocess
 import time
+import traceback
 from concurrent.futures import ProcessPoolExecutor
 from typing import Any, Awaitable, Callable
 
@@ -77,6 +78,12 @@ class EvalOutput(BaseModel):
         return json.dumps(dumped_dict)
 
 
+class EvalError(BaseModel):
+    instance_id: str
+    error: str
+    stacktrace: str
+
+
 def codeact_user_response(
     state: State,
     encapsulate_solution: bool = False,
@@ -227,6 +234,20 @@ def prepare_dataset(
     return pd.DataFrame(new_dataset)
 
 
+def process_instance(
+    instance, metadata, use_multiprocessing, process_instance_func
+) -> EvalOutput | EvalError:
+    try:
+        return process_instance_func(instance, metadata, use_multiprocessing)
+    except Exception as e:
+        logger.error(f'Error processing instance [{instance.instance_id}]: {e}')
+        return EvalError(
+            instance_id=instance.instance_id,
+            error=str(e),
+            stacktrace=traceback.format_exc(),
+        )
+
+
 def run_evaluation(
     dataset: pd.DataFrame,
     metadata: EvalMetadata,
@@ -241,42 +262,65 @@ def run_evaluation(
         f'Evaluation started with Agent {metadata.agent_class}:\n'
         f'model {metadata.llm_config.model}, max iterations {metadata.max_iterations}.\n'
     )
-    pbar = tqdm(total=len(dataset))
-    output_fp = open(output_file, 'a')
 
-    def update_progress(future):
-        pbar.update(1)
-        output: EvalOutput = future.result() if use_multiprocessing else future
+    instance_queue = mp.Queue()
+    for _, instance in dataset.iterrows():
+        instance_queue.put(instance)
 
-        pbar.set_description(f'Instance {output.instance_id}')
-        pbar.set_postfix_str(f'Test Result: {output.test_result}')
-        logger.info(
-            f'Finished evaluation for instance {output.instance_id}: {str(output.test_result)[:300]}...\n'
-        )
-        output_fp.write(json.dumps(output.model_dump()) + '\n')
-        output_fp.flush()
+    total_instances = instance_queue.qsize()
+    pbar = tqdm(total=total_instances, desc='Instances processed')
+    output_fp = open(output_file, 'a')
+
+    def update_progress(result: EvalOutput | EvalError, instance: pd.Series):
+        if isinstance(result, EvalOutput):
+            pbar.update(1)
+            pbar.set_description(f'Instance {result.instance_id}')
+            pbar.set_postfix_str(f'Test Result: {result.test_result}')
+            logger.info(
+                f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
+            )
+            output_fp.write(json.dumps(result.model_dump()) + '\n')
+            output_fp.flush()
+        else:
+            logger.error(
+                f'Retrying instance [{instance.instance_id}] due to error: {result.error}. Stacktrace:\n{result.stacktrace}'
+                + '\n'
+                + '-' * 10
+                + '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
+                + '-' * 10
+                + '\n'
+            )
+            instance_queue.put(instance)
+            pbar.total += 1
+            pbar.refresh()
 
     try:
         if use_multiprocessing:
             with ProcessPoolExecutor(num_workers) as executor:
-                futures = []
-                for _, instance in dataset.iterrows():
-                    future = executor.submit(
-                        process_instance_func,
-                        instance,
-                        metadata,
-                        bool(num_workers > 1),
-                    )
-                    future.add_done_callback(update_progress)
-                    futures.append(future)
-                for future in futures:
-                    future.result()
-        # Use plain for loop for single process for easier debugging
+                while not instance_queue.empty():
+                    futures = []
+                    for _ in range(min(num_workers, instance_queue.qsize())):
+                        instance = instance_queue.get()
+                        future = executor.submit(
+                            process_instance,
+                            instance,
+                            metadata,
+                            True,
+                            process_instance_func,
+                        )
+                        future.add_done_callback(
+                            lambda f, inst=instance: update_progress(f.result(), inst)
+                        )
+                        futures.append(future)
+                    for future in futures:
+                        future.result()
         else:
-            assert num_workers == 1
-            for _, instance in dataset.iterrows():
-                output = process_instance_func(instance, metadata, False)
-                update_progress(output)
+            while not instance_queue.empty():
+                instance = instance_queue.get()
+                result = process_instance(
+                    instance, metadata, False, process_instance_func
+                )
+                update_progress(result, instance)
 
     except KeyboardInterrupt:
         print('\nKeyboardInterrupt received. Cleaning up...\n')

+ 13 - 8
openhands/runtime/builder/remote.py

@@ -7,6 +7,7 @@ import requests
 
 from openhands.core.logger import openhands_logger as logger
 from openhands.runtime.builder import RuntimeBuilder
+from openhands.runtime.utils.request import send_request
 
 
 class RemoteRuntimeBuilder(RuntimeBuilder):
@@ -15,6 +16,8 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
     def __init__(self, api_url: str, api_key: str):
         self.api_url = api_url
         self.api_key = api_key
+        self.session = requests.Session()
+        self.session.headers.update({'X-API-Key': self.api_key})
 
     def build(self, path: str, tags: list[str]) -> str:
         """Builds a Docker image using the Runtime API's /build endpoint."""
@@ -38,8 +41,9 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
             files.append(('tags', (None, tag)))
 
         # Send the POST request to /build
-        headers = {'X-API-Key': self.api_key}
-        response = requests.post(f'{self.api_url}/build', files=files, headers=headers)
+        response = send_request(
+            self.session, 'POST', f'{self.api_url}/build', files=files
+        )
 
         if response.status_code != 202:
             logger.error(f'Build initiation failed: {response.text}')
@@ -57,10 +61,11 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
                 logger.error('Build timed out after 30 minutes')
                 raise RuntimeError('Build timed out after 30 minutes')
 
-            status_response = requests.get(
+            status_response = send_request(
+                self.session,
+                'GET',
                 f'{self.api_url}/build_status',
                 params={'build_id': build_id},
-                headers=headers,
             )
 
             if status_response.status_code != 200:
@@ -90,14 +95,14 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
                 raise RuntimeError(error_message)
 
             # Wait before polling again
-            time.sleep(5)
+            time.sleep(30)
 
     def image_exists(self, image_name: str) -> bool:
         """Checks if an image exists in the remote registry using the /image_exists endpoint."""
         params = {'image': image_name}
-        session = requests.Session()
-        session.headers.update({'X-API-Key': self.api_key})
-        response = session.get(f'{self.api_url}/image_exists', params=params)
+        response = send_request(
+            self.session, 'GET', f'{self.api_url}/image_exists', params=params
+        )
 
         if response.status_code != 200:
             logger.error(f'Failed to check image existence: {response.text}')

+ 35 - 43
openhands/runtime/remote/runtime.py

@@ -1,13 +1,11 @@
 import os
-import ssl
 import tempfile
 import threading
 import uuid
-from typing import Any, Type
 from zipfile import ZipFile
 
 import requests
-from requests.exceptions import HTTPError, RequestException, Timeout
+from requests.exceptions import Timeout
 from tenacity import (
     retry,
     retry_if_exception_type,
@@ -37,15 +35,13 @@ from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
 from openhands.runtime.builder.remote import RemoteRuntimeBuilder
 from openhands.runtime.plugins import PluginRequirement
 from openhands.runtime.runtime import Runtime
+from openhands.runtime.utils.request import (
+    DEFAULT_RETRY_EXCEPTIONS,
+    is_404_error,
+    send_request,
+)
 from openhands.runtime.utils.runtime_build import build_runtime_image
 
-DEFAULT_RETRY_EXCEPTIONS = [
-    ssl.SSLCertVerificationError,
-    RequestException,
-    HTTPError,
-    Timeout,
-]
-
 
 class RemoteRuntime(Runtime):
     """This runtime will connect to a remote od-runtime-client."""
@@ -99,7 +95,7 @@ class RemoteRuntime(Runtime):
         self.container_image: str = self.config.sandbox.base_container_image
         self.container_name = 'od-remote-runtime-' + self.instance_id
         logger.debug(f'RemoteRuntime `{sid}` config:\n{self.config}')
-        response = self._send_request('GET', f'{self.api_url}/registry_prefix')
+        response = send_request(self.session, 'GET', f'{self.api_url}/registry_prefix')
         response_json = response.json()
         registry_prefix = response_json['registry_prefix']
         os.environ['OD_RUNTIME_RUNTIME_IMAGE_REPO'] = (
@@ -122,7 +118,8 @@ class RemoteRuntime(Runtime):
         )
 
         # Use the /image_exists endpoint to check if the image exists
-        response = self._send_request(
+        response = send_request(
+            self.session,
             'GET',
             f'{self.api_url}/image_exists',
             params={'image': self.container_image},
@@ -157,8 +154,8 @@ class RemoteRuntime(Runtime):
         }
 
         # Start the sandbox using the /start endpoint
-        response = self._send_request(
-            'POST', f'{self.api_url}/start', json=start_request
+        response = send_request(
+            self.session, 'POST', f'{self.api_url}/start', json=start_request
         )
         if response.status_code != 201:
             raise RuntimeError(f'Failed to start sandbox: {response.text}')
@@ -184,29 +181,6 @@ class RemoteRuntime(Runtime):
             self.runtime_url is not None
         ), 'Runtime URL is not set. This should never happen.'
 
-    def _send_request(
-        self,
-        method: str,
-        url: str,
-        retry_exceptions: list[Type[Exception]] | None = None,
-        **kwargs: Any,
-    ) -> requests.Response:
-        if retry_exceptions is None:
-            retry_exceptions = DEFAULT_RETRY_EXCEPTIONS
-
-        @retry(
-            stop=stop_after_attempt(30),
-            wait=wait_exponential(multiplier=1, min=4, max=60),
-            retry=retry_if_exception_type(tuple(retry_exceptions)),
-            reraise=True,
-        )
-        def _send_request_with_retry():
-            response = self.session.request(method, url, **kwargs)
-            response.raise_for_status()
-            return response
-
-        return _send_request_with_retry()
-
     @retry(
         stop=stop_after_attempt(10),
         wait=wait_exponential(multiplier=1, min=4, max=60),
@@ -215,7 +189,15 @@ class RemoteRuntime(Runtime):
     )
     def _wait_until_alive(self):
         logger.info('Waiting for sandbox to be alive...')
-        response = self._send_request('GET', f'{self.runtime_url}/alive')
+        response = send_request(
+            self.session,
+            'GET',
+            f'{self.runtime_url}/alive',
+            # Retry 404 errors for the /alive endpoint
+            # because the runtime might just be starting up
+            # and have not registered the endpoint yet
+            retry_fns=[is_404_error],
+        )
         if response.status_code != 200:
             msg = f'Runtime is not alive yet (id={self.runtime_id}). Status: {response.status_code}.'
             logger.warning(msg)
@@ -228,8 +210,11 @@ class RemoteRuntime(Runtime):
     def close(self):
         if self.runtime_id:
             try:
-                response = self._send_request(
-                    'POST', f'{self.api_url}/stop', json={'runtime_id': self.runtime_id}
+                response = send_request(
+                    self.session,
+                    'POST',
+                    f'{self.api_url}/stop',
+                    json={'runtime_id': self.runtime_id},
                 )
                 if response.status_code != 200:
                     logger.error(f'Failed to stop sandbox: {response.text}')
@@ -262,7 +247,8 @@ class RemoteRuntime(Runtime):
                 logger.info('Executing action')
                 request_body = {'action': event_to_dict(action)}
                 logger.debug(f'Request body: {request_body}')
-                response = self._send_request(
+                response = send_request(
+                    self.session,
                     'POST',
                     f'{self.runtime_url}/execute_action',
                     json=request_body,
@@ -270,6 +256,10 @@ class RemoteRuntime(Runtime):
                     retry_exceptions=list(
                         filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
                     ),
+                    # Retry 404 errors for the /execute_action endpoint
+                    # because the runtime might just be starting up
+                    # and have not registered the endpoint yet
+                    retry_fns=[is_404_error],
                 )
                 if response.status_code == 200:
                     output = response.json()
@@ -335,7 +325,8 @@ class RemoteRuntime(Runtime):
 
             params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
 
-            response = self._send_request(
+            response = send_request(
+                self.session,
                 'POST',
                 f'{self.runtime_url}/upload_file',
                 files=upload_data,
@@ -368,7 +359,8 @@ class RemoteRuntime(Runtime):
             if path is not None:
                 data['path'] = path
 
-            response = self._send_request(
+            response = send_request(
+                self.session,
                 'POST',
                 f'{self.runtime_url}/list_files',
                 json=data,

+ 62 - 0
openhands/runtime/utils/request.py

@@ -0,0 +1,62 @@
+from typing import Any, Callable, Type
+
+import requests
+from requests.exceptions import ConnectionError, Timeout
+from tenacity import (
+    retry,
+    retry_if_exception,
+    retry_if_exception_type,
+    stop_after_attempt,
+    wait_exponential,
+)
+
+
+def is_server_error(exception):
+    return (
+        isinstance(exception, requests.HTTPError)
+        and exception.response.status_code >= 500
+    )
+
+
+def is_404_error(exception):
+    return (
+        isinstance(exception, requests.HTTPError)
+        and exception.response.status_code == 404
+    )
+
+
+DEFAULT_RETRY_EXCEPTIONS = [
+    ConnectionError,
+    Timeout,
+]
+
+
+def send_request(
+    session: requests.Session,
+    method: str,
+    url: str,
+    retry_exceptions: list[Type[Exception]] | None = None,
+    retry_fns: list[Callable[[Exception], bool]] | None = None,
+    n_attempts: int = 15,
+    **kwargs: Any,
+) -> requests.Response:
+    exceptions_to_catch = retry_exceptions or DEFAULT_RETRY_EXCEPTIONS
+    retry_condition = retry_if_exception_type(
+        tuple(exceptions_to_catch)
+    ) | retry_if_exception(is_server_error)
+    if retry_fns is not None:
+        for fn in retry_fns:
+            retry_condition |= retry_if_exception(fn)
+
+    @retry(
+        stop=stop_after_attempt(n_attempts),
+        wait=wait_exponential(multiplier=1, min=4, max=60),
+        retry=retry_condition,
+        reraise=True,
+    )
+    def _send_request_with_retry():
+        response = session.request(method, url, **kwargs)
+        response.raise_for_status()
+        return response
+
+    return _send_request_with_retry()