run_infer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. """Implements evaluation of agents on ML-Bench, a benchmark for assessing the effectiveness of
  2. Large Language Models (LLMs) in leveraging existing functions in open-source libraries for
  3. machine learning tasks. The benchmark is introduced in the paper "ML-Bench: Evaluating Large
  4. Language Models for Code Generation in Repository-Level Machine Learning Tasks"
  5. (https://arxiv.org/abs/2311.09835).
  6. Please see https://ghcr.io/super-dainiu/ml_bench and https://huggingface.co/datasets/super-dainiu/ml-bench
  7. for more details on the dataset and docker image used in this evaluation script.
  8. TODOs:
  9. - Support additional evaluation settings, such as providing raw README content or using a
  10. retriever to extract relevant segments.
  11. - Clean up the code and docker image used for evaluation.
  12. """
  13. import asyncio
  14. import os
  15. from typing import Any
  16. import pandas as pd
  17. from datasets import load_dataset
  18. from evaluation.utils.shared import (
  19. EvalMetadata,
  20. EvalOutput,
  21. codeact_user_response,
  22. compatibility_for_eval_history_pairs,
  23. make_metadata,
  24. prepare_dataset,
  25. reset_logger_for_multiprocessing,
  26. run_evaluation,
  27. )
  28. from openhands.controller.state.state import State
  29. from openhands.core.config import (
  30. AppConfig,
  31. SandboxConfig,
  32. get_llm_config_arg,
  33. get_parser,
  34. load_app_config,
  35. )
  36. from openhands.core.logger import openhands_logger as logger
  37. from openhands.core.main import create_runtime, run_controller
  38. from openhands.events.action import CmdRunAction, MessageAction
  39. from openhands.events.observation import CmdOutputObservation
  40. from openhands.runtime.base import Runtime
  41. from openhands.utils.async_utils import call_async_from_sync
  42. config = load_app_config()
  43. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  44. 'CodeActAgent': codeact_user_response,
  45. }
  46. AGENT_CLS_TO_INST_SUFFIX = {
  47. 'CodeActAgent': 'When you think you have completed the task, please finish the interaction using the "finish" tool.\n'
  48. }
  49. ID2CONDA = {
  50. 1: 'dgl_DS',
  51. 2: 'bert_DS',
  52. 3: 'lavis_DS',
  53. 4: 'if_DS',
  54. 5: 'V2V_DS',
  55. 6: 'esm_DS',
  56. 7: 'OP_DS',
  57. 8: 'TSL_DS',
  58. 9: 'EAP_DS',
  59. 10: 'PG_DS',
  60. 11: 'PIM_DS',
  61. 12: 'AD2_DS',
  62. 13: 'L3_DS',
  63. 14: 'MZ2_DS',
  64. 15: 'GSA2_DS',
  65. }
  66. def get_config(
  67. metadata: EvalMetadata,
  68. ) -> AppConfig:
  69. config = AppConfig(
  70. default_agent=metadata.agent_class,
  71. run_as_openhands=False,
  72. runtime='eventstream',
  73. max_iterations=metadata.max_iterations,
  74. sandbox=SandboxConfig(
  75. base_container_image='public.ecr.aws/i5g0m1f6/ml-bench',
  76. enable_auto_lint=True,
  77. use_host_network=False,
  78. ),
  79. # do not mount workspace
  80. workspace_base=None,
  81. workspace_mount_path=None,
  82. )
  83. config.set_llm_config(metadata.llm_config)
  84. return config
  85. def initialize_runtime(
  86. runtime: Runtime,
  87. instance: pd.Series, # this argument is not required
  88. ):
  89. """Initialize the runtime for the agent.
  90. This function is called before the runtime is used to run the agent.
  91. """
  92. logger.info(f"{'-' * 50} BEGIN Runtime Initialization Fn {'-' * 50}")
  93. obs: CmdOutputObservation
  94. # Set instance id
  95. action = CmdRunAction(command='mkdir -p /workspace')
  96. logger.info(action, extra={'msg_type': 'ACTION'})
  97. obs = runtime.run_action(action)
  98. assert obs.exit_code == 0
  99. # Set up the task environment
  100. action = CmdRunAction(command=f'conda activate {ID2CONDA[instance["github_id"]]}')
  101. logger.info(action, extra={'msg_type': 'ACTION'})
  102. obs = runtime.run_action(action)
  103. assert obs.exit_code == 0
  104. repo_url = instance['github']
  105. repo_name = repo_url.split('/')[-1]
  106. action = CmdRunAction(command=f'git clone {repo_url} /workspace/{repo_name}')
  107. logger.info(action, extra={'msg_type': 'ACTION'})
  108. obs = runtime.run_action(action)
  109. assert obs.exit_code == 0
  110. action = CmdRunAction(command=f'chmod -R 777 /workspace/{repo_name}')
  111. logger.info(action, extra={'msg_type': 'ACTION'})
  112. obs = runtime.run_action(action)
  113. assert obs.exit_code == 0
  114. # Navigate to the task's code path
  115. task_path = os.path.join('/workspace', repo_name, instance['path'][2:])
  116. action = CmdRunAction(command=f'cd {task_path}')
  117. logger.info(action, extra={'msg_type': 'ACTION'})
  118. obs = runtime.run_action(action)
  119. assert obs.exit_code == 0
  120. logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
  121. def complete_runtime(
  122. runtime: Runtime,
  123. instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
  124. ) -> dict[str, Any]:
  125. """Complete the runtime for the agent.
  126. This function is called before the runtime is used to run the agent.
  127. If you need to do something in the sandbox to get the correctness metric after
  128. the agent has run, modify this function.
  129. """
  130. logger.info(f"{'-' * 50} BEGIN Runtime Completion Fn {'-' * 50}")
  131. obs: CmdOutputObservation
  132. repo_url = instance['github']
  133. repo_name = repo_url.split('/')[-1]
  134. task_path = os.path.join('/workspace', repo_name, instance['path'][2:])
  135. # Evaluate the agent's script
  136. eval_script = os.path.join(task_path, 'run.sh')
  137. logger.info(f'Running evaluation script: {eval_script}')
  138. action = CmdRunAction(command=f'cat {eval_script}', keep_prompt=False)
  139. logger.info(action, extra={'msg_type': 'ACTION'})
  140. obs = runtime.run_action(action)
  141. if obs.exit_code == 0:
  142. eval_script_content = obs.content
  143. else:
  144. logger.error(f'Error reading evaluation script: {obs.content}')
  145. eval_script_content = ''
  146. action = CmdRunAction(
  147. command=f'timeout 120s conda run -n {ID2CONDA[instance["github_id"]]} bash {eval_script}',
  148. timeout=600,
  149. )
  150. logger.info(action, extra={'msg_type': 'ACTION'})
  151. obs = runtime.run_action(action)
  152. if obs.exit_code == 0:
  153. eval_output = obs.content
  154. else:
  155. logger.error(f'Error running evaluation script: {obs.content}')
  156. eval_output = ''
  157. outputs = {
  158. 'eval_script_content': eval_script_content,
  159. 'eval_output': eval_output,
  160. }
  161. if obs.exit_code != 0 and obs.exit_code != 124:
  162. logger.warning(f'Evaluation script failed with exit code {obs.exit_code}')
  163. logger.warning(f'Output: {eval_output}')
  164. outputs['success'] = int(
  165. 'KeyboardInterrupt' in eval_output
  166. ) # super-dainiu: assume ``KeyboardInterrupt`` is a success as is done in ML-Bench
  167. else:
  168. logger.info(f'Evaluation script succeeded with exit code {obs.exit_code}')
  169. logger.info(f'Output: {eval_output}')
  170. outputs['success'] = 1
  171. outputs['eval_exit_code'] = obs.exit_code
  172. logger.info(f"{'-' * 50} END Runtime Completion Fn {'-' * 50}")
  173. return outputs
  174. def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool = True):
  175. config = get_config(metadata)
  176. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  177. if reset_logger:
  178. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  179. reset_logger_for_multiprocessing(logger, instance['instance_id'], log_dir)
  180. else:
  181. logger.info(f'Starting evaluation for instance {instance["instance_id"]}.')
  182. repo_url = instance['github']
  183. repo_name = repo_url.split('/')[-1]
  184. task_path = os.path.join('/workspace', repo_name, instance['path'][2:])
  185. # Prepare the task instruction
  186. instruction = (
  187. f'Please complete the Machine Learning task in the following repository: {repo_name}\n\n'
  188. f'{instance["instruction"]}\n\n'
  189. 'You should create a script named `run.sh` under the specified path in the repo to run the task.\n\n'
  190. f'You can find the task repo at: {task_path}\n\n'
  191. + (
  192. 'Here is the prefix code for the task:\n'
  193. '```bash\n'
  194. f'{instance["prefix_code"]}\n'
  195. '```\n\n'
  196. if instance['prefix_code']
  197. else ''
  198. )
  199. + 'You should terminate the subprocess after running the task (e.g., call subprocess.Popen(args).wait()).'
  200. )
  201. instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
  202. runtime = create_runtime(config)
  203. call_async_from_sync(runtime.connect)
  204. initialize_runtime(runtime, instance)
  205. # Run the agent
  206. state: State | None = asyncio.run(
  207. run_controller(
  208. config=config,
  209. initial_user_action=MessageAction(content=instruction),
  210. runtime=runtime,
  211. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
  212. metadata.agent_class
  213. ),
  214. )
  215. )
  216. assert state is not None
  217. metrics = state.metrics.get() if state.metrics else {}
  218. test_result = complete_runtime(runtime)
  219. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  220. # for compatibility with the existing output format, we can remake the pairs here
  221. # remove when it becomes unnecessary
  222. histories = compatibility_for_eval_history_pairs(state.history)
  223. # Save the output
  224. output = EvalOutput(
  225. instance_id=instance['instance_id'],
  226. instance=instance.to_dict(),
  227. instruction=instruction,
  228. metadata=metadata,
  229. history=histories,
  230. test_result=test_result,
  231. metrics=metrics,
  232. )
  233. return output
  234. if __name__ == '__main__':
  235. parser = get_parser()
  236. parser.add_argument(
  237. '-s',
  238. '--eval-split',
  239. type=str,
  240. default='quarter',
  241. choices=['full', 'quarter'],
  242. help='data split to evaluate on, either full or quarter',
  243. )
  244. args, _ = parser.parse_known_args()
  245. data_split = args.eval_split
  246. ml_bench = load_dataset('super-dainiu/ml-bench', split=data_split).to_pandas()
  247. ml_bench.rename(columns={'id': 'instance_id'}, inplace=True)
  248. llm_config = None
  249. if args.llm_config:
  250. llm_config = get_llm_config_arg(args.llm_config)
  251. if llm_config is None:
  252. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  253. metadata = make_metadata(
  254. llm_config,
  255. f'ml-bench-{data_split}',
  256. args.agent_cls,
  257. args.max_iterations,
  258. args.eval_note,
  259. args.eval_output_dir,
  260. )
  261. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  262. instances = prepare_dataset(ml_bench, output_file, args.eval_n_limit)
  263. run_evaluation(
  264. instances, metadata, output_file, args.eval_num_workers, process_instance
  265. )