run_infer.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. import asyncio
  2. import json
  3. import os
  4. import tempfile
  5. from typing import Any
  6. import pandas as pd
  7. import toml
  8. from datasets import load_dataset
  9. import agenthub
  10. from evaluation.swe_bench.prompt import CODEACT_SWE_PROMPT
  11. from evaluation.utils.shared import (
  12. EvalMetadata,
  13. EvalOutput,
  14. codeact_user_response,
  15. make_metadata,
  16. prepare_dataset,
  17. reset_logger_for_multiprocessing,
  18. run_evaluation,
  19. )
  20. from openhands.controller.state.state import State
  21. from openhands.core.config import (
  22. AppConfig,
  23. SandboxConfig,
  24. get_llm_config_arg,
  25. parse_arguments,
  26. )
  27. from openhands.core.logger import openhands_logger as logger
  28. from openhands.core.main import create_runtime, run_controller
  29. from openhands.events.action import CmdRunAction
  30. from openhands.events.observation import CmdOutputObservation, ErrorObservation
  31. from openhands.runtime.runtime import Runtime
  32. USE_HINT_TEXT = os.environ.get('USE_HINT_TEXT', 'false').lower() == 'true'
  33. USE_INSTANCE_IMAGE = os.environ.get('USE_INSTANCE_IMAGE', 'false').lower() == 'true'
  34. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  35. 'CodeActAgent': codeact_user_response,
  36. 'CodeActSWEAgent': codeact_user_response,
  37. }
  38. AGENT_CLS_TO_INST_SUFFIX = {
  39. 'CodeActAgent': 'When you think you have fixed the issue through code changes, please run the following command: <execute_bash> exit </execute_bash>.\n',
  40. 'CodeActSWEAgent': 'When you think you have fixed the issue through code changes, please run the following command: <execute_bash> exit </execute_bash>.\n',
  41. }
  42. def _get_swebench_workspace_dir_name(instance: pd.Series) -> str:
  43. return f'{instance.repo}__{instance.version}'.replace('/', '__')
  44. def get_instruction(instance: pd.Series, metadata: EvalMetadata):
  45. workspace_dir_name = _get_swebench_workspace_dir_name(instance)
  46. # Prepare instruction
  47. if metadata.agent_class == 'CodeActSWEAgent':
  48. instruction = (
  49. 'We are currently solving the following issue within our repository. Here is the issue text:\n'
  50. '--- BEGIN ISSUE ---\n'
  51. f'{instance.problem_statement}\n'
  52. '--- END ISSUE ---\n\n'
  53. )
  54. if USE_HINT_TEXT and instance.hints_text:
  55. instruction += (
  56. f'--- BEGIN HINTS ---\n{instance.hints_text}\n--- END HINTS ---\n'
  57. )
  58. instruction += CODEACT_SWE_PROMPT.format(workspace_dir_name=workspace_dir_name)
  59. else:
  60. # Testing general agents
  61. instruction = (
  62. f'Please fix the following issue for the repository in /workspace/{workspace_dir_name}.\n'
  63. 'Environment has been set up for you to start working. You may assume all necessary tools are installed.\n\n'
  64. '# Problem Statement\n'
  65. f'{instance.problem_statement}\n\n'
  66. )
  67. if USE_HINT_TEXT and instance.hints_text:
  68. instruction += f'# Hints\n{instance.hints_text}\n\n'
  69. instruction += (
  70. 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
  71. 'You should NOT modify any existing test case files. If needed, you can add new test cases in a NEW file to reproduce the issue.\n'
  72. 'You SHOULD INCLUDE PROPER INDENTATION in your edit commands.\n'
  73. )
  74. # NOTE: You can actually set slightly different instruction for different agents
  75. instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
  76. return instruction
  77. def get_config(
  78. instance: pd.Series,
  79. metadata: EvalMetadata,
  80. ) -> AppConfig:
  81. SWE_BENCH_CONTAINER_IMAGE = 'ghcr.io/opendevin/eval-swe-bench:full-v1.2.1'
  82. if USE_INSTANCE_IMAGE:
  83. # We use a different instance image for the each instance of swe-bench eval
  84. container_image = 'sweb.eval.x86_64.' + instance['instance_id']
  85. else:
  86. container_image = SWE_BENCH_CONTAINER_IMAGE
  87. config = AppConfig(
  88. default_agent=metadata.agent_class,
  89. run_as_openhands=False,
  90. runtime='eventstream',
  91. max_budget_per_task=4,
  92. max_iterations=metadata.max_iterations,
  93. sandbox=SandboxConfig(
  94. container_image=container_image,
  95. enable_auto_lint=True,
  96. use_host_network=False,
  97. # large enough timeout, since some testcases take very long to run
  98. timeout=300,
  99. ),
  100. # do not mount workspace
  101. workspace_base=None,
  102. workspace_mount_path=None,
  103. )
  104. config.set_llm_config(metadata.llm_config)
  105. return config
  106. async def initialize_runtime(
  107. runtime: Runtime,
  108. instance: pd.Series, # this argument is not required
  109. ):
  110. """Initialize the runtime for the agent.
  111. This function is called before the runtime is used to run the agent.
  112. """
  113. logger.info('-' * 30)
  114. logger.info('BEGIN Runtime Initialization Fn')
  115. logger.info('-' * 30)
  116. workspace_dir_name = _get_swebench_workspace_dir_name(instance)
  117. obs: CmdOutputObservation
  118. # Set instance id
  119. action = CmdRunAction(
  120. command=f"""echo 'export SWE_INSTANCE_ID={instance['instance_id']}' >> ~/.bashrc && echo 'export PIP_CACHE_DIR=~/.cache/pip' >> ~/.bashrc && echo "alias git='git --no-pager'" >> ~/.bashrc"""
  121. )
  122. logger.info(action, extra={'msg_type': 'ACTION'})
  123. obs = await runtime.run_action(action)
  124. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  125. assert obs.exit_code == 0
  126. if USE_INSTANCE_IMAGE:
  127. # inject the init script
  128. script_dir = os.path.dirname(__file__)
  129. # inject the instance info
  130. action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances')
  131. logger.info(action, extra={'msg_type': 'ACTION'})
  132. obs = await runtime.run_action(action)
  133. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  134. assert (
  135. obs.exit_code == 0
  136. ), f'Failed to create /swe_util/eval_data/instances: {obs.content}'
  137. swe_instance_json_name = 'swe-bench-instance.json'
  138. with tempfile.TemporaryDirectory() as temp_dir:
  139. # Construct the full path for the desired file name within the temporary directory
  140. temp_file_path = os.path.join(temp_dir, swe_instance_json_name)
  141. # Write to the file with the desired name within the temporary directory
  142. with open(temp_file_path, 'w') as f:
  143. if not isinstance(instance, dict):
  144. json.dump([instance.to_dict()], f)
  145. else:
  146. json.dump([instance], f)
  147. # Copy the file to the desired location
  148. await runtime.copy_to(temp_file_path, '/swe_util/eval_data/instances/')
  149. # inject the instance swe entry
  150. await runtime.copy_to(
  151. str(os.path.join(script_dir, 'scripts/setup/instance_swe_entry.sh')),
  152. '/swe_util/',
  153. )
  154. action = CmdRunAction(command='cat ~/.bashrc')
  155. logger.info(action, extra={'msg_type': 'ACTION'})
  156. obs = await runtime.run_action(action)
  157. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  158. assert obs.exit_code == 0
  159. action = CmdRunAction(command='source ~/.bashrc')
  160. logger.info(action, extra={'msg_type': 'ACTION'})
  161. obs = await runtime.run_action(action)
  162. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  163. assert obs.exit_code == 0
  164. action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
  165. logger.info(action, extra={'msg_type': 'ACTION'})
  166. obs = await runtime.run_action(action)
  167. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  168. assert obs.exit_code == 0
  169. else:
  170. action = CmdRunAction(command='source /swe_util/swe_entry.sh')
  171. logger.info(action, extra={'msg_type': 'ACTION'})
  172. obs = await runtime.run_action(action)
  173. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  174. assert (
  175. obs.exit_code == 0
  176. ), f'Failed to source /swe_util/swe_entry.sh: {obs.content}'
  177. action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
  178. logger.info(action, extra={'msg_type': 'ACTION'})
  179. obs = await runtime.run_action(action)
  180. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  181. assert obs.exit_code == 0
  182. action = CmdRunAction(command='git reset --hard')
  183. logger.info(action, extra={'msg_type': 'ACTION'})
  184. obs = await runtime.run_action(action)
  185. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  186. assert obs.exit_code == 0
  187. action = CmdRunAction(
  188. command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done'
  189. )
  190. logger.info(action, extra={'msg_type': 'ACTION'})
  191. obs = await runtime.run_action(action)
  192. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  193. assert obs.exit_code == 0
  194. logger.info('-' * 30)
  195. logger.info('END Runtime Initialization Fn')
  196. logger.info('-' * 30)
  197. async def complete_runtime(
  198. runtime: Runtime,
  199. instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
  200. ) -> dict[str, Any]:
  201. """Complete the runtime for the agent.
  202. This function is called before the runtime is used to run the agent.
  203. If you need to do something in the sandbox to get the correctness metric after
  204. the agent has run, modify this function.
  205. """
  206. logger.info('-' * 30)
  207. logger.info('BEGIN Runtime Completion Fn')
  208. logger.info('-' * 30)
  209. obs: CmdOutputObservation
  210. workspace_dir_name = _get_swebench_workspace_dir_name(instance)
  211. action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
  212. logger.info(action, extra={'msg_type': 'ACTION'})
  213. obs = await runtime.run_action(action)
  214. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  215. assert obs.exit_code == 0
  216. action = CmdRunAction(command='git config --global core.pager ""')
  217. logger.info(action, extra={'msg_type': 'ACTION'})
  218. obs = await runtime.run_action(action)
  219. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  220. assert obs.exit_code == 0
  221. action = CmdRunAction(command='git add -A')
  222. logger.info(action, extra={'msg_type': 'ACTION'})
  223. obs = await runtime.run_action(action)
  224. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  225. assert obs.exit_code == 0
  226. n_retries = 0
  227. git_patch = None
  228. while n_retries < 5:
  229. action = CmdRunAction(
  230. command=f'git diff --no-color --cached {instance["base_commit"]}',
  231. keep_prompt=False,
  232. )
  233. action.timeout = 600 + 100 * n_retries
  234. logger.info(action, extra={'msg_type': 'ACTION'})
  235. obs = await runtime.run_action(action)
  236. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  237. n_retries += 1
  238. if isinstance(obs, CmdOutputObservation):
  239. if obs.exit_code == 0:
  240. git_patch = obs.content.strip()
  241. break
  242. else:
  243. logger.info('Failed to get git diff, retrying...')
  244. await asyncio.sleep(10)
  245. elif isinstance(obs, ErrorObservation):
  246. logger.error(f'Error occurred: {obs.content}. Retrying...')
  247. await asyncio.sleep(10)
  248. else:
  249. raise ValueError(f'Unexpected observation type: {type(obs)}')
  250. logger.info('-' * 30)
  251. logger.info('END Runtime Completion Fn')
  252. logger.info('-' * 30)
  253. return {'git_patch': git_patch}
  254. async def process_instance(
  255. instance: pd.Series,
  256. metadata: EvalMetadata,
  257. reset_logger: bool = True,
  258. ) -> EvalOutput:
  259. config = get_config(instance, metadata)
  260. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  261. if reset_logger:
  262. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  263. reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
  264. else:
  265. logger.info(f'Starting evaluation for instance {instance.instance_id}.')
  266. runtime = await create_runtime(config, sid=instance.instance_id)
  267. await initialize_runtime(runtime, instance)
  268. instruction = get_instruction(instance, metadata)
  269. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  270. state: State | None = await run_controller(
  271. config=config,
  272. task_str=instruction,
  273. runtime=runtime,
  274. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[metadata.agent_class],
  275. )
  276. # ======= THIS IS SWE-Bench specific =======
  277. # Get git patch
  278. return_val = await complete_runtime(runtime, instance)
  279. git_patch = return_val['git_patch']
  280. logger.info(
  281. f'Got git diff for instance {instance.instance_id}:\n--------\n{git_patch}\n--------'
  282. )
  283. # ==========================================
  284. # ======= Attempt to evaluate the agent's edits =======
  285. # we use eval_infer.sh to evaluate the agent's edits, not here
  286. # because the agent may alter the environment / testcases
  287. test_result = {
  288. 'git_patch': git_patch,
  289. }
  290. # If you are working on some simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  291. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  292. if state is None:
  293. raise ValueError('State should not be None.')
  294. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  295. # for compatibility with the existing output format, we can remake the pairs here
  296. # remove when it becomes unnecessary
  297. histories = state.history.compatibility_for_eval_history_pairs()
  298. metrics = state.metrics.get() if state.metrics else None
  299. # Save the output
  300. output = EvalOutput(
  301. instance_id=instance.instance_id,
  302. instruction=instruction,
  303. instance=instance.to_dict(), # SWE Bench specific
  304. test_result=test_result,
  305. metadata=metadata,
  306. history=histories,
  307. metrics=metrics,
  308. error=state.last_error if state and state.last_error else None,
  309. )
  310. return output
  311. def filter_dataset(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
  312. file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.toml')
  313. if os.path.exists(file_path):
  314. with open(file_path, 'r') as file:
  315. data = toml.load(file)
  316. if 'selected_ids' in data:
  317. selected_ids = data['selected_ids']
  318. logger.info(
  319. f'Filtering {len(selected_ids)} tasks from "selected_ids"...'
  320. )
  321. subset = dataset[dataset[filter_column].isin(selected_ids)]
  322. logger.info(f'Retained {subset.shape[0]} tasks after filtering')
  323. return subset
  324. return dataset
  325. if __name__ == '__main__':
  326. args = parse_arguments()
  327. # NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
  328. # so we don't need to manage file uploading to OpenHands's repo
  329. dataset = load_dataset('princeton-nlp/SWE-bench_Lite')
  330. swe_bench_tests = filter_dataset(dataset['test'].to_pandas(), 'instance_id')
  331. llm_config = None
  332. if args.llm_config:
  333. llm_config = get_llm_config_arg(args.llm_config)
  334. if llm_config is None:
  335. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  336. details = {}
  337. _agent_cls = agenthub.Agent.get_cls(args.agent_cls)
  338. if hasattr(_agent_cls, 'system_message'):
  339. details['system_message'] = _agent_cls.system_message
  340. if hasattr(_agent_cls, 'in_context_example'):
  341. details['in_context_example'] = _agent_cls.in_context_example
  342. metadata = make_metadata(
  343. llm_config,
  344. 'swe-bench-lite',
  345. args.agent_cls,
  346. args.max_iterations,
  347. args.eval_note,
  348. args.eval_output_dir,
  349. details=details,
  350. )
  351. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  352. instances = prepare_dataset(swe_bench_tests, output_file, args.eval_n_limit)
  353. asyncio.run(
  354. run_evaluation(
  355. instances, metadata, output_file, args.eval_num_workers, process_instance
  356. )
  357. )