run_infer.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import asyncio
  2. import os
  3. import pandas as pd
  4. from datasets import load_dataset
  5. from evaluation.EDA.game import Q20Game, Q20GameCelebrity
  6. from evaluation.utils.shared import (
  7. EvalMetadata,
  8. EvalOutput,
  9. make_metadata,
  10. prepare_dataset,
  11. reset_logger_for_multiprocessing,
  12. run_evaluation,
  13. )
  14. from openhands.controller.state.state import State
  15. from openhands.core.config import (
  16. AppConfig,
  17. SandboxConfig,
  18. get_llm_config_arg,
  19. get_parser,
  20. )
  21. from openhands.core.logger import openhands_logger as logger
  22. from openhands.core.main import create_runtime, run_controller
  23. game = None
  24. def codeact_user_response_eda(state: State) -> str:
  25. global game
  26. model_guess = ''
  27. # retrieve the latest model message from history
  28. if state.history:
  29. model_guess = state.history.get_last_agent_message()
  30. assert game is not None, 'Game is not initialized.'
  31. msg = game.generate_user_response(model_guess)
  32. game.curr_turn += 1
  33. logger.info(f'Model guess: {model_guess}')
  34. logger.info(f'Answer response: {msg}')
  35. if 'bingo!' in msg.lower():
  36. return '/exit'
  37. return msg
  38. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  39. 'CodeActAgent': codeact_user_response_eda,
  40. }
  41. AGENT_CLS_TO_INST_SUFFIX = {
  42. 'CodeActAgent': 'When you think you have solved the question, please first send your answer to user through message and then exit.\n'
  43. }
  44. def get_config(
  45. metadata: EvalMetadata,
  46. ) -> AppConfig:
  47. config = AppConfig(
  48. default_agent=metadata.agent_class,
  49. run_as_openhands=False,
  50. runtime='eventstream',
  51. max_iterations=metadata.max_iterations,
  52. sandbox=SandboxConfig(
  53. base_container_image='python:3.12-bookworm',
  54. enable_auto_lint=False,
  55. use_host_network=False,
  56. ),
  57. # do not mount workspace
  58. workspace_base=None,
  59. workspace_mount_path=None,
  60. )
  61. config.set_llm_config(metadata.llm_config)
  62. return config
  63. def process_instance(
  64. instance: pd.Series,
  65. metadata: EvalMetadata,
  66. reset_logger: bool = True,
  67. ) -> EvalOutput:
  68. config = get_config(metadata)
  69. instance_id = instance['text'].strip()
  70. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  71. if reset_logger:
  72. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  73. reset_logger_for_multiprocessing(logger, instance_id, log_dir)
  74. else:
  75. logger.info(f'Starting evaluation for instance {instance_id}.')
  76. # Prepare instruction
  77. _game_class = {'eda-things': Q20Game, 'eda-celebs': Q20GameCelebrity}
  78. guesser_kargs = {
  79. 'max_new_tokens': 64,
  80. 'temperature': 0.8,
  81. 'repetition_penalty': 1.0,
  82. 'do_sample': True,
  83. } # no penalty
  84. # Use codeactagent as guesser_model
  85. global game
  86. assert metadata.dataset is not None
  87. assert metadata.details is not None
  88. game = _game_class[metadata.dataset](
  89. item=instance['text'].strip(),
  90. answerer_model=metadata.details['answerer_model'],
  91. guesser_model=None,
  92. num_turns=metadata.max_iterations,
  93. openai_api_key=metadata.details['openai_api_key'],
  94. guesser_kargs=guesser_kargs,
  95. )
  96. instruction = f'{game.first_user_utterance}'
  97. logger.info(f'Instruction: {instruction}')
  98. instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
  99. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  100. runtime = create_runtime(config, sid=instance['text'].strip())
  101. state: State | None = asyncio.run(
  102. run_controller(
  103. config=config,
  104. task_str=instruction,
  105. runtime=runtime,
  106. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
  107. metadata.agent_class
  108. ],
  109. )
  110. )
  111. # ======= Attempt to evaluate the agent's edits =======
  112. # If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  113. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  114. if state is None:
  115. raise ValueError('State should not be None.')
  116. final_message = state.history.get_last_agent_message()
  117. logger.info(f'Final message: {final_message} | Ground truth: {instance["text"]}')
  118. test_result = game.reward()
  119. metrics = state.metrics.get() if state.metrics else None
  120. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  121. # for compatibility with the existing output format, we can remake the pairs here
  122. # remove when it becomes unnecessary
  123. histories = state.history.compatibility_for_eval_history_pairs()
  124. # Save the output
  125. output = EvalOutput(
  126. instance_id=instance_id,
  127. instance=instance.to_dict(),
  128. instruction=instruction,
  129. metadata=metadata,
  130. history=histories,
  131. metrics=metrics,
  132. error=state.last_error if state and state.last_error else None,
  133. test_result={
  134. 'success': test_result,
  135. 'final_message': final_message,
  136. 'ground_truth': instance['text'],
  137. },
  138. )
  139. return output
  140. if __name__ == '__main__':
  141. parser = get_parser()
  142. parser.add_argument(
  143. '--answerer_model', '-a', default='gpt-3.5-turbo', help='answerer model'
  144. )
  145. parser.add_argument(
  146. '--dataset',
  147. default='things',
  148. choices=['things', 'celebs'],
  149. type=str,
  150. help='dataset to be used',
  151. )
  152. parser.add_argument(
  153. '--OPENAI_API_KEY', type=str, required=True, help='Your OpenAI API key'
  154. )
  155. parser.add_argument(
  156. '--data-split',
  157. default='test',
  158. type=str,
  159. help='data split, eg, test',
  160. )
  161. args, _ = parser.parse_known_args()
  162. eda_dataset = load_dataset(
  163. 'yizheapple/entity-deduction-arena', name=args.dataset, split=args.data_split
  164. )
  165. eda_dataset.rename(columns={'text': 'instance_id'}, inplace=True)
  166. llm_config = None
  167. if args.llm_config:
  168. llm_config = get_llm_config_arg(args.llm_config)
  169. if llm_config is None:
  170. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  171. metadata = make_metadata(
  172. llm_config,
  173. f'eda-{args.dataset}',
  174. args.agent_cls,
  175. args.max_iterations,
  176. args.eval_note,
  177. args.eval_output_dir,
  178. data_split=args.data_split,
  179. details={
  180. 'answerer_model': str(args.answerer_model),
  181. 'openai_api_key': str(args.OPENAI_API_KEY),
  182. },
  183. )
  184. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  185. prepared_dataset = prepare_dataset(
  186. eda_dataset.to_pandas(), output_file, args.eval_n_limit
  187. )
  188. run_evaluation(
  189. prepared_dataset,
  190. metadata,
  191. output_file,
  192. args.eval_num_workers,
  193. process_instance,
  194. )