summarize_results.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import json
  2. import os
  3. import sys
  4. import numpy as np
  5. import pandas as pd
  6. # Try to import visualization libraries
  7. visualization_available = False
  8. try:
  9. import matplotlib.pyplot as plt
  10. import seaborn as sns
  11. visualization_available = True
  12. except ImportError:
  13. print(
  14. '\n*** WARNING: libraries matplotlib and/or seaborn are not installed.\n*** Visualization will not be available!\n'
  15. )
  16. def show_usage():
  17. print(
  18. 'Usage: poetry run python summarize_results.py <path_to_output_jsonl_file> <model_name>'
  19. )
  20. print(
  21. 'Example:\npoetry run python summarize_results.py evaluation/evaluation_outputs/outputs/AiderBench/CodeActAgent/claude-3-5-sonnet@20240620_maxiter_30_N_v1.9/output.jsonl claude-3-5-sonnet@20240620\n'
  22. )
  23. def print_error(message: str):
  24. print(f'\n***\n*** ERROR: {message}\n***\n')
  25. show_usage()
  26. def extract_test_results(res_file_path: str) -> tuple[list[str], list[str]]:
  27. passed = []
  28. failed = []
  29. with open(res_file_path, 'r') as file:
  30. for line in file:
  31. data = json.loads(line.strip())
  32. instance_id = data['instance_id']
  33. resolved = False
  34. if 'test_result' in data and 'exit_code' in data['test_result']:
  35. resolved = data['test_result']['exit_code'] == 0
  36. if resolved:
  37. passed.append(instance_id)
  38. else:
  39. failed.append(instance_id)
  40. return passed, failed
  41. def visualize_results(json_file_path: str, model: str, output_dir: str):
  42. # based on a Colab notebook by RajMaheshwari
  43. with open(json_file_path, 'r') as f:
  44. data = [json.loads(line) for line in f]
  45. df = pd.DataFrame.from_records(data)
  46. df1 = pd.DataFrame()
  47. df1['cost'] = df['metrics'].apply(pd.Series)['accumulated_cost']
  48. df1['result'] = (
  49. df['test_result'].apply(pd.Series)['exit_code'].map({0: 'Pass', 1: 'Fail'})
  50. )
  51. df1['actions'] = pd.Series([len(a) - 1 for a in df['history']])
  52. passed = np.sum(df1['result'] == 'Pass')
  53. total = df.shape[0]
  54. resolve_rate = round((passed / total) * 100, 2)
  55. print('Number of passed tests:', f'{passed}/{total}')
  56. if not visualization_available:
  57. return resolve_rate
  58. # Cost histogram
  59. plt.figure(figsize=(10, 6))
  60. bins = 10
  61. mx = pd.Series.max(df1['cost'])
  62. g = sns.histplot(df1, x='cost', bins=bins, hue='result', multiple='stack')
  63. x_ticks = np.around(np.linspace(0, mx, bins + 1), 3)
  64. g.set_xticks(x_ticks)
  65. g.set_xlabel('Cost in $')
  66. g.set_title(f'MODEL: {model}, RESOLVE_RATE: {resolve_rate}%', size=9)
  67. plt.tight_layout()
  68. plt.savefig(os.path.join(output_dir, 'cost_histogram.png'))
  69. plt.close()
  70. # Actions histogram
  71. plt.figure(figsize=(10, 6))
  72. bins = np.arange(0, 31, 2)
  73. g = sns.histplot(df1, x='actions', bins=bins, hue='result', multiple='stack')
  74. g.set_xticks(bins)
  75. g.set_xlabel('# of actions')
  76. g.set_title(f'MODEL: {model}, RESOLVE_RATE: {resolve_rate}%', size=9)
  77. plt.tight_layout()
  78. plt.savefig(os.path.join(output_dir, 'actions_histogram.png'))
  79. plt.close()
  80. return resolve_rate
  81. if __name__ == '__main__':
  82. if len(sys.argv) != 3:
  83. print_error('Argument(s) missing!')
  84. sys.exit(1)
  85. json_file_path = sys.argv[1]
  86. model_name = sys.argv[2]
  87. if not os.path.exists(json_file_path):
  88. print_error('Output file does not exist!')
  89. sys.exit(1)
  90. if not os.path.isfile(json_file_path):
  91. print_error('Path-to-output-file is not a file!')
  92. sys.exit(1)
  93. output_dir = os.path.dirname(json_file_path)
  94. if not os.access(output_dir, os.W_OK):
  95. print_error('Output folder is not writable!')
  96. sys.exit(1)
  97. passed_tests, failed_tests = extract_test_results(json_file_path)
  98. resolve_rate = visualize_results(json_file_path, model_name, output_dir)
  99. print(
  100. f'\nPassed {len(passed_tests)} tests, failed {len(failed_tests)} tests, resolve rate = {resolve_rate:.2f}%'
  101. )
  102. print('PASSED TESTS:')
  103. print(passed_tests)
  104. print('FAILED TESTS:')
  105. print(failed_tests)
  106. print(
  107. '\nVisualization results were saved as cost_histogram.png and actions_histogram.png'
  108. )
  109. print('in folder: ', output_dir)