Browse Source

[eval] Fix multi-processing bug (again^3) & allow set EXP_NAME for each `run_infer` (#3907)

Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk>
Xingyao Wang 1 year ago
parent
commit
f996b31d64

+ 1 - 1
evaluation/swe_bench/run_infer.py

@@ -218,7 +218,7 @@ def initialize_runtime(
         assert obs.exit_code == 0
 
         action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
-        action.timeout = 1800
+        action.timeout = 3600
         logger.info(action, extra={'msg_type': 'ACTION'})
         obs = runtime.run_action(action)
         logger.info(obs, extra={'msg_type': 'OBSERVATION'})

+ 5 - 0
evaluation/swe_bench/scripts/run_infer.sh

@@ -66,6 +66,11 @@ if [ "$USE_HINT_TEXT" = false ]; then
   EVAL_NOTE="$EVAL_NOTE-no-hint"
 fi
 
+if [ -n "$EXP_NAME" ]; then
+  EVAL_NOTE="$EVAL_NOTE-$EXP_NAME"
+fi
+echo "EVAL_NOTE: $EVAL_NOTE"
+
 unset SANDBOX_ENV_GITHUB_TOKEN # prevent the agent from using the github token to push
 
 COMMAND="poetry run python evaluation/swe_bench/run_infer.py \

+ 70 - 109
evaluation/utils/shared.py

@@ -6,7 +6,7 @@ import pathlib
 import subprocess
 import time
 import traceback
-from concurrent.futures import Future, ProcessPoolExecutor
+from concurrent.futures import ProcessPoolExecutor, as_completed
 from typing import Any, Awaitable, Callable, TextIO
 
 import pandas as pd
@@ -78,12 +78,6 @@ class EvalOutput(BaseModel):
         return json.dumps(dumped_dict)
 
 
-class EvalError(BaseModel):
-    instance_id: str
-    error: str
-    stacktrace: str
-
-
 def codeact_user_response(
     state: State,
     encapsulate_solution: bool = False,
@@ -235,65 +229,58 @@ def prepare_dataset(
 
 
 def update_progress(
-    result_or_future: Future | EvalOutput | EvalError,
-    instance: pd.Series,
+    result: EvalOutput,
     pbar: tqdm,
     output_fp: TextIO,
-    instance_queue: mp.Queue,
 ):
     """Update the progress bar and write the result to the output file."""
-    try:
-        if isinstance(result_or_future, Future):
-            result = result_or_future.result()
-        else:
-            result = result_or_future
-    except Exception as e:
-        # Handle the error
-        # Exception may be raised in the process_instance_func and will
-        # be raised here when we try to access the .result() of the future
-        handle_error(
-            EvalError(
-                instance_id=instance.instance_id,
-                error=str(e),
-                stacktrace=traceback.format_exc(),
-            ),
-            instance,
-            pbar,
-            instance_queue,
-        )
-        return
+    pbar.update(1)
+    pbar.set_description(f'Instance {result.instance_id}')
+    pbar.set_postfix_str(f'Test Result: {result.test_result}')
+    logger.info(
+        f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
+    )
+    output_fp.write(json.dumps(result.model_dump()) + '\n')
+    output_fp.flush()
 
-    # Update the progress bar and write the result to the output file
-    if isinstance(result, EvalOutput):
-        pbar.update(1)
-        pbar.set_description(f'Instance {result.instance_id}')
-        pbar.set_postfix_str(f'Test Result: {result.test_result}')
-        logger.info(
-            f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
-        )
-        output_fp.write(json.dumps(result.model_dump()) + '\n')
-        output_fp.flush()
-    elif isinstance(result, EvalError):
-        handle_error(result, instance, pbar, instance_queue)
-    else:
-        raise ValueError(f'Unexpected result type: {type(result)}')
 
+def _process_instance_wrapper(
+    process_instance_func: Callable[[pd.Series, EvalMetadata, bool], EvalOutput],
+    instance: pd.Series,
+    metadata: EvalMetadata,
+    use_mp: bool,
+    max_retries: int = 5,
+) -> EvalOutput:
+    """Wrap the process_instance_func to handle retries and errors.
 
-def handle_error(
-    error: EvalError, instance: pd.Series, pbar: tqdm, instance_queue: mp.Queue
-):
-    """Handle an error that occurred during evaluation."""
-    logger.error(
-        f'Retrying instance [{instance.instance_id}] due to error: {error.error}. Stacktrace:\n{error.stacktrace}'
-        + '\n'
-        + '-' * 10
-        + '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
-        + '-' * 10
-        + '\n'
-    )
-    instance_queue.put(instance)
-    pbar.total += 1
-    pbar.refresh()
+    Retry an instance up to max_retries times if it fails (e.g., due to transient network/runtime issues).
+    """
+    for attempt in range(max_retries + 1):
+        try:
+            result = process_instance_func(instance, metadata, use_mp)
+            return result
+        except Exception as e:
+            if attempt == max_retries:
+                # Raise an error after all retries & stop the evaluation
+                raise RuntimeError(
+                    f'Maximum error retries reached for instance {instance.instance_id}'
+                ) from e
+            error = str(e)
+            stacktrace = traceback.format_exc()
+            msg = (
+                '-' * 10
+                + '\n'
+                + f'Error in instance [{instance.instance_id}]: {error}. Stacktrace:\n{stacktrace}'
+                + '\n'
+                + '-' * 10
+                + '[This error occurred after maximum retries]'
+                + '-' * 10
+                + '\n'
+            )
+            logger.error(msg)
+            if use_mp:
+                print(msg)  # use print to directly print to console
+            time.sleep(1)  # Add a small delay before retrying
 
 
 def run_evaluation(
@@ -304,6 +291,7 @@ def run_evaluation(
     process_instance_func: Callable[
         [pd.Series, EvalMetadata, bool], Awaitable[EvalOutput]
     ],
+    max_retries: int = 5,  # number of retries for each instance
 ):
     use_multiprocessing = num_workers > 1
     logger.info(
@@ -311,10 +299,6 @@ def run_evaluation(
         f'model {metadata.llm_config.model}, max iterations {metadata.max_iterations}.\n'
     )
 
-    instance_queue = mp.Queue()
-    for _, instance in dataset.iterrows():
-        instance_queue.put(instance)
-
     total_instances = len(dataset)
     pbar = tqdm(total=total_instances, desc='Instances processed')
     output_fp = open(output_file, 'a')
@@ -322,53 +306,30 @@ def run_evaluation(
     try:
         if use_multiprocessing:
             with ProcessPoolExecutor(num_workers) as executor:
-                batch_futures = []
-
-                # Loop until there are *no more instances to be processed* and *all (in-progress) futures are done*
-                # since a running future may add new instances to the queue when error occurs
-                while not instance_queue.empty() or batch_futures:
-                    # Submit new tasks if there are instances to be processed and available workers
-                    while (
-                        not instance_queue.empty() and len(batch_futures) < num_workers
-                    ):
-                        try:
-                            instance = instance_queue.get(block=False)
-                            future = executor.submit(
-                                process_instance_func, instance, metadata, True
-                            )
-                            future.instance = (
-                                instance  # Attach the instance to the future
-                            )
-                            batch_futures.append(future)
-                        except mp.queues.Empty:
-                            logger.warning(
-                                'Queue is empty - This should not happen. This is a bug.'
-                            )
-                            break  # Queue is empty, stop submitting new tasks
-
-                    # Continue to wait for the futures to be done & remove completed futures
-                    new_batch_futures = []
-                    for future in batch_futures:
-                        if future.done():
-                            update_progress(
-                                future, future.instance, pbar, output_fp, instance_queue
-                            )
-                        else:
-                            new_batch_futures.append(future)
-                    batch_futures = new_batch_futures
-
-                    # Short sleep to prevent busy-waiting
-                    time.sleep(1)
-
-                assert instance_queue.empty(), 'instance_queue should be empty after all futures are done. This is a bug.'
-                assert (
-                    len(batch_futures) == 0
-                ), 'batch_futures should be empty after all futures are done. This is a bug.'
+                futures = [
+                    executor.submit(
+                        _process_instance_wrapper,
+                        process_instance_func=process_instance_func,
+                        instance=instance,
+                        metadata=metadata,
+                        use_mp=True,
+                        max_retries=max_retries,
+                    )
+                    for _, instance in dataset.iterrows()
+                ]
+                for future in as_completed(futures):
+                    result = future.result()
+                    update_progress(result, pbar, output_fp)
         else:
-            while not instance_queue.empty():
-                instance = instance_queue.get()
-                result = process_instance_func(instance, metadata, False)
-                update_progress(result, instance, pbar, output_fp, instance_queue)
+            for _, instance in dataset.iterrows():
+                result = _process_instance_wrapper(
+                    process_instance_func=process_instance_func,
+                    instance=instance,
+                    metadata=metadata,
+                    use_mp=False,
+                    max_retries=max_retries,
+                )
+                update_progress(result, pbar, output_fp)
 
     except KeyboardInterrupt:
         print('\nKeyboardInterrupt received. Cleaning up...\n')