run_infer.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import asyncio
  2. import functools
  3. import os
  4. import re
  5. import huggingface_hub
  6. import pandas as pd
  7. from datasets import load_dataset
  8. from evaluation.gaia.scorer import question_scorer
  9. from evaluation.utils.shared import (
  10. EvalMetadata,
  11. EvalOutput,
  12. codeact_user_response,
  13. make_metadata,
  14. prepare_dataset,
  15. reset_logger_for_multiprocessing,
  16. run_evaluation,
  17. )
  18. from opendevin.controller.state.state import State
  19. from opendevin.core.config import (
  20. AppConfig,
  21. SandboxConfig,
  22. get_llm_config_arg,
  23. get_parser,
  24. )
  25. from opendevin.core.logger import opendevin_logger as logger
  26. from opendevin.core.main import create_runtime, run_controller
  27. from opendevin.events.action import AgentFinishAction, CmdRunAction, MessageAction
  28. from opendevin.events.observation import CmdOutputObservation
  29. from opendevin.runtime.runtime import Runtime
  30. DATASET_CACHE_DIR = os.path.join(os.path.dirname(__file__), 'data')
  31. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  32. 'CodeActAgent': functools.partial(codeact_user_response, encapsulate_solution=True),
  33. }
  34. AGENT_CLS_TO_INST_SUFFIX = {
  35. 'CodeActAgent': 'When you think you have solved the question, please first send your answer to user through message and then exit.\n'
  36. }
  37. def get_config(
  38. metadata: EvalMetadata,
  39. ) -> AppConfig:
  40. config = AppConfig(
  41. default_agent=metadata.agent_class,
  42. run_as_devin=False,
  43. runtime='eventstream',
  44. max_iterations=metadata.max_iterations,
  45. sandbox=SandboxConfig(
  46. container_image='python:3.11-bookworm',
  47. enable_auto_lint=True,
  48. use_host_network=False,
  49. ),
  50. # do not mount workspace
  51. workspace_base=None,
  52. workspace_mount_path=None,
  53. )
  54. config.set_llm_config(metadata.llm_config)
  55. return config
  56. async def initialize_runtime(
  57. runtime: Runtime,
  58. instance: pd.Series, # this argument is not required
  59. ):
  60. """Initialize the runtime for the agent.
  61. This function is called before the runtime is used to run the agent.
  62. """
  63. logger.info(f"{'-' * 50} BEGIN Runtime Initialization Fn {'-' * 50}")
  64. obs: CmdOutputObservation
  65. action = CmdRunAction(command='mkdir -p /workspace')
  66. logger.info(action, extra={'msg_type': 'ACTION'})
  67. obs = await runtime.run_action(action)
  68. assert obs.exit_code == 0
  69. if instance['file_name'] != '':
  70. # if this question comes with a file, we need to save it to the workspace
  71. assert metadata.data_split is not None
  72. src_file = os.path.join(
  73. DATASET_CACHE_DIR, '2023', metadata.data_split, instance['file_name']
  74. )
  75. assert os.path.exists(src_file)
  76. dest_file = os.path.join('/workspace', instance['file_name'])
  77. await runtime.copy_to(src_file, dest_file)
  78. # rename to file.extension_name
  79. extension_name = instance['file_name'].split('.')[-1]
  80. action = CmdRunAction(
  81. command=f'mv /workspace/{instance["file_name"]} /workspace/file.{extension_name}'
  82. )
  83. logger.info(action, extra={'msg_type': 'ACTION'})
  84. obs = await runtime.run_action(action)
  85. assert obs.exit_code == 0
  86. action = CmdRunAction(command='cd /workspace')
  87. logger.info(action, extra={'msg_type': 'ACTION'})
  88. obs = await runtime.run_action(action)
  89. assert obs.exit_code == 0
  90. logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
  91. async def process_instance(
  92. instance: pd.Series,
  93. metadata: EvalMetadata,
  94. reset_logger: bool = True,
  95. ) -> EvalOutput:
  96. config = get_config(metadata)
  97. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  98. if reset_logger:
  99. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  100. reset_logger_for_multiprocessing(logger, instance['instance_id'], log_dir)
  101. else:
  102. logger.info(f'Starting evaluation for instance {instance["instance_id"]}.')
  103. if instance['file_name'] != '':
  104. extension_name = instance['file_name'].split('.')[-1]
  105. dest_file = os.path.join('/workspace', f'file.{extension_name}')
  106. else:
  107. dest_file = None
  108. # Prepare instruction
  109. instruction = f"{instance['Question']}\n"
  110. logger.info(f'Instruction: {instruction}')
  111. if dest_file:
  112. instruction += f"\n\nThe mentioned file is provided in the workspace at: {dest_file.split('/')[-1]}"
  113. instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
  114. instruction += 'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
  115. instruction += (
  116. 'For example: The answer to the question is <solution> 42 </solution>.\n'
  117. )
  118. # NOTE: You can actually set slightly different instruction for different agents
  119. instruction += AGENT_CLS_TO_INST_SUFFIX.get(metadata.agent_class, '')
  120. logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
  121. runtime = await create_runtime(config, sid=instance['instance_id'])
  122. await initialize_runtime(runtime, instance)
  123. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  124. state: State | None = await run_controller(
  125. config=config,
  126. task_str=instruction,
  127. runtime=runtime,
  128. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[metadata.agent_class],
  129. )
  130. # ======= Attempt to evaluate the agent's edits =======
  131. # If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  132. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  133. if state is None:
  134. raise ValueError('State should not be None.')
  135. model_answer_raw = ''
  136. # get the last message or thought from the agent
  137. for event in state.history.get_events(reverse=True):
  138. if event.source == 'agent':
  139. if isinstance(event, AgentFinishAction):
  140. model_answer_raw = event.thought
  141. break
  142. elif isinstance(event, CmdRunAction):
  143. model_answer_raw = event.thought
  144. break
  145. elif isinstance(event, MessageAction):
  146. model_answer_raw = event.content
  147. break
  148. # attempt to parse model_answer
  149. model_answer = re.findall(r'<solution>(.*?)</solution>', model_answer_raw)
  150. if len(model_answer) == 0:
  151. logger.warning(f'Failed to parse model answer: {model_answer_raw}')
  152. model_answer = model_answer_raw
  153. else:
  154. model_answer = model_answer[0]
  155. logger.info(
  156. f'Final message: {model_answer} | Ground truth: {instance["Final answer"]}'
  157. )
  158. score = question_scorer(
  159. model_answer=model_answer, ground_truth=instance['Final answer']
  160. )
  161. test_result = {
  162. 'score': score,
  163. 'model_answer_raw': model_answer_raw,
  164. 'model_answer': model_answer,
  165. 'ground_truth': instance['Final answer'],
  166. }
  167. metrics = state.metrics.get() if state.metrics else None
  168. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  169. # for compatibility with the existing output format, we can remake the pairs here
  170. # remove when it becomes unnecessary
  171. histories = state.history.compatibility_for_eval_history_pairs()
  172. # Save the output
  173. output = EvalOutput(
  174. instance_id=instance['instance_id'],
  175. instance=instance.to_dict(),
  176. instruction=instance['Question'],
  177. metadata=metadata,
  178. history=histories,
  179. metrics=metrics,
  180. error=state.last_error if state and state.last_error else None,
  181. test_result=test_result,
  182. )
  183. return output
  184. if __name__ == '__main__':
  185. parser = get_parser()
  186. parser.add_argument(
  187. '--level',
  188. type=str,
  189. help='gaia level to evaluate, eg. 2023_level1',
  190. )
  191. parser.add_argument(
  192. '--data-split',
  193. type=str,
  194. help='data split to evaluate, eg. test',
  195. default='validation',
  196. )
  197. args, _ = parser.parse_known_args()
  198. llm_config = None
  199. if args.llm_config:
  200. llm_config = get_llm_config_arg(args.llm_config)
  201. if llm_config is None:
  202. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  203. metadata = make_metadata(
  204. llm_config=llm_config,
  205. dataset_name='gaia',
  206. agent_class=args.agent_cls,
  207. max_iterations=args.max_iterations,
  208. eval_note=args.eval_note,
  209. eval_output_dir=args.eval_output_dir,
  210. data_split=args.data_split,
  211. details={'gaia-level': args.level},
  212. )
  213. dataset = load_dataset('gaia-benchmark/GAIA', args.level)
  214. huggingface_hub.snapshot_download(
  215. 'gaia-benchmark/GAIA',
  216. repo_type='dataset',
  217. local_dir=DATASET_CACHE_DIR,
  218. )
  219. gaia_tests = dataset[metadata.data_split].to_pandas()
  220. gaia_tests.rename(columns={'task_id': 'instance_id'}, inplace=True)
  221. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  222. prepared_dataset = prepare_dataset(gaia_tests, output_file, args.eval_n_limit)
  223. asyncio.run(
  224. run_evaluation(
  225. dataset=prepared_dataset,
  226. metadata=metadata,
  227. output_file=output_file,
  228. num_workers=args.eval_num_workers,
  229. process_instance_func=process_instance,
  230. )
  231. )