download_gold_patch.py 928 B

123456789101112131415161718192021222324252627
  1. import argparse
  2. import pandas as pd
  3. from datasets import load_dataset
  4. parser = argparse.ArgumentParser()
  5. parser.add_argument('output_filepath', type=str, help='Path to save the output file')
  6. parser.add_argument(
  7. '--dataset_name',
  8. type=str,
  9. help='Name of the dataset to download',
  10. default='princeton-nlp/SWE-bench_Lite',
  11. )
  12. parser.add_argument('--split', type=str, help='Split to download', default='test')
  13. args = parser.parse_args()
  14. dataset = load_dataset(args.dataset_name, split=args.split)
  15. output_filepath = args.output_filepath
  16. print(
  17. f'Downloading gold patches from {args.dataset_name} (split: {args.split}) to {output_filepath}'
  18. )
  19. patches = [
  20. {'instance_id': row['instance_id'], 'model_patch': row['patch']} for row in dataset
  21. ]
  22. print(f'{len(patches)} gold patches loaded')
  23. pd.DataFrame(patches).to_json(output_filepath, lines=True, orient='records')
  24. print(f'Patches saved to {output_filepath}')