瀏覽代碼

fix(eval): support setting hard timeout per evaluation instance (#5110)

Xingyao Wang 1 年之前
父節點
當前提交
a531413d86
共有 2 個文件被更改,包括 56 次插入7 次删除
  1. 1 1
      evaluation/swe_bench/run_infer.py
  2. 55 6
      evaluation/utils/shared.py

+ 1 - 1
evaluation/swe_bench/run_infer.py

@@ -145,7 +145,7 @@ def get_config(
             platform='linux/amd64',
             api_key=os.environ.get('ALLHANDS_API_KEY', None),
             remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
-            keep_remote_runtime_alive=False,
+            keep_runtime_alive=False,
             remote_runtime_init_timeout=3600,
         ),
         # do not mount workspace

+ 55 - 6
evaluation/utils/shared.py

@@ -3,9 +3,11 @@ import logging
 import multiprocessing as mp
 import os
 import pathlib
+import signal
 import subprocess
 import time
 import traceback
+from contextlib import contextmanager
 from typing import Any, Awaitable, Callable, TextIO
 
 import pandas as pd
@@ -92,6 +94,27 @@ class EvalException(Exception):
     pass
 
 
+class EvalTimeoutException(Exception):
+    pass
+
+
+@contextmanager
+def timeout(seconds: int):
+    def timeout_handler(signum, frame):
+        raise EvalTimeoutException(f'Function timed out after {seconds} seconds')
+
+    # Set up the signal handler
+    original_handler = signal.signal(signal.SIGALRM, timeout_handler)
+    signal.alarm(seconds)
+
+    try:
+        yield
+    finally:
+        # Restore the original handler and disable the alarm
+        signal.alarm(0)
+        signal.signal(signal.SIGALRM, original_handler)
+
+
 def codeact_user_response(
     state: State,
     encapsulate_solution: bool = False,
@@ -280,15 +303,33 @@ def _process_instance_wrapper(
     metadata: EvalMetadata,
     use_mp: bool,
     max_retries: int = 5,
+    timeout_seconds: int | None = None,
 ) -> EvalOutput:
-    """Wrap the process_instance_func to handle retries and errors.
-
-    Retry an instance up to max_retries times if it fails (e.g., due to transient network/runtime issues).
-    """
+    """Wrap the process_instance_func to handle retries and errors."""
     for attempt in range(max_retries + 1):
         try:
-            result = process_instance_func(instance, metadata, use_mp)
+            if timeout_seconds is not None:
+                with timeout(timeout_seconds):
+                    result = process_instance_func(instance, metadata, use_mp)
+            else:
+                result = process_instance_func(instance, metadata, use_mp)
             return result
+        except EvalTimeoutException as e:
+            error = f'Timeout after {timeout_seconds} seconds'
+            stacktrace = traceback.format_exc()
+            msg = (
+                '-' * 10
+                + '\n'
+                + f'Timeout ({timeout_seconds} seconds) in instance [{instance.instance_id}], Stopped evaluation for this instance.'
+                + '\n'
+                + '-' * 10
+            )
+            logger.exception(e)
+            return EvalOutput(
+                instance_id=instance.instance_id,
+                test_result={},
+                error=error,
+            )
         except Exception as e:
             error = str(e)
             stacktrace = traceback.format_exc()
@@ -337,6 +378,7 @@ def run_evaluation(
         [pd.Series, EvalMetadata, bool], Awaitable[EvalOutput]
     ],
     max_retries: int = 5,  # number of retries for each instance
+    timeout_seconds: int | None = None,
 ):
     use_multiprocessing = num_workers > 1
 
@@ -357,7 +399,14 @@ def run_evaluation(
         if use_multiprocessing:
             with mp.Pool(num_workers) as pool:
                 args_iter = (
-                    (process_instance_func, instance, metadata, True, max_retries)
+                    (
+                        process_instance_func,
+                        instance,
+                        metadata,
+                        True,
+                        max_retries,
+                        timeout_seconds,
+                    )
                     for _, instance in dataset.iterrows()
                 )
                 results = pool.imap_unordered(_process_instance_wrapper_mp, args_iter)