eval_infer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. import os
  2. import tempfile
  3. import time
  4. from functools import partial
  5. import pandas as pd
  6. from swebench.harness.grading import get_eval_report
  7. from swebench.harness.run_evaluation import (
  8. APPLY_PATCH_FAIL,
  9. APPLY_PATCH_PASS,
  10. )
  11. from swebench.harness.test_spec import SWEbenchInstance, TestSpec, make_test_spec
  12. from swebench.harness.utils import load_swebench_dataset
  13. from evaluation.benchmarks.swe_bench.run_infer import get_instance_docker_image
  14. from evaluation.utils.shared import (
  15. EvalMetadata,
  16. EvalOutput,
  17. prepare_dataset,
  18. reset_logger_for_multiprocessing,
  19. run_evaluation,
  20. )
  21. from openhands.core.config import (
  22. AppConfig,
  23. SandboxConfig,
  24. get_parser,
  25. )
  26. from openhands.core.logger import openhands_logger as logger
  27. from openhands.core.main import create_runtime
  28. from openhands.events.action import CmdRunAction
  29. from openhands.events.observation import CmdOutputObservation
  30. from openhands.utils.async_utils import call_async_from_sync
  31. # TODO: migrate all swe-bench docker to ghcr.io/openhands
  32. DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/xingyaoww/')
  33. logger.info(f'Using docker image prefix: {DOCKER_IMAGE_PREFIX}')
  34. def process_git_patch(patch):
  35. if not isinstance(patch, str):
  36. return ''
  37. if not patch.strip():
  38. # skip empty patches
  39. return ''
  40. patch = patch.replace('\r\n', '\n')
  41. # There might be some weird characters at the beginning of the patch
  42. # due to some OpenHands inference command outputs
  43. # FOR EXAMPLE:
  44. # git diff --no-color --cached 895f28f9cbed817c00ab68770433170d83132d90
  45. # 0
  46. # diff --git a/django/db/models/sql/.backup.query.py b/django/db/models/sql/.backup.query.py
  47. # new file mode 100644
  48. # index 0000000000..fc13db5948
  49. # We "find" the first line that starts with "diff" and then we remove lines before it
  50. lines = patch.split('\n')
  51. for i, line in enumerate(lines):
  52. if line.startswith('diff --git'):
  53. patch = '\n'.join(lines[i:])
  54. break
  55. patch = patch.rstrip() + '\n' # Make sure the last line ends with a newline
  56. return patch
  57. def get_config(instance: pd.Series) -> AppConfig:
  58. # We use a different instance image for the each instance of swe-bench eval
  59. base_container_image = get_instance_docker_image(instance['instance_id'])
  60. logger.info(
  61. f'Using instance container image: {base_container_image}. '
  62. f'Please make sure this image exists. '
  63. f'Submit an issue on https://github.com/All-Hands-AI/OpenHands if you run into any issues.'
  64. )
  65. config = AppConfig(
  66. run_as_openhands=False,
  67. runtime=os.environ.get('RUNTIME', 'eventstream'),
  68. sandbox=SandboxConfig(
  69. base_container_image=base_container_image,
  70. use_host_network=False,
  71. # large enough timeout, since some testcases take very long to run
  72. timeout=1800,
  73. api_key=os.environ.get('ALLHANDS_API_KEY', None),
  74. remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
  75. remote_runtime_init_timeout=3600,
  76. ),
  77. # do not mount workspace
  78. workspace_base=None,
  79. workspace_mount_path=None,
  80. )
  81. return config
  82. def process_instance(
  83. instance: pd.Series,
  84. metadata: EvalMetadata,
  85. reset_logger: bool = True,
  86. log_dir: str | None = None,
  87. ) -> EvalOutput:
  88. """
  89. Evaluate agent performance on a SWE-bench problem instance.
  90. Note that this signature differs from the expected input to `run_evaluation`. Use
  91. `functools.partial` to provide optional arguments before passing to the evaluation harness.
  92. Args:
  93. log_dir (str | None, default=None): Path to directory where log files will be written. Must
  94. be provided if `reset_logger` is set.
  95. Raises:
  96. AssertionError: if the `reset_logger` flag is set without a provided log directory.
  97. """
  98. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  99. if reset_logger:
  100. assert (
  101. log_dir is not None
  102. ), "Can't reset logger without a provided log directory."
  103. os.makedirs(log_dir, exist_ok=True)
  104. reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
  105. else:
  106. logger.info(f'Starting evaluation for instance {instance.instance_id}.')
  107. config = get_config(instance)
  108. instance_id = instance.instance_id
  109. model_patch = instance['model_patch']
  110. test_spec: TestSpec = instance['test_spec']
  111. logger.info(f'Starting evaluation for instance {instance_id}.')
  112. if 'test_result' not in instance.keys():
  113. instance['test_result'] = {}
  114. instance['test_result']['report'] = {
  115. 'empty_generation': False,
  116. 'resolved': False,
  117. 'failed_apply_patch': False,
  118. 'error_eval': False,
  119. 'test_timeout': False,
  120. }
  121. if model_patch == '':
  122. instance['test_result']['report']['empty_generation'] = True
  123. return EvalOutput(
  124. instance_id=instance_id,
  125. test_result=instance['test_result'],
  126. metadata=metadata,
  127. )
  128. runtime = create_runtime(config)
  129. call_async_from_sync(runtime.connect)
  130. # Get patch and save it to /tmp/patch.diff
  131. with tempfile.TemporaryDirectory() as temp_dir:
  132. # Patch file
  133. patch_file_path = os.path.join(temp_dir, 'patch.diff')
  134. with open(patch_file_path, 'w') as f:
  135. f.write(model_patch)
  136. runtime.copy_to(patch_file_path, '/tmp')
  137. # Eval script
  138. eval_script_path = os.path.join(temp_dir, 'eval.sh')
  139. with open(eval_script_path, 'w') as f:
  140. f.write(test_spec.eval_script)
  141. runtime.copy_to(eval_script_path, '/tmp')
  142. # Set +x
  143. action = CmdRunAction(command='chmod +x /tmp/eval.sh')
  144. action.timeout = 600
  145. logger.info(action, extra={'msg_type': 'ACTION'})
  146. obs = runtime.run_action(action)
  147. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  148. assert obs.exit_code == 0
  149. # Apply patch
  150. exec_command = (
  151. 'cd /testbed && '
  152. "(git apply -v /tmp/patch.diff && echo 'APPLY_PATCH_PASS' || "
  153. "(echo 'Failed to apply patch with git apply, trying with patch command...' && "
  154. "(patch --batch --fuzz=5 -p1 -i /tmp/patch.diff && echo 'APPLY_PATCH_PASS' || "
  155. "echo 'APPLY_PATCH_FAIL')))"
  156. )
  157. action = CmdRunAction(command=exec_command, keep_prompt=False)
  158. action.timeout = 600
  159. obs = runtime.run_action(action)
  160. assert isinstance(obs, CmdOutputObservation)
  161. apply_patch_output = obs.content
  162. assert isinstance(apply_patch_output, str)
  163. instance['test_result']['apply_patch_output'] = apply_patch_output
  164. try:
  165. if 'APPLY_PATCH_FAIL' in apply_patch_output:
  166. logger.info(f'[{instance_id}] {APPLY_PATCH_FAIL}:\n{apply_patch_output}')
  167. instance['test_result']['report']['failed_apply_patch'] = True
  168. return EvalOutput(
  169. instance_id=instance_id,
  170. test_result=instance['test_result'],
  171. metadata=metadata,
  172. )
  173. elif 'APPLY_PATCH_PASS' in apply_patch_output:
  174. logger.info(f'[{instance_id}] {APPLY_PATCH_PASS}:\n{apply_patch_output}')
  175. # Run eval script in background and save output to log file
  176. log_file = '/tmp/eval_output.log'
  177. action = CmdRunAction(
  178. command=f'/tmp/eval.sh > {log_file} 2>&1 & echo $!', keep_prompt=False
  179. )
  180. action.timeout = 60 # Short timeout just to get the process ID
  181. obs = runtime.run_action(action)
  182. if isinstance(obs, CmdOutputObservation) and obs.exit_code == 0:
  183. pid = obs.content.split()[-1].strip()
  184. logger.info(
  185. f'[{instance_id}] Evaluation process started with PID: {pid}'
  186. )
  187. # Poll for completion
  188. start_time = time.time()
  189. timeout = 1800 # 30 minutes
  190. while True:
  191. seconds_elapsed = time.time() - start_time
  192. if seconds_elapsed > timeout:
  193. logger.info(
  194. f'[{instance_id}] Evaluation timed out after {timeout} seconds'
  195. )
  196. instance['test_result']['report']['test_timeout'] = True
  197. break
  198. check_action = CmdRunAction(
  199. command=f'ps -p {pid} > /dev/null; echo $?', keep_prompt=False
  200. )
  201. check_action.timeout = 60
  202. check_obs = runtime.run_action(check_action)
  203. if (
  204. isinstance(check_obs, CmdOutputObservation)
  205. and check_obs.content.split()[-1].strip() == '1'
  206. ):
  207. logger.info(
  208. f'[{instance_id}] Evaluation process completed after {seconds_elapsed} seconds'
  209. )
  210. break
  211. logger.info(
  212. f'[{instance_id}] [{seconds_elapsed:.0f}s] Evaluation still running, waiting...'
  213. )
  214. time.sleep(30) # Wait for 30 seconds before checking again
  215. # Read the log file
  216. cat_action = CmdRunAction(command=f'cat {log_file}', keep_prompt=False)
  217. cat_action.timeout = 300
  218. cat_obs = runtime.run_action(cat_action)
  219. # Grade answer
  220. if isinstance(cat_obs, CmdOutputObservation) and cat_obs.exit_code == 0:
  221. test_output = cat_obs.content
  222. assert isinstance(test_output, str)
  223. instance['test_result']['test_output'] = test_output
  224. # Get report from test output
  225. logger.info(f'[{instance_id}] Grading answer...')
  226. with tempfile.TemporaryDirectory() as temp_dir:
  227. # Create a directory structure that matches the expected format
  228. # NOTE: this is a hack to make the eval report format consistent
  229. # with the original SWE-Bench eval script
  230. log_dir = os.path.join(temp_dir, 'logs', instance_id.lower())
  231. os.makedirs(log_dir, exist_ok=True)
  232. test_output_path = os.path.join(log_dir, 'test_output.txt')
  233. with open(test_output_path, 'w') as f:
  234. f.write(test_output)
  235. try:
  236. _report = get_eval_report(
  237. test_spec=test_spec,
  238. prediction={
  239. 'model_patch': model_patch,
  240. 'instance_id': instance_id,
  241. },
  242. log_path=test_output_path,
  243. include_tests_status=True,
  244. )
  245. report = _report[instance_id]
  246. logger.info(
  247. f"[{instance_id}] report: {report}\nResult for {instance_id}: resolved: {report['resolved']}"
  248. )
  249. instance['test_result']['report']['resolved'] = report[
  250. 'resolved'
  251. ]
  252. except Exception as e:
  253. logger.error(
  254. f'[{instance_id}] Error when getting eval report: {e}'
  255. )
  256. instance['test_result']['report']['resolved'] = False
  257. instance['test_result']['report']['error_eval'] = True
  258. else:
  259. logger.info(f'[{instance_id}] Error when starting eval:\n{obs.content}')
  260. instance['test_result']['report']['error_eval'] = True
  261. return EvalOutput(
  262. instance_id=instance_id,
  263. test_result=instance['test_result'],
  264. metadata=metadata,
  265. )
  266. else:
  267. logger.info(
  268. f'[{instance_id}] Unexpected output when applying patch:\n{apply_patch_output}'
  269. )
  270. raise RuntimeError(
  271. instance_id,
  272. f'Unexpected output when applying patch:\n{apply_patch_output}',
  273. logger,
  274. )
  275. finally:
  276. runtime.close()
  277. if __name__ == '__main__':
  278. parser = get_parser()
  279. parser.add_argument(
  280. '--input-file',
  281. type=str,
  282. help='Path to input predictions file',
  283. required=True,
  284. )
  285. parser.add_argument(
  286. '--dataset',
  287. type=str,
  288. default='princeton-nlp/SWE-bench',
  289. help='data set to evaluate on, either full-test or lite-test',
  290. )
  291. parser.add_argument(
  292. '--split',
  293. type=str,
  294. default='test',
  295. help='split to evaluate on',
  296. )
  297. args, _ = parser.parse_known_args()
  298. # Load SWE-Bench dataset
  299. full_dataset: list[SWEbenchInstance] = load_swebench_dataset(
  300. args.dataset, args.split
  301. )
  302. instance_id_to_instance = {
  303. instance['instance_id']: instance for instance in full_dataset
  304. }
  305. logger.info(
  306. f'Loaded dataset {args.dataset} with split {args.split} to run inference on.'
  307. )
  308. # Load predictions
  309. assert args.input_file.endswith('.jsonl'), 'Input file must be a jsonl file.'
  310. predictions = pd.read_json(args.input_file, lines=True)
  311. assert (
  312. 'instance_id' in predictions.columns
  313. ), 'Input file must contain instance_id column.'
  314. if 'model_patch' not in predictions.columns and (
  315. 'test_result' in predictions.columns
  316. and 'model_patch' in predictions['test_result'].iloc[0]
  317. ):
  318. raise ValueError(
  319. 'Input file must contain model_patch column OR test_result column with model_patch field.'
  320. )
  321. assert len(predictions['instance_id'].unique()) == len(
  322. predictions
  323. ), 'instance_id column must be unique.'
  324. if 'model_patch' not in predictions.columns:
  325. predictions['model_patch'] = predictions['test_result'].apply(
  326. lambda x: x.get('git_patch', '')
  327. )
  328. assert {'instance_id', 'model_patch'}.issubset(
  329. set(predictions.columns)
  330. ), 'Input file must contain instance_id and model_patch columns.'
  331. # Process model_patch
  332. predictions['model_patch'] = predictions['model_patch'].apply(process_git_patch)
  333. # Merge predictions with dataset
  334. predictions['instance'] = predictions['instance_id'].apply(
  335. lambda x: instance_id_to_instance[x]
  336. )
  337. predictions['test_spec'] = predictions['instance'].apply(make_test_spec)
  338. # Prepare dataset
  339. output_file = args.input_file.replace('.jsonl', '.swebench_eval.jsonl')
  340. instances = prepare_dataset(predictions, output_file, args.eval_n_limit)
  341. # If possible, load the relevant metadata to avoid issues with `run_evaluation`.
  342. metadata: EvalMetadata | None = None
  343. metadata_filepath = os.path.join(os.path.dirname(args.input_file), 'metadata.json')
  344. if os.path.exists(metadata_filepath):
  345. with open(metadata_filepath, 'r') as metadata_file:
  346. data = metadata_file.read()
  347. metadata = EvalMetadata.model_validate_json(data)
  348. # The evaluation harness constrains the signature of `process_instance_func` but we need to
  349. # pass extra information. Build a new function object to avoid issues with multiprocessing.
  350. process_instance_func = partial(
  351. process_instance, log_dir=output_file.replace('.jsonl', '.logs')
  352. )
  353. run_evaluation(
  354. instances,
  355. metadata=metadata,
  356. output_file=output_file,
  357. num_workers=args.eval_num_workers,
  358. process_instance_func=process_instance_func,
  359. )
  360. # Load evaluated predictions & print number of resolved predictions
  361. evaluated_predictions = pd.read_json(output_file, lines=True)
  362. fields = ['resolved', 'failed_apply_patch', 'error_eval', 'empty_generation']
  363. def count_report_field(row, field):
  364. return row['test_result']['report'][field]
  365. report = {}
  366. for field in fields:
  367. count = evaluated_predictions.apply(
  368. count_report_field, args=(field,), axis=1
  369. ).sum()
  370. report[field] = count
  371. logger.info(
  372. f'# {field}: {count} / {len(evaluated_predictions)}. ({count / len(evaluated_predictions):.2%})'
  373. )