verify_costs.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import argparse
  2. import pandas as pd
  3. from openhands.core.logger import openhands_logger as logger
  4. def verify_instance_costs(row: pd.Series) -> float:
  5. """
  6. Verifies that the accumulated_cost matches the sum of individual costs in metrics.
  7. Also checks for duplicate consecutive costs which might indicate buggy counting.
  8. If the consecutive costs are identical, the file is affected by this bug:
  9. https://github.com/All-Hands-AI/OpenHands/issues/5383
  10. Args:
  11. row: DataFrame row containing instance data with metrics
  12. Returns:
  13. float: The verified total cost for this instance (corrected if needed)
  14. """
  15. try:
  16. metrics = row.get('metrics')
  17. if not metrics:
  18. logger.warning(f"Instance {row['instance_id']}: No metrics found")
  19. return 0.0
  20. accumulated = metrics.get('accumulated_cost')
  21. costs = metrics.get('costs', [])
  22. if accumulated is None:
  23. logger.warning(
  24. f"Instance {row['instance_id']}: No accumulated_cost in metrics"
  25. )
  26. return 0.0
  27. # Check for duplicate consecutive costs and systematic even-odd pairs
  28. has_duplicate = False
  29. all_pairs_match = True
  30. # Check each even-odd pair (0-1, 2-3, etc.)
  31. for i in range(0, len(costs) - 1, 2):
  32. if abs(costs[i]['cost'] - costs[i + 1]['cost']) < 1e-6:
  33. has_duplicate = True
  34. logger.debug(
  35. f"Instance {row['instance_id']}: Possible buggy double-counting detected! "
  36. f"Steps {i} and {i+1} have identical costs: {costs[i]['cost']:.2f}"
  37. )
  38. else:
  39. all_pairs_match = False
  40. break
  41. # Calculate total cost, accounting for buggy double counting if detected
  42. if len(costs) >= 2 and has_duplicate and all_pairs_match:
  43. paired_steps_cost = sum(
  44. cost_entry['cost']
  45. for cost_entry in costs[: -1 if len(costs) % 2 else None]
  46. )
  47. real_paired_cost = paired_steps_cost / 2
  48. unpaired_cost = costs[-1]['cost'] if len(costs) % 2 else 0
  49. total_cost = real_paired_cost + unpaired_cost
  50. else:
  51. total_cost = sum(cost_entry['cost'] for cost_entry in costs)
  52. if not abs(total_cost - accumulated) < 1e-6:
  53. logger.warning(
  54. f"Instance {row['instance_id']}: Cost mismatch: "
  55. f"accumulated: {accumulated:.2f}, sum of costs: {total_cost:.2f}, "
  56. )
  57. return total_cost
  58. except Exception as e:
  59. logger.error(
  60. f"Error verifying costs for instance {row.get('instance_id', 'UNKNOWN')}: {e}"
  61. )
  62. return 0.0
  63. def main():
  64. parser = argparse.ArgumentParser(
  65. description='Verify costs in SWE-bench output file'
  66. )
  67. parser.add_argument(
  68. 'input_filepath', type=str, help='Path to the output.jsonl file'
  69. )
  70. args = parser.parse_args()
  71. try:
  72. # Load and verify the JSONL file
  73. df = pd.read_json(args.input_filepath, lines=True)
  74. logger.info(f'Loaded {len(df)} instances from {args.input_filepath}')
  75. # Verify costs for each instance and sum up total
  76. total_cost = df.apply(verify_instance_costs, axis=1).sum()
  77. logger.info(f'Total verified cost across all instances: ${total_cost:.2f}')
  78. except Exception as e:
  79. logger.error(f'Failed to process file: {e}')
  80. raise
  81. if __name__ == '__main__':
  82. main()