run_infer.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import asyncio
  2. import functools
  3. import os
  4. import re
  5. import huggingface_hub
  6. import pandas as pd
  7. from datasets import load_dataset
  8. from evaluation.benchmarks.gaia.scorer import question_scorer
  9. from evaluation.utils.shared import (
  10. EvalMetadata,
  11. EvalOutput,
  12. codeact_user_response,
  13. compatibility_for_eval_history_pairs,
  14. make_metadata,
  15. prepare_dataset,
  16. reset_logger_for_multiprocessing,
  17. run_evaluation,
  18. )
  19. from openhands.controller.state.state import State
  20. from openhands.core.config import (
  21. AppConfig,
  22. SandboxConfig,
  23. get_llm_config_arg,
  24. get_parser,
  25. )
  26. from openhands.core.logger import openhands_logger as logger
  27. from openhands.core.main import create_runtime, run_controller
  28. from openhands.events.action import AgentFinishAction, CmdRunAction, MessageAction
  29. from openhands.events.observation import CmdOutputObservation
  30. from openhands.runtime.base import Runtime
  31. from openhands.utils.async_utils import call_async_from_sync
  32. DATASET_CACHE_DIR = os.path.join(os.path.dirname(__file__), 'data')
  33. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  34. 'CodeActAgent': functools.partial(codeact_user_response, encapsulate_solution=True),
  35. }
  36. AGENT_CLS_TO_INST_SUFFIX = {
  37. 'CodeActAgent': 'When you think you have solved the question, please first send your answer to user through message and then exit.\n'
  38. }
  39. def get_config(
  40. metadata: EvalMetadata,
  41. ) -> AppConfig:
  42. config = AppConfig(
  43. default_agent=metadata.agent_class,
  44. run_as_openhands=False,
  45. runtime='eventstream',
  46. max_iterations=metadata.max_iterations,
  47. sandbox=SandboxConfig(
  48. base_container_image='python:3.12-bookworm',
  49. enable_auto_lint=True,
  50. use_host_network=False,
  51. ),
  52. # do not mount workspace
  53. workspace_base=None,
  54. workspace_mount_path=None,
  55. )
  56. config.set_llm_config(metadata.llm_config)
  57. return config
  58. def initialize_runtime(
  59. runtime: Runtime,
  60. instance: pd.Series, # this argument is not required
  61. ):
  62. """Initialize the runtime for the agent.
  63. This function is called before the runtime is used to run the agent.
  64. """
  65. logger.info(f"{'-' * 50} BEGIN Runtime Initialization Fn {'-' * 50}")
  66. obs: CmdOutputObservation
  67. action = CmdRunAction(command='mkdir -p /workspace')
  68. logger.info(action, extra={'msg_type': 'ACTION'})
  69. obs = runtime.run_action(action)
  70. assert obs.exit_code == 0
  71. if instance['file_name'] != '':
  72. # if this question comes with a file, we need to save it to the workspace
  73. assert metadata.data_split is not None
  74. src_file = os.path.join(
  75. DATASET_CACHE_DIR, '2023', metadata.data_split, instance['file_name']
  76. )
  77. assert os.path.exists(src_file)
  78. dest_file = os.path.join('/workspace', instance['file_name'])
  79. runtime.copy_to(src_file, dest_file)
  80. # rename to file.extension_name
  81. extension_name = instance['file_name'].split('.')[-1]
  82. action = CmdRunAction(
  83. command=f'mv /workspace/{instance["file_name"]} /workspace/file.{extension_name}'
  84. )
  85. logger.info(action, extra={'msg_type': 'ACTION'})
  86. obs = runtime.run_action(action)
  87. assert obs.exit_code == 0
  88. action = CmdRunAction(command='cd /workspace')
  89. logger.info(action, extra={'msg_type': 'ACTION'})
  90. obs = runtime.run_action(action)
  91. assert obs.exit_code == 0
  92. logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
  93. def process_instance(
  94. instance: pd.Series,
  95. metadata: EvalMetadata,
  96. reset_logger: bool = True,
  97. ) -> EvalOutput:
  98. config = get_config(metadata)
  99. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  100. if reset_logger:
  101. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  102. reset_logger_for_multiprocessing(logger, instance['instance_id'], log_dir)
  103. else:
  104. logger.info(f'Starting evaluation for instance {instance["instance_id"]}.')
  105. if instance['file_name'] != '':
  106. extension_name = instance['file_name'].split('.')[-1]
  107. dest_file = os.path.join('/workspace', f'file.{extension_name}')
  108. else:
  109. dest_file = None
  110. # Prepare instruction
  111. instruction = f"{instance['Question']}\n"
  112. logger.info(f'Instruction: {instruction}')
  113. if dest_file:
  114. instruction += f"\n\nThe mentioned file is provided in the workspace at: {dest_file.split('/')[-1]}"
  115. instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
  116. instruction += 'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
  117. instruction += (
  118. 'For example: The answer to the question is <solution> 42 </solution>.\n'
  119. )
  120. # NOTE: You can actually set slightly different instruction for different agents
  121. instruction += AGENT_CLS_TO_INST_SUFFIX.get(metadata.agent_class, '')
  122. logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
  123. runtime = create_runtime(config)
  124. call_async_from_sync(runtime.connect)
  125. initialize_runtime(runtime, instance)
  126. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  127. state: State | None = asyncio.run(
  128. run_controller(
  129. config=config,
  130. initial_user_action=MessageAction(content=instruction),
  131. runtime=runtime,
  132. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
  133. metadata.agent_class
  134. ],
  135. )
  136. )
  137. # ======= Attempt to evaluate the agent's edits =======
  138. # If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  139. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  140. if state is None:
  141. raise ValueError('State should not be None.')
  142. model_answer_raw = ''
  143. # get the last message or thought from the agent
  144. for event in reversed(state.history):
  145. if event.source == 'agent':
  146. if isinstance(event, AgentFinishAction):
  147. model_answer_raw = event.thought
  148. break
  149. elif isinstance(event, CmdRunAction):
  150. model_answer_raw = event.thought
  151. break
  152. elif isinstance(event, MessageAction):
  153. model_answer_raw = event.content
  154. break
  155. # attempt to parse model_answer
  156. model_answer = re.findall(r'<solution>(.*?)</solution>', model_answer_raw)
  157. if len(model_answer) == 0:
  158. logger.warning(f'Failed to parse model answer: {model_answer_raw}')
  159. model_answer = model_answer_raw
  160. else:
  161. model_answer = model_answer[0]
  162. logger.info(
  163. f'Final message: {model_answer} | Ground truth: {instance["Final answer"]}'
  164. )
  165. score = question_scorer(
  166. model_answer=model_answer, ground_truth=instance['Final answer']
  167. )
  168. test_result = {
  169. 'score': score,
  170. 'model_answer_raw': model_answer_raw,
  171. 'model_answer': model_answer,
  172. 'ground_truth': instance['Final answer'],
  173. }
  174. metrics = state.metrics.get() if state.metrics else None
  175. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  176. # for compatibility with the existing output format, we can remake the pairs here
  177. # remove when it becomes unnecessary
  178. histories = compatibility_for_eval_history_pairs(state.history)
  179. # Save the output
  180. output = EvalOutput(
  181. instance_id=instance['instance_id'],
  182. instance=instance.to_dict(),
  183. instruction=instance['Question'],
  184. metadata=metadata,
  185. history=histories,
  186. metrics=metrics,
  187. error=state.last_error if state and state.last_error else None,
  188. test_result=test_result,
  189. )
  190. return output
  191. if __name__ == '__main__':
  192. parser = get_parser()
  193. parser.add_argument(
  194. '--level',
  195. type=str,
  196. help='gaia level to evaluate, eg. 2023_level1',
  197. )
  198. parser.add_argument(
  199. '--data-split',
  200. type=str,
  201. help='data split to evaluate, eg. test',
  202. default='validation',
  203. )
  204. args, _ = parser.parse_known_args()
  205. llm_config = None
  206. if args.llm_config:
  207. llm_config = get_llm_config_arg(args.llm_config)
  208. if llm_config is None:
  209. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  210. metadata = make_metadata(
  211. llm_config=llm_config,
  212. dataset_name='gaia',
  213. agent_class=args.agent_cls,
  214. max_iterations=args.max_iterations,
  215. eval_note=args.eval_note,
  216. eval_output_dir=args.eval_output_dir,
  217. data_split=args.data_split,
  218. details={'gaia-level': args.level},
  219. )
  220. dataset = load_dataset('gaia-benchmark/GAIA', args.level)
  221. huggingface_hub.snapshot_download(
  222. 'gaia-benchmark/GAIA',
  223. repo_type='dataset',
  224. local_dir=DATASET_CACHE_DIR,
  225. )
  226. gaia_tests = dataset[metadata.data_split].to_pandas()
  227. gaia_tests.rename(columns={'task_id': 'instance_id'}, inplace=True)
  228. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  229. prepared_dataset = prepare_dataset(gaia_tests, output_file, args.eval_n_limit)
  230. run_evaluation(
  231. dataset=prepared_dataset,
  232. metadata=metadata,
  233. output_file=output_file,
  234. num_workers=args.eval_num_workers,
  235. process_instance_func=process_instance,
  236. )