Просмотр исходного кода

fix(evaluation): SWE-bench evaluation script supports multiprocessing (#4943)

Calvin Smith 1 год назад
Родитель
Сommit
50e7da9c3d
2 измененных файлов с 39 добавлено и 5 удалено
  1. 38 5
      evaluation/swe_bench/eval_infer.py
  2. 1 0
      evaluation/utils/shared.py

+ 38 - 5
evaluation/swe_bench/eval_infer.py

@@ -1,6 +1,7 @@
 import os
 import tempfile
 import time
+from functools import partial
 
 import pandas as pd
 from swebench.harness.grading import get_eval_report
@@ -94,13 +95,28 @@ def get_config(instance: pd.Series) -> AppConfig:
 
 def process_instance(
     instance: pd.Series,
-    metadata: EvalMetadata | None = None,
+    metadata: EvalMetadata,
     reset_logger: bool = True,
+    log_dir: str | None = None,
 ) -> EvalOutput:
+    """
+    Evaluate agent performance on a SWE-bench problem instance.
+
+    Note that this signature differs from the expected input to `run_evaluation`. Use
+    `functools.partial` to provide optional arguments before passing to the evaluation harness.
+
+    Args:
+        log_dir (str | None, default=None): Path to directory where log files will be written. Must
+        be provided if `reset_logger` is set.
+
+    Raises:
+        AssertionError: if the `reset_logger` flag is set without a provided log directory.
+    """
     # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
     if reset_logger:
-        global output_file
-        log_dir = output_file.replace('.jsonl', '.logs')
+        assert (
+            log_dir is not None
+        ), "Can't reset logger without a provided log directory."
         os.makedirs(log_dir, exist_ok=True)
         reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
     else:
@@ -127,6 +143,7 @@ def process_instance(
         return EvalOutput(
             instance_id=instance_id,
             test_result=instance['test_result'],
+            metadata=metadata,
         )
 
     runtime = create_runtime(config)
@@ -176,6 +193,7 @@ def process_instance(
             return EvalOutput(
                 instance_id=instance_id,
                 test_result=instance['test_result'],
+                metadata=metadata,
             )
         elif 'APPLY_PATCH_PASS' in apply_patch_output:
             logger.info(f'[{instance_id}] {APPLY_PATCH_PASS}:\n{apply_patch_output}')
@@ -269,6 +287,7 @@ def process_instance(
             return EvalOutput(
                 instance_id=instance_id,
                 test_result=instance['test_result'],
+                metadata=metadata,
             )
         else:
             logger.info(
@@ -355,12 +374,26 @@ if __name__ == '__main__':
     output_file = args.input_file.replace('.jsonl', '.swebench_eval.jsonl')
     instances = prepare_dataset(predictions, output_file, args.eval_n_limit)
 
+    # If possible, load the relevant metadata to avoid issues with `run_evaluation`.
+    metadata: EvalMetadata | None = None
+    metadata_filepath = os.path.join(os.path.dirname(args.input_file), 'metadata.json')
+    if os.path.exists(metadata_filepath):
+        with open(metadata_filepath, 'r') as metadata_file:
+            data = metadata_file.read()
+            metadata = EvalMetadata.model_validate_json(data)
+
+    # The evaluation harness constrains the signature of `process_instance_func` but we need to
+    # pass extra information. Build a new function object to avoid issues with multiprocessing.
+    process_instance_func = partial(
+        process_instance, log_dir=output_file.replace('.jsonl', '.logs')
+    )
+
     run_evaluation(
         instances,
-        metadata=None,
+        metadata=metadata,
         output_file=output_file,
         num_workers=args.eval_num_workers,
-        process_instance_func=process_instance,
+        process_instance_func=process_instance_func,
     )
 
     # Load evaluated predictions & print number of resolved predictions

+ 1 - 0
evaluation/utils/shared.py

@@ -346,6 +346,7 @@ def run_evaluation(
             f'model {metadata.llm_config.model}, max iterations {metadata.max_iterations}.\n'
         )
     else:
+        logger.warning('Running evaluation without metadata.')
         logger.info(f'Evaluation started with {num_workers} workers.')
 
     total_instances = len(dataset)