run_infer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. """
  2. Implements evaluation of agents on ML-Bench, a benchmark for assessing the effectiveness of
  3. Large Language Models (LLMs) in leveraging existing functions in open-source libraries for
  4. machine learning tasks. The benchmark is introduced in the paper "ML-Bench: Evaluating Large
  5. Language Models for Code Generation in Repository-Level Machine Learning Tasks"
  6. (https://arxiv.org/abs/2311.09835).
  7. Please see https://ghcr.io/super-dainiu/ml_bench and https://huggingface.co/datasets/super-dainiu/ml-bench
  8. for more details on the dataset and docker image used in this evaluation script.
  9. TODOs:
  10. - Support additional evaluation settings, such as providing raw README content or using a
  11. retriever to extract relevant segments.
  12. - Clean up the code and docker image used for evaluation.
  13. """
  14. import asyncio
  15. import json
  16. import logging
  17. import multiprocessing as mp
  18. import os
  19. import pathlib
  20. import subprocess
  21. import time
  22. from concurrent.futures import ProcessPoolExecutor
  23. from datasets import load_dataset
  24. from tqdm import tqdm
  25. from opendevin.controller.state.state import State
  26. from opendevin.core.config import config, get_llm_config_arg, get_parser
  27. from opendevin.core.logger import get_console_handler
  28. from opendevin.core.logger import opendevin_logger as logger
  29. from opendevin.core.main import main
  30. from opendevin.events.action import MessageAction
  31. from opendevin.events.serialization.event import event_to_dict
  32. from opendevin.runtime.docker.ssh_box import DockerSSHBox
  33. def cleanup():
  34. logger.info('Cleaning up child processes...')
  35. for process in mp.active_children():
  36. logger.info(f'Terminating child process: {process.name}')
  37. process.terminate()
  38. process.join()
  39. def codeact_user_response(state: State) -> str:
  40. msg = (
  41. 'Please continue working on the task on whatever approach you think is suitable.\n'
  42. 'If you think you have completed the task, please run the following command: <execute_bash> exit </execute_bash>.\n'
  43. 'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP OR USE THE INTERNET TO SOLVE THIS TASK.\n'
  44. )
  45. if state.history:
  46. user_msgs = [
  47. action
  48. for action, _ in state.history
  49. if isinstance(action, MessageAction) and action.source == 'user'
  50. ]
  51. if len(user_msgs) >= 2:
  52. # let the agent know that it can give up when it has tried 3 times
  53. return (
  54. msg
  55. + 'If you want to give up, run: <execute_bash> exit </execute_bash>.\n'
  56. )
  57. return msg
  58. def monologue_user_response(state: State) -> str:
  59. raise NotImplementedError('MonologueAgent should never ask for user responses.')
  60. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  61. 'CodeActAgent': codeact_user_response,
  62. 'MonologueAgent': monologue_user_response,
  63. }
  64. AGENT_CLS_TO_INST_SUFFIX = {
  65. 'CodeActAgent': 'When you think you have completed the task, please run the following command: <execute_bash> exit </execute_bash>.\n'
  66. }
  67. ID2CONDA = {
  68. 1: 'dgl_DS',
  69. 2: 'bert_DS',
  70. 3: 'lavis_DS',
  71. 4: 'if_DS',
  72. 5: 'V2V_DS',
  73. 6: 'esm_DS',
  74. 7: 'OP_DS',
  75. 8: 'TSL_DS',
  76. 9: 'EAP_DS',
  77. 10: 'PG_DS',
  78. 11: 'PIM_DS',
  79. 12: 'AD2_DS',
  80. 13: 'L3_DS',
  81. 14: 'MZ2_DS',
  82. 15: 'GSA2_DS',
  83. }
  84. def process_instance(
  85. instance, agent_class, metadata, eval_output_dir, reset_logger: bool = True
  86. ):
  87. old_workspace_mount_path = config.workspace_mount_path
  88. old_workspace_base = config.workspace_base
  89. try:
  90. workspace_mount_path = os.path.join(
  91. config.workspace_mount_path, '_eval_workspace'
  92. )
  93. # create process-specific workspace dir
  94. # so that different agent don't interfere with each other.
  95. workspace_mount_path = os.path.join(workspace_mount_path, str(os.getpid()))
  96. pathlib.Path(workspace_mount_path).mkdir(parents=True, exist_ok=True)
  97. # reset workspace to config
  98. config.workspace_base = workspace_mount_path
  99. config.workspace_mount_path = workspace_mount_path
  100. # Setup the logger properly, so you can run multi-processing to parallize the evaluation
  101. if reset_logger:
  102. # Set up logger
  103. log_file = os.path.join(
  104. eval_output_dir,
  105. 'logs',
  106. f"instance_{instance['id']}_pid_{os.getpid()}.log",
  107. )
  108. # Remove all existing handlers from logger
  109. for handler in logger.handlers[:]:
  110. logger.removeHandler(handler)
  111. # add back the console handler to print ONE line
  112. logger.addHandler(get_console_handler())
  113. logger.info(
  114. f"Starting evaluation for instance {instance['id']}.\nLOG: tail -f {log_file}"
  115. )
  116. # Remove all existing handlers from logger
  117. for handler in logger.handlers[:]:
  118. logger.removeHandler(handler)
  119. file_handler = logging.FileHandler(log_file)
  120. file_handler.setFormatter(
  121. logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
  122. )
  123. logger.addHandler(file_handler)
  124. logger.info(f'Process-specific workspace mounted at {workspace_mount_path}')
  125. # Create a sandbox, using the instance ID as the session ID to avoid conflicts
  126. sandbox = DockerSSHBox(sid=str(instance['id']) + '_' + str(os.getpid()))
  127. # Set up the task environment
  128. sandbox.execute(f'conda activate {ID2CONDA[instance["github_id"]]}')
  129. # Clone the task repo into the sandbox
  130. repo_url = instance['github']
  131. repo_name = repo_url.split('/')[-1]
  132. sandbox.execute(f'git clone {repo_url} /workspace/{repo_name}')
  133. sandbox.execute(f'chmod -R 777 /workspace/{repo_name}')
  134. # Navigate to the task's code path
  135. task_path = os.path.join('/workspace', repo_name, instance['path'][2:])
  136. sandbox.execute(f'cd {task_path}')
  137. # Prepare the task instruction
  138. instruction = (
  139. f'Please complete the Machine Learning task in the following repository: {repo_name}\n\n'
  140. f'The task is: {instance["task"]}\n\n'
  141. f'{instance["instruction"]}\n\n'
  142. 'You should create a script named `run.sh` under the specified path in the repo to run the task.\n\n'
  143. f'You can find the task repo at: {task_path}\n\n'
  144. + (
  145. 'Here is the prefix code for the task:\n'
  146. '```bash\n'
  147. f'{instance["prefix_code"]}\n'
  148. '```\n\n'
  149. if instance['prefix_code']
  150. else ''
  151. )
  152. + 'You should terminate the subprocess after running the task (e.g., call subprocess.Popen(args).wait()).'
  153. )
  154. instruction += AGENT_CLS_TO_INST_SUFFIX.get(agent_class, '')
  155. # Run the agent
  156. state: State = asyncio.run(
  157. main(
  158. instruction,
  159. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
  160. agent_class
  161. ),
  162. sandbox=sandbox,
  163. )
  164. )
  165. metrics = state.metrics.get() if state.metrics else {}
  166. # Evaluate the agent's script
  167. eval_script = os.path.join(task_path, 'run.sh')
  168. logger.info(f'Running evaluation script: {eval_script}')
  169. try:
  170. _, eval_script_content = sandbox.execute(f'cat {eval_script}')
  171. except Exception as e:
  172. logger.error(f'Error reading evaluation script: {e}')
  173. eval_script_content = ''
  174. try:
  175. exit_code, eval_output = sandbox.execute(
  176. f'timeout 120s conda run -n {ID2CONDA[instance["github_id"]]} bash {eval_script}',
  177. timeout=600,
  178. )
  179. except Exception as e:
  180. logger.error(f'Error running evaluation script: {e}')
  181. exit_code = -1
  182. eval_output = ''
  183. if exit_code != 0 and exit_code != 124:
  184. logger.warning(f'Evaluation script failed with exit code {exit_code}')
  185. logger.warning(f'Output: {eval_output}')
  186. metrics['success'] = int(
  187. 'KeyboardInterrupt' in eval_output
  188. ) # super-dainiu: assume ``KeyboardInterrupt`` is a success as is done in ML-Bench
  189. else:
  190. logger.info(f'Evaluation script succeeded with exit code {exit_code}')
  191. logger.info(f'Output: {eval_output}')
  192. metrics['success'] = 1
  193. # Save the output
  194. output = {
  195. 'instance_id': instance['id'],
  196. 'repo': repo_url,
  197. 'instruction': instruction,
  198. 'metadata': metadata,
  199. 'history': [
  200. (event_to_dict(action), event_to_dict(obs))
  201. for action, obs in state.history
  202. ],
  203. 'eval_script': eval_script_content,
  204. 'eval_exit_code': exit_code,
  205. 'eval_output': eval_output,
  206. 'metrics': metrics,
  207. }
  208. except Exception as e:
  209. logger.error(f'Error processing instance {instance["id"]}: {e}')
  210. raise
  211. finally:
  212. config.workspace_mount_path = old_workspace_mount_path
  213. config.workspace_base = old_workspace_base
  214. # Shutdown the sandbox
  215. sandbox.close()
  216. return output
  217. if __name__ == '__main__':
  218. parser = get_parser()
  219. parser.add_argument(
  220. '-s',
  221. '--eval-split',
  222. type=str,
  223. default='quarter',
  224. choices=['full', 'quarter'],
  225. help='data split to evaluate on, either full or quarter',
  226. )
  227. args, _ = parser.parse_known_args()
  228. data_split = args.eval_split
  229. agent_class = args.agent_cls
  230. num_workers = args.eval_num_workers
  231. # Check https://github.com/OpenDevin/OpenDevin/blob/main/evaluation/swe_bench/README.md#configure-opendevin-and-your-llm
  232. # for details of how to set `llm_config`
  233. if args.llm_config:
  234. specified_llm_config = get_llm_config_arg(args.llm_config)
  235. if specified_llm_config:
  236. config.llm = specified_llm_config
  237. logger.info(f'Config for evaluation: {config}')
  238. # NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
  239. # so we don't need to manage file uploading to OpenDevin's repo
  240. ml_bench = load_dataset('super-dainiu/ml-bench', split=data_split).to_pandas()
  241. # LIMIT EVALUATION
  242. eval_n_limit = args.eval_n_limit
  243. if eval_n_limit:
  244. ml_bench = ml_bench.head(eval_n_limit)
  245. logger.info(f'Limiting evaluation to {eval_n_limit} instances.')
  246. # TEST METADATA
  247. model_name = config.llm.model.split('/')[-1]
  248. max_iterations = args.max_iterations
  249. eval_note = ''
  250. if args.eval_note is not None:
  251. eval_note += '_N_' + args.eval_note
  252. eval_output_dir = os.path.join(
  253. args.eval_output_dir,
  254. 'ml_bench',
  255. agent_class,
  256. model_name + '_maxiter_' + str(max_iterations) + eval_note,
  257. )
  258. os.makedirs(eval_output_dir, exist_ok=True)
  259. os.makedirs(os.path.join(eval_output_dir, 'logs'), exist_ok=True)
  260. logger.info(f'Using evaluation output directory: {eval_output_dir}')
  261. metadata = {
  262. 'agent_class': agent_class,
  263. 'model_name': model_name,
  264. 'max_iterations': max_iterations,
  265. 'eval_output_dir': eval_output_dir,
  266. 'start_time': time.strftime('%Y-%m-%d %H:%M:%S'),
  267. # get the commit id of current repo for reproduciblity
  268. 'git_commit': subprocess.check_output(['git', 'rev-parse', 'HEAD'])
  269. .decode('utf-8')
  270. .strip(),
  271. }
  272. logger.info(f'Metadata: {metadata}')
  273. output_file = os.path.join(eval_output_dir, 'output.jsonl')
  274. logger.info(f'Evaluating on data split: {data_split}')
  275. logger.info(f'Using {num_workers} worker processes')
  276. logger.info(f'Writing evaluation output to {output_file}')
  277. finished_instance_ids = set()
  278. if os.path.exists(output_file):
  279. with open(output_file, 'r') as f:
  280. for line in f:
  281. try:
  282. data = json.loads(line)
  283. except json.JSONDecodeError:
  284. print(f'Error parsing line: {line}')
  285. finished_instance_ids.add(data['instance_id'])
  286. logger.warning(
  287. f'Output file {output_file} already exists. Loaded {len(finished_instance_ids)} finished instances.'
  288. )
  289. output_fp = open(output_file, 'a')
  290. logger.info(
  291. f'Evaluation started with Agent {agent_class}, model {model_name}, data split {data_split}.'
  292. )
  293. # Filter out finished instances
  294. new_instances = [
  295. instance
  296. for _, instance in ml_bench.iterrows()
  297. if instance['id'] not in finished_instance_ids
  298. ]
  299. logger.info(
  300. f'Finished instances: {len(finished_instance_ids)}, Remaining instances: {len(new_instances)}'
  301. )
  302. pbar = tqdm(total=len(new_instances))
  303. # This function tracks the progress AND writes the output to a JSONL file
  304. def update_progress(future):
  305. pbar.update(1)
  306. output = future.result()
  307. pbar.set_description(f'Instance {output["instance_id"]}')
  308. pbar.set_postfix_str(f'Metrics: {output["metrics"]}')
  309. logger.info(
  310. f'Finished evaluation for instance {output["instance_id"]}: {output["metrics"]}'
  311. )
  312. output_fp.write(json.dumps(output) + '\n')
  313. output_fp.flush()
  314. # This sets the multi-processing
  315. num_workers = args.eval_num_workers
  316. logger.info(f'Using {num_workers} workers for evaluation.')
  317. try:
  318. with ProcessPoolExecutor(num_workers) as executor:
  319. futures = []
  320. for _, instance in enumerate(new_instances):
  321. future = executor.submit(
  322. process_instance,
  323. instance,
  324. agent_class,
  325. metadata,
  326. eval_output_dir,
  327. reset_logger=bool(num_workers > 1),
  328. )
  329. future.add_done_callback(update_progress)
  330. futures.append(future)
  331. for future in futures:
  332. output = future.result()
  333. except KeyboardInterrupt:
  334. print('KeyboardInterrupt received. Cleaning up...')
  335. cleanup()
  336. logger.info('Evaluation completed.')