run_infer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. """Implements evaluation of agents on ML-Bench, a benchmark for assessing the effectiveness of
  2. Large Language Models (LLMs) in leveraging existing functions in open-source libraries for
  3. machine learning tasks. The benchmark is introduced in the paper "ML-Bench: Evaluating Large
  4. Language Models for Code Generation in Repository-Level Machine Learning Tasks"
  5. (https://arxiv.org/abs/2311.09835).
  6. Please see https://ghcr.io/super-dainiu/ml_bench and https://huggingface.co/datasets/super-dainiu/ml-bench
  7. for more details on the dataset and docker image used in this evaluation script.
  8. TODOs:
  9. - Support additional evaluation settings, such as providing raw README content or using a
  10. retriever to extract relevant segments.
  11. - Clean up the code and docker image used for evaluation.
  12. """
  13. import asyncio
  14. import os
  15. from typing import Any
  16. import pandas as pd
  17. from datasets import load_dataset
  18. from evaluation.utils.shared import (
  19. EvalMetadata,
  20. EvalOutput,
  21. codeact_user_response,
  22. make_metadata,
  23. prepare_dataset,
  24. reset_logger_for_multiprocessing,
  25. run_evaluation,
  26. )
  27. from openhands.controller.state.state import State
  28. from openhands.core.config import (
  29. AppConfig,
  30. SandboxConfig,
  31. get_llm_config_arg,
  32. get_parser,
  33. load_app_config,
  34. )
  35. from openhands.core.logger import openhands_logger as logger
  36. from openhands.core.main import create_runtime, run_controller
  37. from openhands.events.action import CmdRunAction, MessageAction
  38. from openhands.events.observation import CmdOutputObservation
  39. from openhands.runtime.runtime import Runtime
  40. config = load_app_config()
  41. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  42. 'CodeActAgent': codeact_user_response,
  43. }
  44. AGENT_CLS_TO_INST_SUFFIX = {
  45. 'CodeActAgent': 'When you think you have completed the task, please run the following command: <execute_bash> exit </execute_bash>.\n'
  46. }
  47. ID2CONDA = {
  48. 1: 'dgl_DS',
  49. 2: 'bert_DS',
  50. 3: 'lavis_DS',
  51. 4: 'if_DS',
  52. 5: 'V2V_DS',
  53. 6: 'esm_DS',
  54. 7: 'OP_DS',
  55. 8: 'TSL_DS',
  56. 9: 'EAP_DS',
  57. 10: 'PG_DS',
  58. 11: 'PIM_DS',
  59. 12: 'AD2_DS',
  60. 13: 'L3_DS',
  61. 14: 'MZ2_DS',
  62. 15: 'GSA2_DS',
  63. }
  64. def get_config(
  65. metadata: EvalMetadata,
  66. ) -> AppConfig:
  67. config = AppConfig(
  68. default_agent=metadata.agent_class,
  69. run_as_openhands=False,
  70. runtime='eventstream',
  71. max_iterations=metadata.max_iterations,
  72. sandbox=SandboxConfig(
  73. base_container_image='public.ecr.aws/i5g0m1f6/ml-bench',
  74. enable_auto_lint=True,
  75. use_host_network=False,
  76. ),
  77. # do not mount workspace
  78. workspace_base=None,
  79. workspace_mount_path=None,
  80. )
  81. config.set_llm_config(metadata.llm_config)
  82. return config
  83. def initialize_runtime(
  84. runtime: Runtime,
  85. instance: pd.Series, # this argument is not required
  86. ):
  87. """Initialize the runtime for the agent.
  88. This function is called before the runtime is used to run the agent.
  89. """
  90. logger.info(f"{'-' * 50} BEGIN Runtime Initialization Fn {'-' * 50}")
  91. obs: CmdOutputObservation
  92. # Set instance id
  93. action = CmdRunAction(command='mkdir -p /workspace')
  94. logger.info(action, extra={'msg_type': 'ACTION'})
  95. obs = runtime.run_action(action)
  96. assert obs.exit_code == 0
  97. # Set up the task environment
  98. action = CmdRunAction(command=f'conda activate {ID2CONDA[instance["github_id"]]}')
  99. logger.info(action, extra={'msg_type': 'ACTION'})
  100. obs = runtime.run_action(action)
  101. assert obs.exit_code == 0
  102. repo_url = instance['github']
  103. repo_name = repo_url.split('/')[-1]
  104. action = CmdRunAction(command=f'git clone {repo_url} /workspace/{repo_name}')
  105. logger.info(action, extra={'msg_type': 'ACTION'})
  106. obs = runtime.run_action(action)
  107. assert obs.exit_code == 0
  108. action = CmdRunAction(command=f'chmod -R 777 /workspace/{repo_name}')
  109. logger.info(action, extra={'msg_type': 'ACTION'})
  110. obs = runtime.run_action(action)
  111. assert obs.exit_code == 0
  112. # Navigate to the task's code path
  113. task_path = os.path.join('/workspace', repo_name, instance['path'][2:])
  114. action = CmdRunAction(command=f'cd {task_path}')
  115. logger.info(action, extra={'msg_type': 'ACTION'})
  116. obs = runtime.run_action(action)
  117. assert obs.exit_code == 0
  118. logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
  119. def complete_runtime(
  120. runtime: Runtime,
  121. instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
  122. ) -> dict[str, Any]:
  123. """Complete the runtime for the agent.
  124. This function is called before the runtime is used to run the agent.
  125. If you need to do something in the sandbox to get the correctness metric after
  126. the agent has run, modify this function.
  127. """
  128. logger.info(f"{'-' * 50} BEGIN Runtime Completion Fn {'-' * 50}")
  129. obs: CmdOutputObservation
  130. repo_url = instance['github']
  131. repo_name = repo_url.split('/')[-1]
  132. task_path = os.path.join('/workspace', repo_name, instance['path'][2:])
  133. # Evaluate the agent's script
  134. eval_script = os.path.join(task_path, 'run.sh')
  135. logger.info(f'Running evaluation script: {eval_script}')
  136. action = CmdRunAction(command=f'cat {eval_script}', keep_prompt=False)
  137. logger.info(action, extra={'msg_type': 'ACTION'})
  138. obs = runtime.run_action(action)
  139. if obs.exit_code == 0:
  140. eval_script_content = obs.content
  141. else:
  142. logger.error(f'Error reading evaluation script: {obs.content}')
  143. eval_script_content = ''
  144. action = CmdRunAction(
  145. command=f'timeout 120s conda run -n {ID2CONDA[instance["github_id"]]} bash {eval_script}',
  146. timeout=600,
  147. )
  148. logger.info(action, extra={'msg_type': 'ACTION'})
  149. obs = runtime.run_action(action)
  150. if obs.exit_code == 0:
  151. eval_output = obs.content
  152. else:
  153. logger.error(f'Error running evaluation script: {obs.content}')
  154. eval_output = ''
  155. outputs = {
  156. 'eval_script_content': eval_script_content,
  157. 'eval_output': eval_output,
  158. }
  159. if obs.exit_code != 0 and obs.exit_code != 124:
  160. logger.warning(f'Evaluation script failed with exit code {obs.exit_code}')
  161. logger.warning(f'Output: {eval_output}')
  162. outputs['success'] = int(
  163. 'KeyboardInterrupt' in eval_output
  164. ) # super-dainiu: assume ``KeyboardInterrupt`` is a success as is done in ML-Bench
  165. else:
  166. logger.info(f'Evaluation script succeeded with exit code {obs.exit_code}')
  167. logger.info(f'Output: {eval_output}')
  168. outputs['success'] = 1
  169. outputs['eval_exit_code'] = obs.exit_code
  170. logger.info(f"{'-' * 50} END Runtime Completion Fn {'-' * 50}")
  171. return outputs
  172. def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool = True):
  173. config = get_config(metadata)
  174. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  175. if reset_logger:
  176. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  177. reset_logger_for_multiprocessing(logger, instance['instance_id'], log_dir)
  178. else:
  179. logger.info(f'Starting evaluation for instance {instance["instance_id"]}.')
  180. repo_url = instance['github']
  181. repo_name = repo_url.split('/')[-1]
  182. task_path = os.path.join('/workspace', repo_name, instance['path'][2:])
  183. # Prepare the task instruction
  184. instruction = (
  185. f'Please complete the Machine Learning task in the following repository: {repo_name}\n\n'
  186. f'{instance["instruction"]}\n\n'
  187. 'You should create a script named `run.sh` under the specified path in the repo to run the task.\n\n'
  188. f'You can find the task repo at: {task_path}\n\n'
  189. + (
  190. 'Here is the prefix code for the task:\n'
  191. '```bash\n'
  192. f'{instance["prefix_code"]}\n'
  193. '```\n\n'
  194. if instance['prefix_code']
  195. else ''
  196. )
  197. + 'You should terminate the subprocess after running the task (e.g., call subprocess.Popen(args).wait()).'
  198. )
  199. instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
  200. runtime = create_runtime(config)
  201. initialize_runtime(runtime, instance)
  202. # Run the agent
  203. state: State | None = asyncio.run(
  204. run_controller(
  205. config=config,
  206. initial_user_action=MessageAction(content=instruction),
  207. runtime=runtime,
  208. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
  209. metadata.agent_class
  210. ),
  211. )
  212. )
  213. assert state is not None
  214. metrics = state.metrics.get() if state.metrics else {}
  215. test_result = complete_runtime(runtime)
  216. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  217. # for compatibility with the existing output format, we can remake the pairs here
  218. # remove when it becomes unnecessary
  219. histories = state.history.compatibility_for_eval_history_pairs()
  220. # Save the output
  221. output = EvalOutput(
  222. instance_id=instance['instance_id'],
  223. instance=instance.to_dict(),
  224. instruction=instruction,
  225. metadata=metadata,
  226. history=histories,
  227. test_result=test_result,
  228. metrics=metrics,
  229. )
  230. return output
  231. if __name__ == '__main__':
  232. parser = get_parser()
  233. parser.add_argument(
  234. '-s',
  235. '--eval-split',
  236. type=str,
  237. default='quarter',
  238. choices=['full', 'quarter'],
  239. help='data split to evaluate on, either full or quarter',
  240. )
  241. args, _ = parser.parse_known_args()
  242. data_split = args.eval_split
  243. ml_bench = load_dataset('super-dainiu/ml-bench', split=data_split).to_pandas()
  244. ml_bench.rename(columns={'id': 'instance_id'}, inplace=True)
  245. llm_config = None
  246. if args.llm_config:
  247. llm_config = get_llm_config_arg(args.llm_config)
  248. if llm_config is None:
  249. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  250. metadata = make_metadata(
  251. llm_config,
  252. f'ml-bench-{data_split}',
  253. args.agent_cls,
  254. args.max_iterations,
  255. args.eval_note,
  256. args.eval_output_dir,
  257. )
  258. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  259. instances = prepare_dataset(ml_bench, output_file, args.eval_n_limit)
  260. run_evaluation(
  261. instances, metadata, output_file, args.eval_num_workers, process_instance
  262. )