run_infer.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import asyncio
  2. import functools
  3. import os
  4. import re
  5. import huggingface_hub
  6. import pandas as pd
  7. from datasets import load_dataset
  8. from evaluation.gaia.scorer import question_scorer
  9. from evaluation.utils.shared import (
  10. EvalMetadata,
  11. EvalOutput,
  12. codeact_user_response,
  13. make_metadata,
  14. prepare_dataset,
  15. reset_logger_for_multiprocessing,
  16. run_evaluation,
  17. )
  18. from openhands.controller.state.state import State
  19. from openhands.core.config import (
  20. AppConfig,
  21. SandboxConfig,
  22. get_llm_config_arg,
  23. get_parser,
  24. )
  25. from openhands.core.logger import openhands_logger as logger
  26. from openhands.core.main import create_runtime, run_controller
  27. from openhands.events.action import AgentFinishAction, CmdRunAction, MessageAction
  28. from openhands.events.observation import CmdOutputObservation
  29. from openhands.runtime.runtime import Runtime
  30. DATASET_CACHE_DIR = os.path.join(os.path.dirname(__file__), 'data')
  31. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  32. 'CodeActAgent': functools.partial(codeact_user_response, encapsulate_solution=True),
  33. }
  34. AGENT_CLS_TO_INST_SUFFIX = {
  35. 'CodeActAgent': 'When you think you have solved the question, please first send your answer to user through message and then exit.\n'
  36. }
  37. def get_config(
  38. metadata: EvalMetadata,
  39. ) -> AppConfig:
  40. config = AppConfig(
  41. default_agent=metadata.agent_class,
  42. run_as_openhands=False,
  43. runtime='eventstream',
  44. max_iterations=metadata.max_iterations,
  45. sandbox=SandboxConfig(
  46. base_container_image='python:3.11-bookworm',
  47. enable_auto_lint=True,
  48. use_host_network=False,
  49. ),
  50. # do not mount workspace
  51. workspace_base=None,
  52. workspace_mount_path=None,
  53. )
  54. config.set_llm_config(metadata.llm_config)
  55. return config
  56. def initialize_runtime(
  57. runtime: Runtime,
  58. instance: pd.Series, # this argument is not required
  59. ):
  60. """Initialize the runtime for the agent.
  61. This function is called before the runtime is used to run the agent.
  62. """
  63. logger.info(f"{'-' * 50} BEGIN Runtime Initialization Fn {'-' * 50}")
  64. obs: CmdOutputObservation
  65. action = CmdRunAction(command='mkdir -p /workspace')
  66. logger.info(action, extra={'msg_type': 'ACTION'})
  67. obs = runtime.run_action(action)
  68. assert obs.exit_code == 0
  69. if instance['file_name'] != '':
  70. # if this question comes with a file, we need to save it to the workspace
  71. assert metadata.data_split is not None
  72. src_file = os.path.join(
  73. DATASET_CACHE_DIR, '2023', metadata.data_split, instance['file_name']
  74. )
  75. assert os.path.exists(src_file)
  76. dest_file = os.path.join('/workspace', instance['file_name'])
  77. runtime.copy_to(src_file, dest_file)
  78. # rename to file.extension_name
  79. extension_name = instance['file_name'].split('.')[-1]
  80. action = CmdRunAction(
  81. command=f'mv /workspace/{instance["file_name"]} /workspace/file.{extension_name}'
  82. )
  83. logger.info(action, extra={'msg_type': 'ACTION'})
  84. obs = runtime.run_action(action)
  85. assert obs.exit_code == 0
  86. action = CmdRunAction(command='cd /workspace')
  87. logger.info(action, extra={'msg_type': 'ACTION'})
  88. obs = runtime.run_action(action)
  89. assert obs.exit_code == 0
  90. logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
  91. def process_instance(
  92. instance: pd.Series,
  93. metadata: EvalMetadata,
  94. reset_logger: bool = True,
  95. ) -> EvalOutput:
  96. config = get_config(metadata)
  97. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  98. if reset_logger:
  99. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  100. reset_logger_for_multiprocessing(logger, instance['instance_id'], log_dir)
  101. else:
  102. logger.info(f'Starting evaluation for instance {instance["instance_id"]}.')
  103. if instance['file_name'] != '':
  104. extension_name = instance['file_name'].split('.')[-1]
  105. dest_file = os.path.join('/workspace', f'file.{extension_name}')
  106. else:
  107. dest_file = None
  108. # Prepare instruction
  109. instruction = f"{instance['Question']}\n"
  110. logger.info(f'Instruction: {instruction}')
  111. if dest_file:
  112. instruction += f"\n\nThe mentioned file is provided in the workspace at: {dest_file.split('/')[-1]}"
  113. instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
  114. instruction += 'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
  115. instruction += (
  116. 'For example: The answer to the question is <solution> 42 </solution>.\n'
  117. )
  118. # NOTE: You can actually set slightly different instruction for different agents
  119. instruction += AGENT_CLS_TO_INST_SUFFIX.get(metadata.agent_class, '')
  120. logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
  121. runtime = create_runtime(config, sid=instance['instance_id'])
  122. initialize_runtime(runtime, instance)
  123. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  124. state: State | None = asyncio.run(
  125. run_controller(
  126. config=config,
  127. task_str=instruction,
  128. runtime=runtime,
  129. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
  130. metadata.agent_class
  131. ],
  132. )
  133. )
  134. # ======= Attempt to evaluate the agent's edits =======
  135. # If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  136. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  137. if state is None:
  138. raise ValueError('State should not be None.')
  139. model_answer_raw = ''
  140. # get the last message or thought from the agent
  141. for event in state.history.get_events(reverse=True):
  142. if event.source == 'agent':
  143. if isinstance(event, AgentFinishAction):
  144. model_answer_raw = event.thought
  145. break
  146. elif isinstance(event, CmdRunAction):
  147. model_answer_raw = event.thought
  148. break
  149. elif isinstance(event, MessageAction):
  150. model_answer_raw = event.content
  151. break
  152. # attempt to parse model_answer
  153. model_answer = re.findall(r'<solution>(.*?)</solution>', model_answer_raw)
  154. if len(model_answer) == 0:
  155. logger.warning(f'Failed to parse model answer: {model_answer_raw}')
  156. model_answer = model_answer_raw
  157. else:
  158. model_answer = model_answer[0]
  159. logger.info(
  160. f'Final message: {model_answer} | Ground truth: {instance["Final answer"]}'
  161. )
  162. score = question_scorer(
  163. model_answer=model_answer, ground_truth=instance['Final answer']
  164. )
  165. test_result = {
  166. 'score': score,
  167. 'model_answer_raw': model_answer_raw,
  168. 'model_answer': model_answer,
  169. 'ground_truth': instance['Final answer'],
  170. }
  171. metrics = state.metrics.get() if state.metrics else None
  172. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  173. # for compatibility with the existing output format, we can remake the pairs here
  174. # remove when it becomes unnecessary
  175. histories = state.history.compatibility_for_eval_history_pairs()
  176. # Save the output
  177. output = EvalOutput(
  178. instance_id=instance['instance_id'],
  179. instance=instance.to_dict(),
  180. instruction=instance['Question'],
  181. metadata=metadata,
  182. history=histories,
  183. metrics=metrics,
  184. error=state.last_error if state and state.last_error else None,
  185. test_result=test_result,
  186. )
  187. return output
  188. if __name__ == '__main__':
  189. parser = get_parser()
  190. parser.add_argument(
  191. '--level',
  192. type=str,
  193. help='gaia level to evaluate, eg. 2023_level1',
  194. )
  195. parser.add_argument(
  196. '--data-split',
  197. type=str,
  198. help='data split to evaluate, eg. test',
  199. default='validation',
  200. )
  201. args, _ = parser.parse_known_args()
  202. llm_config = None
  203. if args.llm_config:
  204. llm_config = get_llm_config_arg(args.llm_config)
  205. if llm_config is None:
  206. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  207. metadata = make_metadata(
  208. llm_config=llm_config,
  209. dataset_name='gaia',
  210. agent_class=args.agent_cls,
  211. max_iterations=args.max_iterations,
  212. eval_note=args.eval_note,
  213. eval_output_dir=args.eval_output_dir,
  214. data_split=args.data_split,
  215. details={'gaia-level': args.level},
  216. )
  217. dataset = load_dataset('gaia-benchmark/GAIA', args.level)
  218. huggingface_hub.snapshot_download(
  219. 'gaia-benchmark/GAIA',
  220. repo_type='dataset',
  221. local_dir=DATASET_CACHE_DIR,
  222. )
  223. gaia_tests = dataset[metadata.data_split].to_pandas()
  224. gaia_tests.rename(columns={'task_id': 'instance_id'}, inplace=True)
  225. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  226. prepared_dataset = prepare_dataset(gaia_tests, output_file, args.eval_n_limit)
  227. run_evaluation(
  228. dataset=prepared_dataset,
  229. metadata=metadata,
  230. output_file=output_file,
  231. num_workers=args.eval_num_workers,
  232. process_instance_func=process_instance,
  233. )