run_infer.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. """Implements evaluation of agents on ML-Bench, a benchmark for assessing the effectiveness of
  2. Large Language Models (LLMs) in leveraging existing functions in open-source libraries for
  3. machine learning tasks. The benchmark is introduced in the paper "ML-Bench: Evaluating Large
  4. Language Models for Code Generation in Repository-Level Machine Learning Tasks"
  5. (https://arxiv.org/abs/2311.09835).
  6. Please see https://ghcr.io/super-dainiu/ml_bench and https://huggingface.co/datasets/super-dainiu/ml-bench
  7. for more details on the dataset and docker image used in this evaluation script.
  8. TODOs:
  9. - Support additional evaluation settings, such as providing raw README content or using a
  10. retriever to extract relevant segments.
  11. - Clean up the code and docker image used for evaluation.
  12. """
  13. import asyncio
  14. import logging
  15. import os
  16. import pathlib
  17. from typing import Any
  18. from datasets import load_dataset
  19. from evaluation.utils.shared import (
  20. EvalMetadata,
  21. codeact_user_response,
  22. make_metadata,
  23. monologue_user_response,
  24. prepare_dataset,
  25. run_evaluation,
  26. )
  27. from opendevin.controller.agent import Agent
  28. from opendevin.controller.state.state import State
  29. from opendevin.core.config import config, get_llm_config_arg, get_parser
  30. from opendevin.core.logger import get_console_handler
  31. from opendevin.core.logger import opendevin_logger as logger
  32. from opendevin.core.main import run_agent_controller
  33. from opendevin.llm.llm import LLM
  34. from opendevin.runtime.docker.ssh_box import DockerSSHBox
  35. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  36. 'CodeActAgent': codeact_user_response,
  37. 'MonologueAgent': monologue_user_response,
  38. }
  39. AGENT_CLS_TO_INST_SUFFIX = {
  40. 'CodeActAgent': 'When you think you have completed the task, please run the following command: <execute_bash> exit </execute_bash>.\n'
  41. }
  42. ID2CONDA = {
  43. 1: 'dgl_DS',
  44. 2: 'bert_DS',
  45. 3: 'lavis_DS',
  46. 4: 'if_DS',
  47. 5: 'V2V_DS',
  48. 6: 'esm_DS',
  49. 7: 'OP_DS',
  50. 8: 'TSL_DS',
  51. 9: 'EAP_DS',
  52. 10: 'PG_DS',
  53. 11: 'PIM_DS',
  54. 12: 'AD2_DS',
  55. 13: 'L3_DS',
  56. 14: 'MZ2_DS',
  57. 15: 'GSA2_DS',
  58. }
  59. def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool = True):
  60. agent = Agent.get_cls(metadata.agent_class)(llm=LLM(llm_config=metadata.llm_config))
  61. old_workspace_mount_path = config.workspace_mount_path
  62. old_workspace_base = config.workspace_base
  63. try:
  64. workspace_mount_path = os.path.join(
  65. config.workspace_mount_path, '_eval_workspace'
  66. )
  67. # create process-specific workspace dir
  68. # so that different agent don't interfere with each other.
  69. workspace_mount_path = os.path.join(workspace_mount_path, str(os.getpid()))
  70. pathlib.Path(workspace_mount_path).mkdir(parents=True, exist_ok=True)
  71. # reset workspace to config
  72. config.workspace_base = workspace_mount_path
  73. config.workspace_mount_path = workspace_mount_path
  74. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  75. if reset_logger:
  76. # Set up logger
  77. log_file = os.path.join(
  78. metadata.eval_output_dir,
  79. 'logs',
  80. f"instance_{instance['id']}_pid_{os.getpid()}.log",
  81. )
  82. # Remove all existing handlers from logger
  83. for handler in logger.handlers[:]:
  84. logger.removeHandler(handler)
  85. # add back the console handler to print ONE line
  86. logger.addHandler(get_console_handler())
  87. logger.info(
  88. f"Starting evaluation for instance {instance['id']}.\nLOG: tail -f {log_file}"
  89. )
  90. # Remove all existing handlers from logger
  91. for handler in logger.handlers[:]:
  92. logger.removeHandler(handler)
  93. file_handler = logging.FileHandler(log_file)
  94. file_handler.setFormatter(
  95. logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
  96. )
  97. logger.addHandler(file_handler)
  98. logger.info(f'Process-specific workspace mounted at {workspace_mount_path}')
  99. # Create a sandbox, using the instance ID and PID as the session ID to avoid conflicts
  100. sid = str(instance['id']) + '_' + str(os.getpid())
  101. sandbox = DockerSSHBox(sid=sid)
  102. # Set up the task environment
  103. sandbox.execute(f'conda activate {ID2CONDA[instance["github_id"]]}')
  104. # Clone the task repo into the sandbox
  105. repo_url = instance['github']
  106. repo_name = repo_url.split('/')[-1]
  107. sandbox.execute(f'git clone {repo_url} /workspace/{repo_name}')
  108. sandbox.execute(f'chmod -R 777 /workspace/{repo_name}')
  109. # Navigate to the task's code path
  110. task_path = os.path.join('/workspace', repo_name, instance['path'][2:])
  111. sandbox.execute(f'cd {task_path}')
  112. # Prepare the task instruction
  113. instruction = (
  114. f'Please complete the Machine Learning task in the following repository: {repo_name}\n\n'
  115. f'The task is: {instance["task"]}\n\n'
  116. f'{instance["instruction"]}\n\n'
  117. 'You should create a script named `run.sh` under the specified path in the repo to run the task.\n\n'
  118. f'You can find the task repo at: {task_path}\n\n'
  119. + (
  120. 'Here is the prefix code for the task:\n'
  121. '```bash\n'
  122. f'{instance["prefix_code"]}\n'
  123. '```\n\n'
  124. if instance['prefix_code']
  125. else ''
  126. )
  127. + 'You should terminate the subprocess after running the task (e.g., call subprocess.Popen(args).wait()).'
  128. )
  129. instruction += AGENT_CLS_TO_INST_SUFFIX[agent.__class__.__name__]
  130. # Run the agent
  131. state: State | None = asyncio.run(
  132. run_agent_controller(
  133. agent,
  134. instruction,
  135. max_iterations=metadata.max_iterations,
  136. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
  137. agent.__class__.__name__
  138. ),
  139. sandbox=sandbox,
  140. sid=sid,
  141. )
  142. )
  143. assert state is not None
  144. metrics = state.metrics.get() if state.metrics else {}
  145. # Evaluate the agent's script
  146. eval_script = os.path.join(task_path, 'run.sh')
  147. logger.info(f'Running evaluation script: {eval_script}')
  148. try:
  149. _, eval_script_content = sandbox.execute(f'cat {eval_script}')
  150. except Exception as e:
  151. logger.error(f'Error reading evaluation script: {e}')
  152. eval_script_content = ''
  153. try:
  154. exit_code, eval_output = sandbox.execute(
  155. f'timeout 120s conda run -n {ID2CONDA[instance["github_id"]]} bash {eval_script}',
  156. timeout=600,
  157. )
  158. except Exception as e:
  159. logger.error(f'Error running evaluation script: {e}')
  160. exit_code = -1
  161. eval_output = ''
  162. if exit_code != 0 and exit_code != 124:
  163. logger.warning(f'Evaluation script failed with exit code {exit_code}')
  164. logger.warning(f'Output: {eval_output}')
  165. metrics['success'] = int(
  166. 'KeyboardInterrupt' in eval_output
  167. ) # super-dainiu: assume ``KeyboardInterrupt`` is a success as is done in ML-Bench
  168. else:
  169. logger.info(f'Evaluation script succeeded with exit code {exit_code}')
  170. logger.info(f'Output: {eval_output}')
  171. metrics['success'] = 1
  172. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  173. # for compatibility with the existing output format, we can remake the pairs here
  174. # remove when it becomes unnecessary
  175. histories = state.history.compatibility_for_eval_history_pairs()
  176. # Save the output
  177. output = {
  178. 'instance_id': instance['id'],
  179. 'repo': repo_url,
  180. 'instruction': instruction,
  181. 'metadata': metadata.model_dump(),
  182. 'history': histories,
  183. 'eval_script': eval_script_content,
  184. 'eval_exit_code': exit_code,
  185. 'eval_output': eval_output,
  186. 'metrics': metrics,
  187. }
  188. except Exception as e:
  189. logger.error(f'Error processing instance {instance["id"]}: {e}')
  190. raise
  191. finally:
  192. config.workspace_mount_path = old_workspace_mount_path
  193. config.workspace_base = old_workspace_base
  194. # Shutdown the sandbox
  195. sandbox.close()
  196. return output
  197. if __name__ == '__main__':
  198. parser = get_parser()
  199. parser.add_argument(
  200. '-s',
  201. '--eval-split',
  202. type=str,
  203. default='quarter',
  204. choices=['full', 'quarter'],
  205. help='data split to evaluate on, either full or quarter',
  206. )
  207. args, _ = parser.parse_known_args()
  208. data_split = args.eval_split
  209. # NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
  210. # so we don't need to manage file uploading to OpenDevin's repo
  211. ml_bench = load_dataset('super-dainiu/ml-bench', split=data_split).to_pandas()
  212. id_column = 'instance_id'
  213. llm_config = get_llm_config_arg(args.llm_config) if args.llm_config else config.llm
  214. logger.info(f'Config for evaluation: {config}')
  215. metadata = make_metadata(
  216. llm_config,
  217. args.dataset_name,
  218. args.agent_cls,
  219. args.max_iterations,
  220. args.eval_note,
  221. args.eval_output_dir,
  222. )
  223. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  224. instances = prepare_dataset(ml_bench, output_file, args.eval_n_limit, id_column)
  225. run_evaluation(
  226. instances,
  227. metadata,
  228. output_file,
  229. args.eval_num_workers,
  230. process_instance,
  231. id_column,
  232. )