run_infer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. import asyncio
  2. import json
  3. import logging
  4. import multiprocessing as mp
  5. import os
  6. import pathlib
  7. import re
  8. import shutil
  9. import subprocess
  10. import time
  11. from concurrent.futures import ProcessPoolExecutor
  12. import huggingface_hub
  13. from datasets import load_dataset
  14. from tqdm import tqdm
  15. from evaluation.gaia.scorer import question_scorer
  16. from opendevin.controller.state.state import State
  17. from opendevin.core.config import config, get_llm_config_arg, get_parser
  18. from opendevin.core.logger import get_console_handler
  19. from opendevin.core.logger import opendevin_logger as logger
  20. from opendevin.core.main import main
  21. from opendevin.events.action import CmdRunAction, MessageAction
  22. from opendevin.events.serialization.event import event_to_dict
  23. DATASET_CACHE_DIR = '~/.cache/open-devin/evals/gaia'
  24. DATASET_CACHE_DIR = os.path.expanduser(DATASET_CACHE_DIR)
  25. def cleanup():
  26. logger.info('Cleaning up child processes...')
  27. for process in mp.active_children():
  28. logger.info(f'Terminating child process: {process.name}')
  29. process.terminate()
  30. process.join()
  31. def codeact_user_response(state: State) -> str:
  32. msg = (
  33. 'Please continue working on the task on whatever approach you think is suitable.\n'
  34. 'If you think you have solved the task, please first send your answer to user through message and then <execute_bash> exit </execute_bash>.\n'
  35. 'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
  36. 'For example: The answer to the question is <solution> 42 </solution>.\n'
  37. 'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP.\n'
  38. )
  39. if state.history:
  40. user_msgs = [
  41. action
  42. for action, _ in state.history
  43. if isinstance(action, MessageAction) and action.source == 'user'
  44. ]
  45. if len(user_msgs) >= 2:
  46. # let the agent know that it can give up when it has tried 3 times
  47. return (
  48. msg
  49. + 'If you want to give up, run: <execute_bash> exit </execute_bash>.\n'
  50. )
  51. return msg
  52. def monologue_user_response(state: State) -> str:
  53. raise NotImplementedError('MonologueAgent should never ask for user responses.')
  54. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  55. 'CodeActAgent': codeact_user_response,
  56. 'MonologueAgent': monologue_user_response,
  57. }
  58. AGENT_CLS_TO_INST_SUFFIX = {
  59. 'CodeActAgent': 'When you think you have solved the question, please first send your answer to user through message and then exit.\n'
  60. }
  61. def process_instance(instance, agent_class, metadata, reset_logger: bool = True):
  62. # create process-specific workspace dir
  63. # we will create a workspace directory for EACH process
  64. # so that different agent don't interfere with each other.
  65. old_workspace_mount_path = config.workspace_mount_path
  66. try:
  67. workspace_mount_path = os.path.join(
  68. config.workspace_mount_path, '_eval_workspace'
  69. )
  70. workspace_mount_path = os.path.join(workspace_mount_path, str(os.getpid()))
  71. pathlib.Path(workspace_mount_path).mkdir(parents=True, exist_ok=True)
  72. config.workspace_mount_path = workspace_mount_path
  73. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  74. eval_output_dir = metadata['eval_output_dir']
  75. if reset_logger:
  76. # Set up logger
  77. log_file = os.path.join(
  78. eval_output_dir, 'logs', f'instance_{instance["task_id"]}.log'
  79. )
  80. # Remove all existing handlers from logger
  81. for handler in logger.handlers[:]:
  82. logger.removeHandler(handler)
  83. # add back the console handler to print ONE line
  84. logger.addHandler(get_console_handler())
  85. logger.info(
  86. f'Starting evaluation for instance {instance["task_id"]}.\nLOG: tail -f {log_file}'
  87. )
  88. # Remove all existing handlers from logger
  89. for handler in logger.handlers[:]:
  90. logger.removeHandler(handler)
  91. file_handler = logging.FileHandler(log_file)
  92. file_handler.setFormatter(
  93. logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
  94. )
  95. logger.addHandler(file_handler)
  96. logger.info(f'Process-specific workspace mounted at {workspace_mount_path}')
  97. if instance['file_name'] != '':
  98. # if this question comes with a file, we need to save it to the workspace
  99. src_file = os.path.join(
  100. DATASET_CACHE_DIR, '2023', metadata['data_split'], instance['file_name']
  101. )
  102. extension_name = instance['file_name'].split('.')[-1]
  103. dest_file = os.path.join(workspace_mount_path, f'file.{extension_name}')
  104. shutil.copyfile(src_file, dest_file)
  105. logger.info(f'File copied to {dest_file}')
  106. else:
  107. dest_file = None
  108. # Prepare instruction
  109. instruction = f"{instance['Question']}\n"
  110. logger.info(f'Instruction: {instruction}')
  111. if dest_file:
  112. instruction += f"\n\nThe mentioned file is provided in the workspace at: {dest_file.split('/')[-1]}"
  113. instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
  114. instruction += 'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
  115. instruction += (
  116. 'For example: The answer to the question is <solution> 42 </solution>.\n'
  117. )
  118. # NOTE: You can actually set slightly different instruction for different agents
  119. instruction += AGENT_CLS_TO_INST_SUFFIX.get(agent_class, '')
  120. logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
  121. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  122. state: State = asyncio.run(
  123. main(
  124. instruction,
  125. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
  126. agent_class
  127. ),
  128. )
  129. )
  130. # ======= Attempt to evaluate the agent's edits =======
  131. # If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  132. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  133. if state is None:
  134. raise ValueError('State should not be None.')
  135. model_answer_raw = ''
  136. for act, _ in reversed(state.history):
  137. if isinstance(act, CmdRunAction) and act.source == 'agent':
  138. model_answer_raw = act.thought
  139. break
  140. elif isinstance(act, MessageAction) and act.source == 'agent':
  141. model_answer_raw = act.content
  142. break
  143. # attempt to parse model_answer
  144. model_answer = re.findall(r'<solution>(.*?)</solution>', model_answer_raw)
  145. if len(model_answer) == 0:
  146. logger.warning(f'Failed to parse model answer: {model_answer_raw}')
  147. model_answer = model_answer_raw
  148. else:
  149. model_answer = model_answer[0]
  150. logger.info(
  151. f'Final message: {model_answer} | Ground truth: {instance["Final answer"]}'
  152. )
  153. score = question_scorer(
  154. model_answer=model_answer, ground_truth=instance['Final answer']
  155. )
  156. test_result = {
  157. 'score': score,
  158. 'model_answer_raw': model_answer_raw,
  159. 'model_answer': model_answer,
  160. 'ground_truth': instance['Final answer'],
  161. }
  162. metrics = state.metrics.get() if state.metrics else None
  163. # Save the output
  164. output = {
  165. 'instance_id': instance['task_id'],
  166. 'instance': instance,
  167. 'instruction': instance['Question'],
  168. 'metadata': metadata,
  169. 'history': [
  170. (event_to_dict(action), event_to_dict(obs))
  171. for action, obs in state.history
  172. ],
  173. 'metrics': metrics,
  174. 'error': state.error if state and state.error else None,
  175. 'test_result': test_result,
  176. }
  177. except Exception:
  178. logger.error('Process instance failed')
  179. raise
  180. finally:
  181. config.workspace_mount_path = old_workspace_mount_path
  182. return output
  183. if __name__ == '__main__':
  184. parser = get_parser()
  185. parser.add_argument(
  186. '--level',
  187. type=str,
  188. help='gaia level to evaluate, eg. 2023_level1',
  189. )
  190. parser.add_argument(
  191. '--data-split',
  192. type=str,
  193. help='data split to evaluate, eg. validation',
  194. )
  195. args, _ = parser.parse_known_args()
  196. if args.directory:
  197. config.workspace_base = os.path.abspath(args.directory)
  198. logger.info(f'Setting workspace base to {config.workspace_base}')
  199. # NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
  200. # so we don't need to manage file uploading to OpenDevin's repo
  201. level = args.level
  202. data_split = args.data_split
  203. dataset = load_dataset('gaia-benchmark/GAIA', level)
  204. huggingface_hub.snapshot_download(
  205. 'gaia-benchmark/GAIA',
  206. repo_type='dataset',
  207. local_dir=DATASET_CACHE_DIR,
  208. )
  209. gaia_tests = dataset[data_split]
  210. logger.info(f'Evaluating GAIA-Benchmark {level} {data_split} split')
  211. # Check https://github.com/OpenDevin/OpenDevin/blob/main/evaluation/swe_bench/README.md#configure-opendevin-and-your-llm
  212. # for details of how to set `llm_config`
  213. if args.llm_config:
  214. specified_llm_config = get_llm_config_arg(args.llm_config)
  215. if specified_llm_config:
  216. config.llm = specified_llm_config
  217. logger.info(f'Config for evaluation: {config}')
  218. # TEST METADATA
  219. agent_class = args.agent_cls
  220. assert (
  221. agent_class in AGENT_CLS_TO_FAKE_USER_RESPONSE_FN
  222. ), f'Unsupported agent class: {agent_class}'
  223. model_name = config.llm.model.split('/')[-1]
  224. max_iterations = args.max_iterations
  225. eval_note = ''
  226. if args.eval_note is not None:
  227. eval_note += '_N_' + args.eval_note
  228. eval_output_dir = os.path.join(
  229. args.eval_output_dir,
  230. 'gaia',
  231. agent_class,
  232. model_name + '_maxiter_' + str(max_iterations) + eval_note,
  233. )
  234. pathlib.Path(eval_output_dir).mkdir(parents=True, exist_ok=True)
  235. pathlib.Path(os.path.join(eval_output_dir, 'logs')).mkdir(
  236. parents=True, exist_ok=True
  237. )
  238. logger.info(f'Using evaluation output directory: {eval_output_dir}')
  239. metadata = {
  240. 'gaia-level': level,
  241. 'data_split': data_split,
  242. 'agent_class': agent_class,
  243. 'model_name': model_name,
  244. 'max_iterations': max_iterations,
  245. 'eval_output_dir': eval_output_dir,
  246. 'start_time': time.strftime('%Y-%m-%d %H:%M:%S'),
  247. # get the commit id of current repo for reproducibility
  248. 'git_commit': subprocess.check_output(['git', 'rev-parse', 'HEAD'])
  249. .decode('utf-8')
  250. .strip(),
  251. }
  252. logger.info(f'Metadata: {metadata}')
  253. with open(os.path.join(eval_output_dir, 'metadata.json'), 'w') as f:
  254. json.dump(metadata, f)
  255. # LIMIT EVALUATION
  256. eval_n_limit = args.eval_n_limit
  257. if eval_n_limit:
  258. gaia_tests = gaia_tests.select(list(range(eval_n_limit)))
  259. logger.info(f'Limiting evaluation to first {eval_n_limit} instances.')
  260. # OUTPUT FILE
  261. output_file = os.path.join(eval_output_dir, 'output.jsonl')
  262. logger.info(f'Writing evaluation output to {output_file}')
  263. finished_task_ids = set()
  264. if os.path.exists(output_file):
  265. with open(output_file, 'r') as f:
  266. for line in f:
  267. data = json.loads(line)
  268. finished_task_ids.add(data['instance_id'])
  269. logger.warning(
  270. f'Output file {output_file} already exists. Loaded {len(finished_task_ids)} finished instances.'
  271. )
  272. output_fp = open(output_file, 'a')
  273. logger.info(
  274. f'Evaluation started with Agent {agent_class}, model {model_name}, max iterations {max_iterations}.'
  275. )
  276. # =============================================
  277. # filter out finished instances
  278. new_gaia_tests = []
  279. for instance in gaia_tests:
  280. if instance['task_id'] in finished_task_ids:
  281. logger.info(
  282. f'Skipping instance {instance["task_id"]} as it is already finished.'
  283. )
  284. continue
  285. new_gaia_tests.append(instance)
  286. gaia_tests = new_gaia_tests
  287. logger.info(
  288. f'Finished instances: {len(finished_task_ids)}, Remaining instances: {len(gaia_tests)}'
  289. )
  290. # =============================================
  291. pbar = tqdm(total=len(gaia_tests))
  292. # This function tracks the progress AND write the output to a JSONL file
  293. def update_progress(future):
  294. pbar.update(1)
  295. output = future.result()
  296. pbar.set_description(f'Instance {output["instance_id"]}')
  297. pbar.set_postfix_str(f'Test Result: {output["test_result"]["score"]}')
  298. logger.info(
  299. f'Finished evaluation for instance {output["instance_id"]}: {output["test_result"]}'
  300. )
  301. output_fp.write(json.dumps(output) + '\n')
  302. output_fp.flush()
  303. # This sets the multi-processing
  304. num_workers = args.eval_num_workers
  305. logger.info(f'Using {num_workers} workers for evaluation.')
  306. try:
  307. with ProcessPoolExecutor(num_workers) as executor:
  308. futures = []
  309. # This is how we perform multi-processing
  310. for instance in gaia_tests:
  311. future = executor.submit(
  312. process_instance,
  313. instance,
  314. agent_class,
  315. metadata,
  316. reset_logger=bool(num_workers > 1),
  317. )
  318. future.add_done_callback(update_progress)
  319. futures.append(future)
  320. # Wait for all futures to complete
  321. for future in futures:
  322. future.result()
  323. except KeyboardInterrupt:
  324. logger.info('KeyboardInterrupt received. Cleaning up...')
  325. cleanup()
  326. output_fp.close()
  327. logger.info('Evaluation finished.')