run_infer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. import asyncio
  2. import json
  3. import os
  4. import tempfile
  5. from typing import Any
  6. import pandas as pd
  7. import toml
  8. from datasets import load_dataset
  9. import openhands.agenthub
  10. from evaluation.swe_bench.prompt import CODEACT_SWE_PROMPT
  11. from evaluation.utils.shared import (
  12. EvalException,
  13. EvalMetadata,
  14. EvalOutput,
  15. assert_and_raise,
  16. codeact_user_response,
  17. make_metadata,
  18. prepare_dataset,
  19. reset_logger_for_multiprocessing,
  20. run_evaluation,
  21. )
  22. from openhands.controller.state.state import State
  23. from openhands.core.config import (
  24. AppConfig,
  25. SandboxConfig,
  26. get_llm_config_arg,
  27. get_parser,
  28. )
  29. from openhands.core.logger import openhands_logger as logger
  30. from openhands.core.main import create_runtime, run_controller
  31. from openhands.events.action import CmdRunAction, MessageAction
  32. from openhands.events.observation import CmdOutputObservation, ErrorObservation
  33. from openhands.events.serialization.event import event_to_dict
  34. from openhands.runtime.base import Runtime
  35. from openhands.runtime.utils.shutdown_listener import sleep_if_should_continue
  36. USE_HINT_TEXT = os.environ.get('USE_HINT_TEXT', 'false').lower() == 'true'
  37. USE_INSTANCE_IMAGE = os.environ.get('USE_INSTANCE_IMAGE', 'false').lower() == 'true'
  38. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  39. 'CodeActAgent': codeact_user_response,
  40. 'CodeActSWEAgent': codeact_user_response,
  41. }
  42. AGENT_CLS_TO_INST_SUFFIX = {
  43. 'CodeActAgent': 'When you think you have fixed the issue through code changes, please run the following command: <execute_bash> exit </execute_bash>.\n',
  44. 'CodeActSWEAgent': 'When you think you have fixed the issue through code changes, please run the following command: <execute_bash> exit </execute_bash>.\n',
  45. }
  46. def _get_swebench_workspace_dir_name(instance: pd.Series) -> str:
  47. return f'{instance.repo}__{instance.version}'.replace('/', '__')
  48. def get_instruction(instance: pd.Series, metadata: EvalMetadata):
  49. workspace_dir_name = _get_swebench_workspace_dir_name(instance)
  50. # Prepare instruction
  51. if metadata.agent_class == 'CodeActSWEAgent':
  52. instruction = (
  53. 'We are currently solving the following issue within our repository. Here is the issue text:\n'
  54. '--- BEGIN ISSUE ---\n'
  55. f'{instance.problem_statement}\n'
  56. '--- END ISSUE ---\n\n'
  57. )
  58. if USE_HINT_TEXT and instance.hints_text:
  59. instruction += (
  60. f'--- BEGIN HINTS ---\n{instance.hints_text}\n--- END HINTS ---\n'
  61. )
  62. instruction += CODEACT_SWE_PROMPT.format(workspace_dir_name=workspace_dir_name)
  63. else:
  64. # Testing general agents
  65. instruction = (
  66. f'Please fix the following issue for the repository in /workspace/{workspace_dir_name}.\n'
  67. 'Environment has been set up for you to start working. You may assume all necessary tools are installed.\n\n'
  68. '# Problem Statement\n'
  69. f'{instance.problem_statement}\n\n'
  70. )
  71. if USE_HINT_TEXT and instance.hints_text:
  72. instruction += f'# Hints\n{instance.hints_text}\n\n'
  73. instruction += (
  74. 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
  75. 'You should NOT modify any existing test case files. You SHOULD add new test in a NEW file to reproduce the issue.\n'
  76. 'You should verify that the issue is resolved and any new tests you create pass successfully.\n'
  77. 'You should NEVER use web browsing or any other web-based tools.\n'
  78. 'You should ALWAYS use the default Python interpreter available in the <execute_bash> environment to run code related to the provided issue and/or repository.\n'
  79. )
  80. # NOTE: You can actually set slightly different instruction for different agents
  81. instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
  82. return instruction
  83. # TODO: migrate all swe-bench docker to ghcr.io/openhands
  84. DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/xingyaoww/')
  85. logger.info(f'Using docker image prefix: {DOCKER_IMAGE_PREFIX}')
  86. def get_instance_docker_image(instance_id: str) -> str:
  87. image_name = 'sweb.eval.x86_64.' + instance_id
  88. image_name = image_name.replace(
  89. '__', '_s_'
  90. ) # to comply with docker image naming convention
  91. return DOCKER_IMAGE_PREFIX.rstrip('/') + '/' + image_name
  92. def get_config(
  93. instance: pd.Series,
  94. metadata: EvalMetadata,
  95. ) -> AppConfig:
  96. SWE_BENCH_CONTAINER_IMAGE = 'ghcr.io/opendevin/eval-swe-bench:full-v1.2.1'
  97. if USE_INSTANCE_IMAGE:
  98. # We use a different instance image for the each instance of swe-bench eval
  99. base_container_image = get_instance_docker_image(instance['instance_id'])
  100. logger.info(
  101. f'Using instance container image: {base_container_image}. '
  102. f'Please make sure this image exists. '
  103. f'Submit an issue on https://github.com/All-Hands-AI/OpenHands if you run into any issues.'
  104. )
  105. else:
  106. base_container_image = SWE_BENCH_CONTAINER_IMAGE
  107. logger.info(f'Using swe-bench container image: {base_container_image}')
  108. config = AppConfig(
  109. default_agent=metadata.agent_class,
  110. run_as_openhands=False,
  111. max_iterations=metadata.max_iterations,
  112. runtime=os.environ.get('RUNTIME', 'eventstream'),
  113. sandbox=SandboxConfig(
  114. base_container_image=base_container_image,
  115. enable_auto_lint=True,
  116. use_host_network=False,
  117. # large enough timeout, since some testcases take very long to run
  118. timeout=300,
  119. # Add platform to the sandbox config to solve issue 4401
  120. platform='linux/amd64',
  121. api_key=os.environ.get('ALLHANDS_API_KEY', None),
  122. remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
  123. keep_remote_runtime_alive=False,
  124. ),
  125. # do not mount workspace
  126. workspace_base=None,
  127. workspace_mount_path=None,
  128. )
  129. config.set_llm_config(metadata.llm_config)
  130. return config
  131. def initialize_runtime(
  132. runtime: Runtime,
  133. instance: pd.Series, # this argument is not required
  134. ):
  135. """Initialize the runtime for the agent.
  136. This function is called before the runtime is used to run the agent.
  137. """
  138. logger.info('-' * 30)
  139. logger.info('BEGIN Runtime Initialization Fn')
  140. logger.info('-' * 30)
  141. workspace_dir_name = _get_swebench_workspace_dir_name(instance)
  142. obs: CmdOutputObservation
  143. # Set instance id
  144. action = CmdRunAction(
  145. command=f"""echo 'export SWE_INSTANCE_ID={instance['instance_id']}' >> ~/.bashrc && echo 'export PIP_CACHE_DIR=~/.cache/pip' >> ~/.bashrc && echo "alias git='git --no-pager'" >> ~/.bashrc"""
  146. )
  147. action.timeout = 600
  148. logger.info(action, extra={'msg_type': 'ACTION'})
  149. obs = runtime.run_action(action)
  150. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  151. assert_and_raise(
  152. obs.exit_code == 0, f'Failed to export SWE_INSTANCE_ID: {str(obs)}'
  153. )
  154. action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
  155. action.timeout = 600
  156. logger.info(action, extra={'msg_type': 'ACTION'})
  157. obs = runtime.run_action(action)
  158. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  159. assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {str(obs)}')
  160. if USE_INSTANCE_IMAGE:
  161. # inject the init script
  162. script_dir = os.path.dirname(__file__)
  163. # inject the instance info
  164. action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances')
  165. action.timeout = 600
  166. logger.info(action, extra={'msg_type': 'ACTION'})
  167. obs = runtime.run_action(action)
  168. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  169. assert_and_raise(
  170. obs.exit_code == 0,
  171. f'Failed to create /swe_util/eval_data/instances: {str(obs)}',
  172. )
  173. swe_instance_json_name = 'swe-bench-instance.json'
  174. with tempfile.TemporaryDirectory() as temp_dir:
  175. # Construct the full path for the desired file name within the temporary directory
  176. temp_file_path = os.path.join(temp_dir, swe_instance_json_name)
  177. # Write to the file with the desired name within the temporary directory
  178. with open(temp_file_path, 'w') as f:
  179. if not isinstance(instance, dict):
  180. json.dump([instance.to_dict()], f)
  181. else:
  182. json.dump([instance], f)
  183. # Copy the file to the desired location
  184. runtime.copy_to(temp_file_path, '/swe_util/eval_data/instances/')
  185. # inject the instance swe entry
  186. runtime.copy_to(
  187. str(os.path.join(script_dir, 'scripts/setup/instance_swe_entry.sh')),
  188. '/swe_util/',
  189. )
  190. action = CmdRunAction(command='cat ~/.bashrc')
  191. action.timeout = 600
  192. logger.info(action, extra={'msg_type': 'ACTION'})
  193. obs = runtime.run_action(action)
  194. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  195. assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {str(obs)}')
  196. action = CmdRunAction(command='source ~/.bashrc')
  197. action.timeout = 600
  198. logger.info(action, extra={'msg_type': 'ACTION'})
  199. obs = runtime.run_action(action)
  200. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  201. if isinstance(obs, ErrorObservation):
  202. logger.error(f'Failed to source ~/.bashrc: {str(obs)}')
  203. assert_and_raise(obs.exit_code == 0, f'Failed to source ~/.bashrc: {str(obs)}')
  204. action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
  205. action.timeout = 3600
  206. logger.info(action, extra={'msg_type': 'ACTION'})
  207. obs = runtime.run_action(action)
  208. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  209. assert_and_raise(
  210. obs.exit_code == 0,
  211. f'Failed to source /swe_util/instance_swe_entry.sh: {str(obs)}',
  212. )
  213. else:
  214. action = CmdRunAction(command='source /swe_util/swe_entry.sh')
  215. action.timeout = 1800
  216. logger.info(action, extra={'msg_type': 'ACTION'})
  217. obs = runtime.run_action(action)
  218. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  219. assert_and_raise(
  220. obs.exit_code == 0,
  221. f'Failed to source /swe_util/swe_entry.sh: {str(obs)}',
  222. )
  223. action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
  224. action.timeout = 600
  225. logger.info(action, extra={'msg_type': 'ACTION'})
  226. obs = runtime.run_action(action)
  227. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  228. assert_and_raise(
  229. obs.exit_code == 0,
  230. f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
  231. )
  232. action = CmdRunAction(command='git reset --hard')
  233. action.timeout = 600
  234. logger.info(action, extra={'msg_type': 'ACTION'})
  235. obs = runtime.run_action(action)
  236. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  237. assert_and_raise(obs.exit_code == 0, f'Failed to git reset --hard: {str(obs)}')
  238. action = CmdRunAction(
  239. command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done'
  240. )
  241. action.timeout = 600
  242. logger.info(action, extra={'msg_type': 'ACTION'})
  243. obs = runtime.run_action(action)
  244. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  245. assert_and_raise(obs.exit_code == 0, f'Failed to remove git remotes: {str(obs)}')
  246. logger.info('-' * 30)
  247. logger.info('END Runtime Initialization Fn')
  248. logger.info('-' * 30)
  249. def complete_runtime(
  250. runtime: Runtime,
  251. instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
  252. ) -> dict[str, Any]:
  253. """Complete the runtime for the agent.
  254. This function is called before the runtime is used to run the agent.
  255. If you need to do something in the sandbox to get the correctness metric after
  256. the agent has run, modify this function.
  257. """
  258. logger.info('-' * 30)
  259. logger.info('BEGIN Runtime Completion Fn')
  260. logger.info('-' * 30)
  261. obs: CmdOutputObservation
  262. workspace_dir_name = _get_swebench_workspace_dir_name(instance)
  263. action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
  264. action.timeout = 600
  265. logger.info(action, extra={'msg_type': 'ACTION'})
  266. obs = runtime.run_action(action)
  267. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  268. assert_and_raise(
  269. obs.exit_code == 0,
  270. f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
  271. )
  272. action = CmdRunAction(command='git config --global core.pager ""')
  273. action.timeout = 600
  274. logger.info(action, extra={'msg_type': 'ACTION'})
  275. obs = runtime.run_action(action)
  276. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  277. assert_and_raise(
  278. obs.exit_code == 0,
  279. f'Failed to git config --global core.pager "": {str(obs)}',
  280. )
  281. action = CmdRunAction(command='git add -A')
  282. action.timeout = 600
  283. logger.info(action, extra={'msg_type': 'ACTION'})
  284. obs = runtime.run_action(action)
  285. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  286. assert_and_raise(obs.exit_code == 0, f'Failed to git add -A: {str(obs)}')
  287. n_retries = 0
  288. git_patch = None
  289. while n_retries < 5:
  290. action = CmdRunAction(
  291. command=f'git diff --no-color --cached {instance["base_commit"]}',
  292. keep_prompt=False,
  293. )
  294. action.timeout = 600 + 100 * n_retries
  295. logger.info(action, extra={'msg_type': 'ACTION'})
  296. obs = runtime.run_action(action)
  297. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  298. n_retries += 1
  299. if isinstance(obs, CmdOutputObservation):
  300. if obs.exit_code == 0:
  301. git_patch = obs.content.strip()
  302. break
  303. else:
  304. logger.info('Failed to get git diff, retrying...')
  305. sleep_if_should_continue(10)
  306. elif isinstance(obs, ErrorObservation):
  307. logger.error(f'Error occurred: {obs.content}. Retrying...')
  308. sleep_if_should_continue(10)
  309. else:
  310. assert_and_raise(False, f'Unexpected observation type: {str(obs)}')
  311. assert_and_raise(git_patch is not None, 'Failed to get git diff (None)')
  312. logger.info('-' * 30)
  313. logger.info('END Runtime Completion Fn')
  314. logger.info('-' * 30)
  315. return {'git_patch': git_patch}
  316. def process_instance(
  317. instance: pd.Series,
  318. metadata: EvalMetadata,
  319. reset_logger: bool = True,
  320. ) -> EvalOutput:
  321. config = get_config(instance, metadata)
  322. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  323. if reset_logger:
  324. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  325. reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
  326. else:
  327. logger.info(f'Starting evaluation for instance {instance.instance_id}.')
  328. runtime = create_runtime(config)
  329. try:
  330. initialize_runtime(runtime, instance)
  331. instruction = get_instruction(instance, metadata)
  332. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  333. state: State | None = asyncio.run(
  334. run_controller(
  335. config=config,
  336. initial_user_action=MessageAction(content=instruction),
  337. runtime=runtime,
  338. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
  339. metadata.agent_class
  340. ],
  341. )
  342. )
  343. # if fatal error, throw EvalError to trigger re-run
  344. if (
  345. state.last_error
  346. and 'fatal error during agent execution' in state.last_error
  347. ):
  348. raise EvalException('Fatal error detected: ' + state.last_error)
  349. # ======= THIS IS SWE-Bench specific =======
  350. # Get git patch
  351. return_val = complete_runtime(runtime, instance)
  352. git_patch = return_val['git_patch']
  353. logger.info(
  354. f'Got git diff for instance {instance.instance_id}:\n--------\n{git_patch}\n--------'
  355. )
  356. finally:
  357. runtime.close()
  358. # ==========================================
  359. # ======= Attempt to evaluate the agent's edits =======
  360. # we use eval_infer.sh to evaluate the agent's edits, not here
  361. # because the agent may alter the environment / testcases
  362. test_result = {
  363. 'git_patch': git_patch,
  364. }
  365. # If you are working on some simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  366. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  367. if state is None:
  368. raise ValueError('State should not be None.')
  369. histories = [event_to_dict(event) for event in state.history.get_events()]
  370. metrics = state.metrics.get() if state.metrics else None
  371. # Save the output
  372. output = EvalOutput(
  373. instance_id=instance.instance_id,
  374. instruction=instruction,
  375. instance=instance.to_dict(), # SWE Bench specific
  376. test_result=test_result,
  377. metadata=metadata,
  378. history=histories,
  379. metrics=metrics,
  380. llm_completions=state.extra_data.get('llm_completions', []),
  381. error=state.last_error if state and state.last_error else None,
  382. )
  383. return output
  384. def filter_dataset(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
  385. file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.toml')
  386. if os.path.exists(file_path):
  387. with open(file_path, 'r') as file:
  388. data = toml.load(file)
  389. if 'selected_ids' in data:
  390. selected_ids = data['selected_ids']
  391. logger.info(
  392. f'Filtering {len(selected_ids)} tasks from "selected_ids"...'
  393. )
  394. subset = dataset[dataset[filter_column].isin(selected_ids)]
  395. logger.info(f'Retained {subset.shape[0]} tasks after filtering')
  396. return subset
  397. return dataset
  398. if __name__ == '__main__':
  399. parser = get_parser()
  400. parser.add_argument(
  401. '--dataset',
  402. type=str,
  403. default='princeton-nlp/SWE-bench',
  404. help='data set to evaluate on, either full-test or lite-test',
  405. )
  406. parser.add_argument(
  407. '--split',
  408. type=str,
  409. default='test',
  410. help='split to evaluate on',
  411. )
  412. args, _ = parser.parse_known_args()
  413. # NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
  414. # so we don't need to manage file uploading to OpenHands's repo
  415. dataset = load_dataset(args.dataset, split=args.split)
  416. logger.info(f'Loaded dataset {args.dataset} with split {args.split}')
  417. swe_bench_tests = filter_dataset(dataset.to_pandas(), 'instance_id')
  418. llm_config = None
  419. if args.llm_config:
  420. llm_config = get_llm_config_arg(args.llm_config)
  421. if llm_config is None:
  422. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  423. details = {}
  424. _agent_cls = openhands.agenthub.Agent.get_cls(args.agent_cls)
  425. dataset_descrption = (
  426. args.dataset.replace('/', '__') + '-' + args.split.replace('/', '__')
  427. )
  428. metadata = make_metadata(
  429. llm_config,
  430. dataset_descrption,
  431. args.agent_cls,
  432. args.max_iterations,
  433. args.eval_note,
  434. args.eval_output_dir,
  435. details=details,
  436. )
  437. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  438. instances = prepare_dataset(swe_bench_tests, output_file, args.eval_n_limit)
  439. if len(instances) > 0 and not isinstance(
  440. instances['PASS_TO_PASS'][instances['PASS_TO_PASS'].index[0]], str
  441. ):
  442. for col in ['PASS_TO_PASS', 'FAIL_TO_PASS']:
  443. instances[col] = instances[col].apply(lambda x: str(x))
  444. run_evaluation(
  445. instances, metadata, output_file, args.eval_num_workers, process_instance
  446. )