run_infer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. """Implements evaluation of agents on ML-Bench, a benchmark for assessing the effectiveness of
  2. Large Language Models (LLMs) in leveraging existing functions in open-source libraries for
  3. machine learning tasks. The benchmark is introduced in the paper "ML-Bench: Evaluating Large
  4. Language Models for Code Generation in Repository-Level Machine Learning Tasks"
  5. (https://arxiv.org/abs/2311.09835).
  6. Please see https://ghcr.io/super-dainiu/ml_bench and https://huggingface.co/datasets/super-dainiu/ml-bench
  7. for more details on the dataset and docker image used in this evaluation script.
  8. TODOs:
  9. - Support additional evaluation settings, such as providing raw README content or using a
  10. retriever to extract relevant segments.
  11. - Clean up the code and docker image used for evaluation.
  12. """
  13. import os
  14. from typing import Any
  15. import pandas as pd
  16. from datasets import load_dataset
  17. from evaluation.utils.shared import (
  18. EvalMetadata,
  19. EvalOutput,
  20. codeact_user_response,
  21. make_metadata,
  22. prepare_dataset,
  23. reset_logger_for_multiprocessing,
  24. run_evaluation,
  25. )
  26. from openhands.controller.state.state import State
  27. from openhands.core.config import (
  28. AppConfig,
  29. SandboxConfig,
  30. get_llm_config_arg,
  31. get_parser,
  32. load_app_config,
  33. )
  34. from openhands.core.logger import openhands_logger as logger
  35. from openhands.core.main import create_runtime, run_controller
  36. from openhands.events.action import CmdRunAction
  37. from openhands.events.observation import CmdOutputObservation
  38. from openhands.runtime.runtime import Runtime
  39. config = load_app_config()
  40. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  41. 'CodeActAgent': codeact_user_response,
  42. }
  43. AGENT_CLS_TO_INST_SUFFIX = {
  44. 'CodeActAgent': 'When you think you have completed the task, please run the following command: <execute_bash> exit </execute_bash>.\n'
  45. }
  46. ID2CONDA = {
  47. 1: 'dgl_DS',
  48. 2: 'bert_DS',
  49. 3: 'lavis_DS',
  50. 4: 'if_DS',
  51. 5: 'V2V_DS',
  52. 6: 'esm_DS',
  53. 7: 'OP_DS',
  54. 8: 'TSL_DS',
  55. 9: 'EAP_DS',
  56. 10: 'PG_DS',
  57. 11: 'PIM_DS',
  58. 12: 'AD2_DS',
  59. 13: 'L3_DS',
  60. 14: 'MZ2_DS',
  61. 15: 'GSA2_DS',
  62. }
  63. def get_config(
  64. metadata: EvalMetadata,
  65. ) -> AppConfig:
  66. config = AppConfig(
  67. default_agent=metadata.agent_class,
  68. run_as_openhands=False,
  69. runtime='eventstream',
  70. max_iterations=metadata.max_iterations,
  71. sandbox=SandboxConfig(
  72. container_image='public.ecr.aws/i5g0m1f6/ml-bench',
  73. enable_auto_lint=True,
  74. use_host_network=False,
  75. ),
  76. # do not mount workspace
  77. workspace_base=None,
  78. workspace_mount_path=None,
  79. )
  80. config.set_llm_config(metadata.llm_config)
  81. return config
  82. async def initialize_runtime(
  83. runtime: Runtime,
  84. instance: pd.Series, # this argument is not required
  85. ):
  86. """Initialize the runtime for the agent.
  87. This function is called before the runtime is used to run the agent.
  88. """
  89. logger.info(f"{'-' * 50} BEGIN Runtime Initialization Fn {'-' * 50}")
  90. obs: CmdOutputObservation
  91. # Set instance id
  92. action = CmdRunAction(command='mkdir -p /workspace')
  93. logger.info(action, extra={'msg_type': 'ACTION'})
  94. obs = await runtime.run_action(action)
  95. assert obs.exit_code == 0
  96. # Set up the task environment
  97. action = CmdRunAction(command=f'conda activate {ID2CONDA[instance["github_id"]]}')
  98. logger.info(action, extra={'msg_type': 'ACTION'})
  99. obs = await runtime.run_action(action)
  100. assert obs.exit_code == 0
  101. repo_url = instance['github']
  102. repo_name = repo_url.split('/')[-1]
  103. action = CmdRunAction(command=f'git clone {repo_url} /workspace/{repo_name}')
  104. logger.info(action, extra={'msg_type': 'ACTION'})
  105. obs = await runtime.run_action(action)
  106. assert obs.exit_code == 0
  107. action = CmdRunAction(command=f'chmod -R 777 /workspace/{repo_name}')
  108. logger.info(action, extra={'msg_type': 'ACTION'})
  109. obs = await runtime.run_action(action)
  110. assert obs.exit_code == 0
  111. # Navigate to the task's code path
  112. task_path = os.path.join('/workspace', repo_name, instance['path'][2:])
  113. action = CmdRunAction(command=f'cd {task_path}')
  114. logger.info(action, extra={'msg_type': 'ACTION'})
  115. obs = await runtime.run_action(action)
  116. assert obs.exit_code == 0
  117. logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
  118. async def complete_runtime(
  119. runtime: Runtime,
  120. instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
  121. ) -> dict[str, Any]:
  122. """Complete the runtime for the agent.
  123. This function is called before the runtime is used to run the agent.
  124. If you need to do something in the sandbox to get the correctness metric after
  125. the agent has run, modify this function.
  126. """
  127. logger.info(f"{'-' * 50} BEGIN Runtime Completion Fn {'-' * 50}")
  128. obs: CmdOutputObservation
  129. repo_url = instance['github']
  130. repo_name = repo_url.split('/')[-1]
  131. task_path = os.path.join('/workspace', repo_name, instance['path'][2:])
  132. # Evaluate the agent's script
  133. eval_script = os.path.join(task_path, 'run.sh')
  134. logger.info(f'Running evaluation script: {eval_script}')
  135. action = CmdRunAction(command=f'cat {eval_script}', keep_prompt=False)
  136. logger.info(action, extra={'msg_type': 'ACTION'})
  137. obs = await runtime.run_action(action)
  138. if obs.exit_code == 0:
  139. eval_script_content = obs.content
  140. else:
  141. logger.error(f'Error reading evaluation script: {obs.content}')
  142. eval_script_content = ''
  143. action = CmdRunAction(
  144. command=f'timeout 120s conda run -n {ID2CONDA[instance["github_id"]]} bash {eval_script}',
  145. timeout=600,
  146. )
  147. logger.info(action, extra={'msg_type': 'ACTION'})
  148. obs = await runtime.run_action(action)
  149. if obs.exit_code == 0:
  150. eval_output = obs.content
  151. else:
  152. logger.error(f'Error running evaluation script: {obs.content}')
  153. eval_output = ''
  154. outputs = {
  155. 'eval_script_content': eval_script_content,
  156. 'eval_output': eval_output,
  157. }
  158. if obs.exit_code != 0 and obs.exit_code != 124:
  159. logger.warning(f'Evaluation script failed with exit code {obs.exit_code}')
  160. logger.warning(f'Output: {eval_output}')
  161. outputs['success'] = int(
  162. 'KeyboardInterrupt' in eval_output
  163. ) # super-dainiu: assume ``KeyboardInterrupt`` is a success as is done in ML-Bench
  164. else:
  165. logger.info(f'Evaluation script succeeded with exit code {obs.exit_code}')
  166. logger.info(f'Output: {eval_output}')
  167. outputs['success'] = 1
  168. outputs['eval_exit_code'] = obs.exit_code
  169. logger.info(f"{'-' * 50} END Runtime Completion Fn {'-' * 50}")
  170. return outputs
  171. async def process_instance(
  172. instance: Any, metadata: EvalMetadata, reset_logger: bool = True
  173. ):
  174. config = get_config(metadata)
  175. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  176. if reset_logger:
  177. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  178. reset_logger_for_multiprocessing(logger, instance['instance_id'], log_dir)
  179. else:
  180. logger.info(f'Starting evaluation for instance {instance["instance_id"]}.')
  181. # Create a sandbox, using the instance ID and PID as the session ID to avoid conflicts
  182. sid = str(instance['instance_id'])
  183. repo_url = instance['github']
  184. repo_name = repo_url.split('/')[-1]
  185. task_path = os.path.join('/workspace', repo_name, instance['path'][2:])
  186. # Prepare the task instruction
  187. instruction = (
  188. f'Please complete the Machine Learning task in the following repository: {repo_name}\n\n'
  189. f'{instance["instruction"]}\n\n'
  190. 'You should create a script named `run.sh` under the specified path in the repo to run the task.\n\n'
  191. f'You can find the task repo at: {task_path}\n\n'
  192. + (
  193. 'Here is the prefix code for the task:\n'
  194. '```bash\n'
  195. f'{instance["prefix_code"]}\n'
  196. '```\n\n'
  197. if instance['prefix_code']
  198. else ''
  199. )
  200. + 'You should terminate the subprocess after running the task (e.g., call subprocess.Popen(args).wait()).'
  201. )
  202. instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
  203. runtime = await create_runtime(config, sid=sid)
  204. await initialize_runtime(runtime, instance)
  205. # Run the agent
  206. state: State | None = await run_controller(
  207. config=config,
  208. task_str=instruction,
  209. runtime=runtime,
  210. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
  211. metadata.agent_class
  212. ),
  213. )
  214. assert state is not None
  215. metrics = state.metrics.get() if state.metrics else {}
  216. test_result = await complete_runtime(runtime)
  217. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  218. # for compatibility with the existing output format, we can remake the pairs here
  219. # remove when it becomes unnecessary
  220. histories = state.history.compatibility_for_eval_history_pairs()
  221. # Save the output
  222. output = EvalOutput(
  223. instance_id=instance['instance_id'],
  224. instance=instance.to_dict(),
  225. instruction=instruction,
  226. metadata=metadata,
  227. history=histories,
  228. test_result=test_result,
  229. metrics=metrics,
  230. )
  231. return output
  232. if __name__ == '__main__':
  233. parser = get_parser()
  234. parser.add_argument(
  235. '-s',
  236. '--eval-split',
  237. type=str,
  238. default='quarter',
  239. choices=['full', 'quarter'],
  240. help='data split to evaluate on, either full or quarter',
  241. )
  242. args, _ = parser.parse_known_args()
  243. data_split = args.eval_split
  244. ml_bench = load_dataset('super-dainiu/ml-bench', split=data_split).to_pandas()
  245. ml_bench.rename(columns={'id': 'instance_id'}, inplace=True)
  246. llm_config = None
  247. if args.llm_config:
  248. llm_config = get_llm_config_arg(args.llm_config)
  249. if llm_config is None:
  250. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  251. metadata = make_metadata(
  252. llm_config,
  253. f'ml-bench-{data_split}',
  254. args.agent_cls,
  255. args.max_iterations,
  256. args.eval_note,
  257. args.eval_output_dir,
  258. )
  259. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  260. instances = prepare_dataset(ml_bench, output_file, args.eval_n_limit)
  261. run_evaluation(
  262. instances, metadata, output_file, args.eval_num_workers, process_instance
  263. )