Explorar o código

[eval] add git patch post-processing for SWE-Bench eval_infer (#3980)

Xingyao Wang hai 1 ano
pai
achega
b13ed017d8

+ 36 - 9
evaluation/swe_bench/eval_infer.py

@@ -3,7 +3,6 @@ import tempfile
 import time
 
 import pandas as pd
-from pydantic import BaseModel
 from swebench.harness.grading import get_eval_report
 from swebench.harness.run_evaluation import (
     APPLY_PATCH_FAIL,
@@ -35,6 +34,36 @@ DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/xing
 logger.info(f'Using docker image prefix: {DOCKER_IMAGE_PREFIX}')
 
 
+def process_git_patch(patch):
+    if not isinstance(patch, str):
+        return ''
+
+    if not patch.strip():
+        # skip empty patches
+        return ''
+
+    patch = patch.replace('\r\n', '\n')
+    # There might be some weird characters at the beginning of the patch
+    # due to some OpenHands inference command outputs
+
+    # FOR EXAMPLE:
+    # git diff --no-color --cached 895f28f9cbed817c00ab68770433170d83132d90
+    # 0
+    # diff --git a/django/db/models/sql/.backup.query.py b/django/db/models/sql/.backup.query.py
+    # new file mode 100644
+    # index 0000000000..fc13db5948
+
+    # We "find" the first line that starts with "diff" and then we remove lines before it
+    lines = patch.split('\n')
+    for i, line in enumerate(lines):
+        if line.startswith('diff --git'):
+            patch = '\n'.join(lines[i:])
+            break
+
+    patch = patch.rstrip() + '\n'  # Make sure the last line ends with a newline
+    return patch
+
+
 def get_config(instance: pd.Series) -> AppConfig:
     # We use a different instance image for the each instance of swe-bench eval
     base_container_image = get_instance_docker_image(instance['instance_id'])
@@ -60,13 +89,6 @@ def get_config(instance: pd.Series) -> AppConfig:
     return config
 
 
-class SWEBenchEvalResult(BaseModel):
-    instance_id: str
-    apply_patch_output: str
-    test_output: str
-    resolved: bool
-
-
 def process_instance(
     instance: pd.Series,
     metadata: EvalMetadata | None = None,
@@ -94,6 +116,7 @@ def process_instance(
         'resolved': False,
         'failed_apply_patch': False,
         'error_eval': False,
+        'test_timeout': False,
     }
 
     if model_patch == '':
@@ -170,13 +193,14 @@ def process_instance(
 
                 # Poll for completion
                 start_time = time.time()
-                timeout = 900  # 15 minutes
+                timeout = 1800  # 30 minutes
                 while True:
                     seconds_elapsed = time.time() - start_time
                     if seconds_elapsed > timeout:
                         logger.info(
                             f'[{instance_id}] Evaluation timed out after {timeout} seconds'
                         )
+                        instance['test_result']['report']['test_timeout'] = True
                         break
                     check_action = CmdRunAction(
                         command=f'ps -p {pid} > /dev/null; echo $?', keep_prompt=False
@@ -315,6 +339,9 @@ if __name__ == '__main__':
         set(predictions.columns)
     ), 'Input file must contain instance_id and model_patch columns.'
 
+    # Process model_patch
+    predictions['model_patch'] = predictions['model_patch'].apply(process_git_patch)
+
     # Merge predictions with dataset
     predictions['instance'] = predictions['instance_id'].apply(
         lambda x: instance_id_to_instance[x]

+ 2 - 30
evaluation/swe_bench/scripts/eval/convert_oh_output_to_swe_json.py

@@ -3,6 +3,8 @@ import os
 
 import pandas as pd
 
+from evaluation.swe_bench.eval_infer import process_git_patch
+
 parser = argparse.ArgumentParser()
 parser.add_argument('oh_output_file', type=str)
 args = parser.parse_args()
@@ -14,36 +16,6 @@ oh_format = pd.read_json(args.oh_output_file, orient='records', lines=True)
 model_name = os.path.basename(os.path.dirname(args.oh_output_file))
 
 
-def process_git_patch(patch):
-    if not isinstance(patch, str):
-        return ''
-
-    if not patch.strip():
-        # skip empty patches
-        return ''
-
-    patch = patch.replace('\r\n', '\n')
-    # There might be some weird characters at the beginning of the patch
-    # due to some OpenHands inference command outputs
-
-    # FOR EXAMPLE:
-    # git diff --no-color --cached 895f28f9cbed817c00ab68770433170d83132d90
-    # 0
-    # diff --git a/django/db/models/sql/.backup.query.py b/django/db/models/sql/.backup.query.py
-    # new file mode 100644
-    # index 0000000000..fc13db5948
-
-    # We "find" the first line that starts with "diff" and then we remove lines before it
-    lines = patch.split('\n')
-    for i, line in enumerate(lines):
-        if line.startswith('diff --git'):
-            patch = '\n'.join(lines[i:])
-            break
-
-    patch = patch.rstrip() + '\n'  # Make sure the last line ends with a newline
-    return patch
-
-
 def convert_row_to_swebench_format(row):
     if 'git_patch' in row:
         model_patch = row['git_patch']