run_infer.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import asyncio
  2. import json
  3. import os
  4. import pandas as pd
  5. import requests
  6. from evaluation.gorilla.utils import encode_question, get_data_for_hub
  7. from evaluation.utils.shared import (
  8. EvalMetadata,
  9. EvalOutput,
  10. codeact_user_response,
  11. make_metadata,
  12. prepare_dataset,
  13. reset_logger_for_multiprocessing,
  14. run_evaluation,
  15. )
  16. from openhands.controller.state.state import State
  17. from openhands.core.config import (
  18. AppConfig,
  19. SandboxConfig,
  20. get_llm_config_arg,
  21. get_parser,
  22. )
  23. from openhands.core.logger import openhands_logger as logger
  24. from openhands.core.main import create_runtime, run_controller
  25. from openhands.events.action import MessageAction
  26. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  27. 'CodeActAgent': codeact_user_response,
  28. }
  29. AGENT_CLS_TO_INST_SUFFIX = {
  30. 'CodeActAgent': 'When you think you have completed the request, please run the following command: <execute_bash> exit </execute_bash>.\n'
  31. }
  32. def get_config(
  33. metadata: EvalMetadata,
  34. ) -> AppConfig:
  35. config = AppConfig(
  36. default_agent=metadata.agent_class,
  37. run_as_openhands=False,
  38. runtime='eventstream',
  39. max_iterations=metadata.max_iterations,
  40. sandbox=SandboxConfig(
  41. base_container_image='python:3.12-bookworm',
  42. enable_auto_lint=True,
  43. use_host_network=False,
  44. ),
  45. # do not mount workspace
  46. workspace_base=None,
  47. workspace_mount_path=None,
  48. )
  49. config.set_llm_config(metadata.llm_config)
  50. return config
  51. def process_instance(
  52. instance: pd.Series,
  53. metadata: EvalMetadata,
  54. reset_logger: bool = True,
  55. ) -> EvalOutput:
  56. config = get_config(metadata)
  57. instance_id = instance['question_id']
  58. question = instance['question']
  59. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  60. if reset_logger:
  61. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  62. reset_logger_for_multiprocessing(logger, instance_id, log_dir)
  63. else:
  64. logger.info(f'Starting evaluation for instance {instance_id}.')
  65. # Prepare instruction
  66. instruction = encode_question(question, instance['hub'])
  67. instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
  68. # NOTE: You can actually set slightly different instruction for different agents
  69. instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
  70. # logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
  71. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  72. runtime = create_runtime(config)
  73. state: State | None = asyncio.run(
  74. run_controller(
  75. config=config,
  76. initial_user_action=MessageAction(content=instruction),
  77. runtime=runtime,
  78. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
  79. metadata.agent_class
  80. ),
  81. )
  82. )
  83. # ======= Attempt to evaluate the agent's edits =======
  84. # If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  85. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  86. if state is None:
  87. raise ValueError('State should not be None.')
  88. # retrieve the last message from the agent
  89. model_answer_raw = state.history.get_last_agent_message()
  90. # attempt to parse model_answer
  91. ast_eval_fn = instance['ast_eval']
  92. correct, hallucination = ast_eval_fn(instance_id, model_answer_raw)
  93. metrics = state.metrics.get() if state.metrics else None
  94. logger.info(
  95. f'Final message: {model_answer_raw} | Correctness: {correct} | Hallucination: {hallucination}'
  96. )
  97. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  98. # for compatibility with the existing output format, we can remake the pairs here
  99. # remove when it becomes unnecessary
  100. histories = state.history.compatibility_for_eval_history_pairs()
  101. output = EvalOutput(
  102. instance_id=instance_id,
  103. metadata=metadata,
  104. history=histories,
  105. metrics=metrics,
  106. error=state.last_error if state and state.last_error else None,
  107. test_result={
  108. 'text': model_answer_raw,
  109. 'correct': correct,
  110. 'hallucination': hallucination,
  111. },
  112. )
  113. return output
  114. if __name__ == '__main__':
  115. parser = get_parser()
  116. parser.add_argument(
  117. '--hubs',
  118. type=str,
  119. help='Which hubs to evaluate from APIBench. APIBench contains 3 hubs, namely huggingface, torch, and tensorflow. You could choose one or more from hf, torch, or tf, separated by commas. For example, the default is --hub hf,torch,tf.',
  120. default='hf,torch,tf',
  121. )
  122. args, _ = parser.parse_known_args()
  123. llm_config = None
  124. if args.llm_config:
  125. llm_config = get_llm_config_arg(args.llm_config)
  126. if llm_config is None:
  127. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  128. hubs = args.hubs.split(',')
  129. if len(hubs) == 0:
  130. raise ValueError('Please choose at least one from hf, torch, and tf for hubs.')
  131. dfs = []
  132. for hub in hubs:
  133. logger.info(f'Evaluating APIBench {hub} test')
  134. df = get_data_for_hub(hub)
  135. dfs.append(df)
  136. dataset_df = pd.concat(dfs)
  137. dataset_df.rename(columns={'question_id': 'instance_id'}, inplace=True)
  138. metadata = make_metadata(
  139. llm_config=llm_config,
  140. dataset_name=f'gorilla-{hub}',
  141. agent_class=args.agent_cls,
  142. max_iterations=args.max_iterations,
  143. eval_note=args.eval_note,
  144. eval_output_dir=args.eval_output_dir,
  145. data_split=args.data_split,
  146. )
  147. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  148. dataset = prepare_dataset(
  149. dataset_df, output_file=output_file, eval_n_limit=args.eval_n_limit
  150. )
  151. file_path = os.path.join(os.path.dirname(__file__), 'my-languages.so')
  152. # Check if the file exists
  153. if not os.path.exists(file_path):
  154. url = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-scripts/codebleu/parser/my-languages.so'
  155. response = requests.get(url)
  156. with open(file_path, 'wb') as f:
  157. f.write(response.content)
  158. else:
  159. print('File already exists, skipping download.')
  160. run_evaluation(
  161. dataset=dataset,
  162. metadata=metadata,
  163. output_file=output_file,
  164. num_workers=args.eval_num_workers,
  165. process_instance_func=process_instance,
  166. )
  167. # Read the output file and calculate the accuracy
  168. total_correct = 0
  169. total_hallucination = 0
  170. output = []
  171. with open(output_file, 'r') as f:
  172. for line in f:
  173. data = json.loads(line)
  174. if data['test_result']['correct']:
  175. total_correct += 1
  176. if data['test_result']['hallucination']:
  177. total_hallucination += 1
  178. output.append(data)
  179. logger.info(
  180. f'Evaluation finished for {hub}. Total: {len(output)}; Correct: {total_correct}; Hallucination: {total_hallucination}. Accuracy: {total_correct / len(output)}'
  181. )