run_infer.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import asyncio
  2. import json
  3. import os
  4. import pandas as pd
  5. import requests
  6. from evaluation.benchmarks.gorilla.utils import encode_question, get_data_for_hub
  7. from evaluation.utils.shared import (
  8. EvalMetadata,
  9. EvalOutput,
  10. codeact_user_response,
  11. compatibility_for_eval_history_pairs,
  12. make_metadata,
  13. prepare_dataset,
  14. reset_logger_for_multiprocessing,
  15. run_evaluation,
  16. )
  17. from openhands.controller.state.state import State
  18. from openhands.core.config import (
  19. AppConfig,
  20. SandboxConfig,
  21. get_llm_config_arg,
  22. get_parser,
  23. )
  24. from openhands.core.logger import openhands_logger as logger
  25. from openhands.core.main import create_runtime, run_controller
  26. from openhands.events.action import MessageAction
  27. from openhands.utils.async_utils import call_async_from_sync
  28. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  29. 'CodeActAgent': codeact_user_response,
  30. }
  31. AGENT_CLS_TO_INST_SUFFIX = {
  32. 'CodeActAgent': 'When you think you have completed the request, please finish the interaction using the "finish" tool.\n'
  33. }
  34. def get_config(
  35. metadata: EvalMetadata,
  36. ) -> AppConfig:
  37. config = AppConfig(
  38. default_agent=metadata.agent_class,
  39. run_as_openhands=False,
  40. runtime='eventstream',
  41. max_iterations=metadata.max_iterations,
  42. sandbox=SandboxConfig(
  43. base_container_image='python:3.12-bookworm',
  44. enable_auto_lint=True,
  45. use_host_network=False,
  46. ),
  47. # do not mount workspace
  48. workspace_base=None,
  49. workspace_mount_path=None,
  50. )
  51. config.set_llm_config(metadata.llm_config)
  52. return config
  53. def process_instance(
  54. instance: pd.Series,
  55. metadata: EvalMetadata,
  56. reset_logger: bool = True,
  57. ) -> EvalOutput:
  58. config = get_config(metadata)
  59. instance_id = instance['question_id']
  60. question = instance['question']
  61. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  62. if reset_logger:
  63. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  64. reset_logger_for_multiprocessing(logger, instance_id, log_dir)
  65. else:
  66. logger.info(f'Starting evaluation for instance {instance_id}.')
  67. # Prepare instruction
  68. instruction = encode_question(question, instance['hub'])
  69. instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
  70. # NOTE: You can actually set slightly different instruction for different agents
  71. instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
  72. # logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
  73. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  74. runtime = create_runtime(config)
  75. call_async_from_sync(runtime.connect)
  76. state: State | None = asyncio.run(
  77. run_controller(
  78. config=config,
  79. initial_user_action=MessageAction(content=instruction),
  80. runtime=runtime,
  81. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
  82. metadata.agent_class
  83. ),
  84. )
  85. )
  86. # ======= Attempt to evaluate the agent's edits =======
  87. # If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  88. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  89. if state is None:
  90. raise ValueError('State should not be None.')
  91. # retrieve the last message from the agent
  92. last_agent_message = state.get_last_agent_message()
  93. model_answer_raw = last_agent_message.content if last_agent_message else ''
  94. # attempt to parse model_answer
  95. ast_eval_fn = instance['ast_eval']
  96. correct, hallucination = ast_eval_fn(instance_id, model_answer_raw)
  97. metrics = state.metrics.get() if state.metrics else None
  98. logger.info(
  99. f'Final message: {model_answer_raw} | Correctness: {correct} | Hallucination: {hallucination}'
  100. )
  101. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  102. # for compatibility with the existing output format, we can remake the pairs here
  103. # remove when it becomes unnecessary
  104. histories = compatibility_for_eval_history_pairs(state.history)
  105. output = EvalOutput(
  106. instance_id=instance_id,
  107. metadata=metadata,
  108. history=histories,
  109. metrics=metrics,
  110. error=state.last_error if state and state.last_error else None,
  111. test_result={
  112. 'text': model_answer_raw,
  113. 'correct': correct,
  114. 'hallucination': hallucination,
  115. },
  116. )
  117. return output
  118. if __name__ == '__main__':
  119. parser = get_parser()
  120. parser.add_argument(
  121. '--hubs',
  122. type=str,
  123. help='Which hubs to evaluate from APIBench. APIBench contains 3 hubs, namely huggingface, torch, and tensorflow. You could choose one or more from hf, torch, or tf, separated by commas. For example, the default is --hub hf,torch,tf.',
  124. default='hf,torch,tf',
  125. )
  126. args, _ = parser.parse_known_args()
  127. llm_config = None
  128. if args.llm_config:
  129. llm_config = get_llm_config_arg(args.llm_config)
  130. if llm_config is None:
  131. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  132. hubs = args.hubs.split(',')
  133. if len(hubs) == 0:
  134. raise ValueError('Please choose at least one from hf, torch, and tf for hubs.')
  135. dfs = []
  136. for hub in hubs:
  137. logger.info(f'Evaluating APIBench {hub} test')
  138. df = get_data_for_hub(hub)
  139. dfs.append(df)
  140. dataset_df = pd.concat(dfs)
  141. dataset_df.rename(columns={'question_id': 'instance_id'}, inplace=True)
  142. metadata = make_metadata(
  143. llm_config=llm_config,
  144. dataset_name=f'gorilla-{hub}',
  145. agent_class=args.agent_cls,
  146. max_iterations=args.max_iterations,
  147. eval_note=args.eval_note,
  148. eval_output_dir=args.eval_output_dir,
  149. data_split=args.data_split,
  150. )
  151. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  152. dataset = prepare_dataset(
  153. dataset_df, output_file=output_file, eval_n_limit=args.eval_n_limit
  154. )
  155. file_path = os.path.join(os.path.dirname(__file__), 'my-languages.so')
  156. # Check if the file exists
  157. if not os.path.exists(file_path):
  158. url = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-scripts/codebleu/parser/my-languages.so'
  159. response = requests.get(url)
  160. with open(file_path, 'wb') as f:
  161. f.write(response.content)
  162. else:
  163. print('File already exists, skipping download.')
  164. run_evaluation(
  165. dataset=dataset,
  166. metadata=metadata,
  167. output_file=output_file,
  168. num_workers=args.eval_num_workers,
  169. process_instance_func=process_instance,
  170. )
  171. # Read the output file and calculate the accuracy
  172. total_correct = 0
  173. total_hallucination = 0
  174. output = []
  175. with open(output_file, 'r') as f:
  176. for line in f:
  177. data = json.loads(line)
  178. if data['test_result']['correct']:
  179. total_correct += 1
  180. if data['test_result']['hallucination']:
  181. total_hallucination += 1
  182. output.append(data)
  183. logger.info(
  184. f'Evaluation finished for {hub}. Total: {len(output)}; Correct: {total_correct}; Hallucination: {total_hallucination}. Accuracy: {total_correct / len(output)}'
  185. )