run_infer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. import asyncio
  2. import json
  3. import logging
  4. import multiprocessing as mp
  5. import os
  6. import pathlib
  7. import re
  8. import shutil
  9. import subprocess
  10. import time
  11. from concurrent.futures import ProcessPoolExecutor
  12. import docker
  13. from datasets import load_dataset
  14. from tqdm import tqdm
  15. from evaluation.agent_bench.helper import (
  16. compare_results,
  17. create_sh_file,
  18. try_parse_answer,
  19. )
  20. from opendevin.controller.state.state import State
  21. from opendevin.core.config import args, config, get_llm_config_arg
  22. from opendevin.core.logger import get_console_handler
  23. from opendevin.core.logger import opendevin_logger as logger
  24. from opendevin.core.main import main
  25. from opendevin.events.action import CmdRunAction, MessageAction
  26. from opendevin.events.serialization.event import event_to_dict
  27. from opendevin.runtime.docker.ssh_box import DockerSSHBox
  28. def cleanup():
  29. print('Cleaning up child processes...')
  30. for process in mp.active_children():
  31. print(f'Terminating child process: {process.name}')
  32. process.terminate()
  33. process.join()
  34. def codeact_user_response(state: State) -> str:
  35. msg = (
  36. 'Please continue working on the task on whatever approach you think is suitable.\n'
  37. 'If you think you have solved the task, please first send your answer to user through '
  38. 'message and then <execute_bash> exit </execute_bash>.\n'
  39. 'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
  40. 'For example: The answer to the question is <solution> 42 </solution>.\n'
  41. 'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP.\n'
  42. )
  43. if state.history:
  44. # check if the last action is an answer, if so, return exit for early exit
  45. last_action, _ = state.history[-1]
  46. ans = try_parse_answer(last_action)
  47. if ans is not None:
  48. return '/exit'
  49. user_msgs = [
  50. action
  51. for action, _ in state.history
  52. if isinstance(action, MessageAction) and action.source == 'user'
  53. ]
  54. if len(user_msgs) >= 2:
  55. # let the agent know that it can give up when it has tried 3 times
  56. return (
  57. msg
  58. + 'If you want to give up, run: <execute_bash> exit </execute_bash>.\n'
  59. )
  60. return msg
  61. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  62. 'CodeActAgent': codeact_user_response,
  63. }
  64. AGENT_CLS_TO_INST_SUFFIX = {
  65. 'CodeActAgent': 'When you think you have solved the question, '
  66. 'please first send your answer to user through message and then exit.\n'
  67. }
  68. def process_instance(
  69. instance,
  70. agent_class,
  71. metadata,
  72. eval_output_dir,
  73. reset_logger: bool = True,
  74. ):
  75. # =============================================
  76. # preparation
  77. # =============================================
  78. inst_id = instance.instance_id
  79. question = instance.description
  80. # create a directory for the instance's workspace
  81. instance_workspace = str(os.path.join(config.workspace_base, inst_id))
  82. container_inst_workspace = str(
  83. os.path.join(config.workspace_mount_path_in_sandbox, inst_id)
  84. )
  85. if os.path.exists(instance_workspace):
  86. shutil.rmtree(instance_workspace)
  87. os.makedirs(instance_workspace, exist_ok=True)
  88. # Set up the logger properly, so you can run multiprocessing to parallel the evaluation
  89. if reset_logger:
  90. # Set up logger
  91. log_file = os.path.join(eval_output_dir, 'logs', f'instance_{inst_id}.log')
  92. # Remove all existing handlers from logger
  93. for handler in logger.handlers[:]:
  94. logger.removeHandler(handler)
  95. # add back the console handler to print ONE line
  96. logger.addHandler(get_console_handler())
  97. logger.info(
  98. f'Starting evaluation for instance {inst_id}.\nHint: run "tail -f {log_file}" to see live logs in a separate shell'
  99. )
  100. # Remove all existing handlers from logger
  101. for handler in logger.handlers[:]:
  102. logger.removeHandler(handler)
  103. file_handler = logging.FileHandler(log_file)
  104. file_handler.setFormatter(
  105. logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
  106. )
  107. logger.addHandler(file_handler)
  108. # =============================================
  109. # build instruction
  110. # =============================================
  111. # Prepare instruction
  112. instruction = (
  113. f'Please fix the following issue.\n'
  114. 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
  115. 'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
  116. 'For example: The answer to the question is <solution> 42 </solution>.\n'
  117. '# Problem \n'
  118. f'{question}\n\n'
  119. )
  120. instruction += (
  121. 'IMPORTANT: You should ONLY interact with the environment provided '
  122. 'to you AND NEVER ASK FOR HUMAN HELP.\n'
  123. )
  124. # NOTE: You can actually set slightly different instruction for different agents
  125. instruction += AGENT_CLS_TO_INST_SUFFIX.get(agent_class, '')
  126. # =============================================
  127. # create sandbox and run the agent
  128. # =============================================
  129. sandbox = DockerSSHBox()
  130. sandbox.execute(f'cd {inst_id}')
  131. init_cmd = instance.init
  132. if init_cmd is not None:
  133. scpt_name = f'{instance.instance_id}_init.sh'
  134. scpt_path = os.path.join(container_inst_workspace, scpt_name)
  135. host_scpt_path = os.path.join(instance_workspace, scpt_name)
  136. create_sh_file(host_scpt_path, init_cmd)
  137. logger.info(f'Running init script: {scpt_path}')
  138. _, init_res = sandbox.execute(scpt_path)
  139. logger.info(f'Init script result: {init_res}')
  140. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  141. state: State = asyncio.run(
  142. main(
  143. instruction,
  144. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(agent_class),
  145. sandbox=sandbox,
  146. )
  147. )
  148. if state is None:
  149. raise ValueError('State should not be None.')
  150. # get the ground truth
  151. # OSBenchSSHBox.get_ground_truth(instance, state)
  152. # =============================================
  153. # result evaluation
  154. # =============================================
  155. agent_answer = ''
  156. get_agent_result_cmd = instance.get_agent_result
  157. if get_agent_result_cmd is not None:
  158. scpt_name = f'{instance.instance_id}_get_agent_result.sh'
  159. scpt_path = os.path.join(container_inst_workspace, scpt_name)
  160. host_scpt_path = os.path.join(instance_workspace, scpt_name)
  161. create_sh_file(host_scpt_path, get_agent_result_cmd)
  162. logger.info(f'Running get agent result cmd: {scpt_path}')
  163. _, agent_answer = sandbox.execute(scpt_path)
  164. else:
  165. logger.info('Retrieving agent answer from history.')
  166. raw_ans = ''
  167. for act, _ in reversed(state.history):
  168. if isinstance(act, MessageAction) and act.source == 'agent':
  169. raw_ans = act.content
  170. break
  171. if isinstance(act, CmdRunAction) and act.source == 'agent':
  172. raw_ans = act.thought
  173. break
  174. agent_answer = re.findall(r'<solution>(.*?)</solution>', raw_ans)
  175. if len(agent_answer) == 0:
  176. logger.warning(f'Failed to parse model answer: {raw_ans}')
  177. agent_answer = raw_ans
  178. else:
  179. agent_answer = agent_answer[0]
  180. final_ans = ''
  181. if instance.ground_truth is not None:
  182. final_ans = instance.ground_truth
  183. else:
  184. get_ground_truth_cmd = instance.get_ground_truth
  185. if get_ground_truth_cmd is not None:
  186. scpt_name = f'{instance.instance_id}_get_ground_truth.sh'
  187. scpt_path = os.path.join(container_inst_workspace, scpt_name)
  188. host_scpt_path = os.path.join(instance_workspace, scpt_name)
  189. create_sh_file(host_scpt_path, get_ground_truth_cmd)
  190. logger.info(f'Running get ground truth cmd: {scpt_path}')
  191. sandbox.execute(f'cd {container_inst_workspace}')
  192. _, final_ans = sandbox.execute(scpt_path)
  193. comparison_method = instance.comparison_method
  194. logger.info(
  195. f'Final message: {agent_answer} | Ground truth: {final_ans} | Comparison method: {comparison_method}'
  196. )
  197. test_result = compare_results(comparison_method, agent_answer, final_ans)
  198. histories = [
  199. (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
  200. ]
  201. metrics = state.metrics.get() if state.metrics else None
  202. # Save the output
  203. output = {
  204. 'instance_id': inst_id,
  205. 'instance': instance.to_dict(),
  206. 'instruction': instruction,
  207. 'metadata': metadata,
  208. 'history': histories,
  209. 'metrics': metrics,
  210. 'error': state.error if state and state.error else None,
  211. 'test_result': {
  212. 'agent_answer': agent_answer,
  213. 'final_answer': final_ans,
  214. 'check_method': comparison_method,
  215. 'result': test_result,
  216. },
  217. }
  218. # clean up
  219. if os.path.exists(instance_workspace):
  220. shutil.rmtree(instance_workspace)
  221. # Close the sandbox
  222. try:
  223. sandbox.close()
  224. except docker.errors.NotFound as e:
  225. logger.error(f'Failed to close sandbox: {e}')
  226. return output
  227. if __name__ == '__main__':
  228. # =============================================
  229. # load datasets
  230. # =============================================
  231. dataset = load_dataset('iFurySt/AgentBench')
  232. agent_bench_tests = dataset['osbench'].to_pandas()
  233. logger.info(f'Loaded {len(agent_bench_tests)} tests.')
  234. # =============================================
  235. # handle arguments and prepare for evaluation
  236. # =============================================
  237. if args.llm_config:
  238. specified_llm_config = get_llm_config_arg(args.llm_config)
  239. if specified_llm_config:
  240. config.llm = specified_llm_config
  241. logger.info(f'Config for evaluation: {config}')
  242. # TEST METADATA
  243. agent_cls = args.agent_cls
  244. assert (
  245. agent_cls in AGENT_CLS_TO_FAKE_USER_RESPONSE_FN
  246. ), f'Unsupported agent class: {agent_cls}'
  247. model_name = config.llm.model.split('/')[-1]
  248. max_iterations = args.max_iterations
  249. eval_note = ''
  250. if args.eval_note is not None:
  251. eval_note += '_N_' + args.eval_note
  252. eval_op_dir = str(
  253. os.path.join(
  254. args.eval_output_dir,
  255. 'agent_bench',
  256. agent_cls,
  257. model_name + '_maxiter_' + str(max_iterations) + eval_note,
  258. )
  259. )
  260. pathlib.Path(eval_op_dir).mkdir(parents=True, exist_ok=True)
  261. pathlib.Path(str(os.path.join(eval_op_dir, 'logs'))).mkdir(
  262. parents=True, exist_ok=True
  263. )
  264. logger.info(f'Using evaluation output directory: {eval_op_dir}')
  265. meta = {
  266. 'agent_class': agent_cls,
  267. 'model_name': model_name,
  268. 'max_iterations': max_iterations,
  269. 'eval_output_dir': eval_op_dir,
  270. 'start_time': time.strftime('%Y-%m-%d %H:%M:%S'),
  271. # get the commit id of current repo for reproducibility
  272. 'git_commit': subprocess.check_output(['git', 'rev-parse', 'HEAD'])
  273. .decode('utf-8')
  274. .strip(),
  275. }
  276. logger.info(f'Metadata: {meta}')
  277. with open(os.path.join(eval_op_dir, 'metadata.json'), 'w') as f:
  278. json.dump(meta, f)
  279. # LIMIT EVALUATION
  280. eval_n_limit = args.eval_n_limit
  281. if eval_n_limit:
  282. agent_bench_tests = agent_bench_tests[:eval_n_limit]
  283. logger.info(f'Limiting evaluation to first {eval_n_limit} instances.')
  284. # OUTPUT FILE
  285. output_file = os.path.join(eval_op_dir, 'output.jsonl')
  286. logger.info(f'Writing evaluation output to {output_file}')
  287. finished_instance_ids = set()
  288. if os.path.exists(output_file):
  289. with open(output_file, 'r') as f:
  290. for line in f:
  291. data = json.loads(line)
  292. finished_instance_ids.add(data['instance_id'])
  293. logger.warning(
  294. f'Output file {output_file} already exists. Loaded {len(finished_instance_ids)} finished instances.'
  295. )
  296. output_fp = open(output_file, 'a')
  297. logger.info(
  298. f'Evaluation started with Agent {agent_cls}, model {model_name}, max iterations {max_iterations}.'
  299. )
  300. # =============================================
  301. # filter out finished instances
  302. # =============================================
  303. new_agent_bench_tests = []
  304. for idx, inst in agent_bench_tests.iterrows():
  305. if inst.instance_id in finished_instance_ids:
  306. logger.info(
  307. f'Skipping instance {inst.instance_id} as it is already finished.'
  308. )
  309. continue
  310. new_agent_bench_tests.append(inst)
  311. agent_bench_tests = new_agent_bench_tests
  312. logger.info(
  313. f'Finished instances: {len(finished_instance_ids)}, Remaining instances: {len(agent_bench_tests)}'
  314. )
  315. # =============================================
  316. # start task
  317. # =============================================
  318. pbar = tqdm(total=len(agent_bench_tests))
  319. # This function tracks the progress AND write the output to a JSONL file
  320. def update_progress(fut):
  321. pbar.update(1)
  322. output = fut.result()
  323. pbar.set_description(f'Instance {output["instance_id"]}')
  324. pbar.set_postfix_str(f'Test Result: {output["test_result"]["result"]}')
  325. logger.info(
  326. f'Finished evaluation for instance {output["instance_id"]}: {output["test_result"]["result"]}'
  327. )
  328. output_fp.write(json.dumps(output) + '\n')
  329. output_fp.flush()
  330. # This sets the multiprocessing
  331. num_workers = args.eval_num_workers
  332. logger.info(f'Using {num_workers} workers for evaluation.')
  333. try:
  334. with ProcessPoolExecutor(num_workers) as executor:
  335. futures = []
  336. # This is how we perform multiprocessing
  337. for inst in agent_bench_tests:
  338. future = executor.submit(
  339. process_instance,
  340. inst,
  341. agent_cls,
  342. meta,
  343. eval_op_dir,
  344. reset_logger=bool(num_workers > 1),
  345. )
  346. future.add_done_callback(update_progress)
  347. futures.append(future)
  348. # Wait for all futures to complete
  349. for future in futures:
  350. future.result()
  351. except KeyboardInterrupt:
  352. print('KeyboardInterrupt received. Cleaning up...')
  353. cleanup()
  354. output_fp.close()
  355. logger.info('Evaluation finished.')