run_infer.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import asyncio
  2. import os
  3. import pandas as pd
  4. from datasets import load_dataset
  5. from evaluation.EDA.game import Q20Game, Q20GameCelebrity
  6. from evaluation.utils.shared import (
  7. EvalMetadata,
  8. EvalOutput,
  9. compatibility_for_eval_history_pairs,
  10. make_metadata,
  11. prepare_dataset,
  12. reset_logger_for_multiprocessing,
  13. run_evaluation,
  14. )
  15. from openhands.controller.state.state import State
  16. from openhands.core.config import (
  17. AppConfig,
  18. SandboxConfig,
  19. get_llm_config_arg,
  20. get_parser,
  21. )
  22. from openhands.core.logger import openhands_logger as logger
  23. from openhands.core.main import create_runtime, run_controller
  24. from openhands.events.action import MessageAction
  25. from openhands.utils.async_utils import call_async_from_sync
  26. game = None
  27. def codeact_user_response_eda(state: State) -> str:
  28. global game
  29. model_guess = ''
  30. # retrieve the latest model message from history
  31. if state.history:
  32. model_guess = state.get_last_agent_message()
  33. assert game is not None, 'Game is not initialized.'
  34. msg = game.generate_user_response(model_guess)
  35. game.curr_turn += 1
  36. logger.info(f'Model guess: {model_guess}')
  37. logger.info(f'Answer response: {msg}')
  38. if 'bingo!' in msg.lower():
  39. return '/exit'
  40. return msg
  41. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  42. 'CodeActAgent': codeact_user_response_eda,
  43. }
  44. AGENT_CLS_TO_INST_SUFFIX = {
  45. 'CodeActAgent': 'When you think you have solved the question, please first send your answer to user through message and then exit.\n'
  46. }
  47. def get_config(
  48. metadata: EvalMetadata,
  49. ) -> AppConfig:
  50. config = AppConfig(
  51. default_agent=metadata.agent_class,
  52. run_as_openhands=False,
  53. runtime='eventstream',
  54. max_iterations=metadata.max_iterations,
  55. sandbox=SandboxConfig(
  56. base_container_image='python:3.12-bookworm',
  57. enable_auto_lint=False,
  58. use_host_network=False,
  59. ),
  60. # do not mount workspace
  61. workspace_base=None,
  62. workspace_mount_path=None,
  63. )
  64. config.set_llm_config(metadata.llm_config)
  65. return config
  66. def process_instance(
  67. instance: pd.Series,
  68. metadata: EvalMetadata,
  69. reset_logger: bool = True,
  70. ) -> EvalOutput:
  71. config = get_config(metadata)
  72. instance_id = instance['text'].strip()
  73. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  74. if reset_logger:
  75. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  76. reset_logger_for_multiprocessing(logger, instance_id, log_dir)
  77. else:
  78. logger.info(f'Starting evaluation for instance {instance_id}.')
  79. # Prepare instruction
  80. _game_class = {'eda-things': Q20Game, 'eda-celebs': Q20GameCelebrity}
  81. guesser_kargs = {
  82. 'max_new_tokens': 64,
  83. 'temperature': 0.8,
  84. 'repetition_penalty': 1.0,
  85. 'do_sample': True,
  86. } # no penalty
  87. # Use codeactagent as guesser_model
  88. global game
  89. assert metadata.dataset is not None
  90. assert metadata.details is not None
  91. game = _game_class[metadata.dataset](
  92. item=instance['text'].strip(),
  93. answerer_model=metadata.details['answerer_model'],
  94. guesser_model=None,
  95. num_turns=metadata.max_iterations,
  96. openai_api_key=metadata.details['openai_api_key'],
  97. guesser_kargs=guesser_kargs,
  98. )
  99. instruction = f'{game.first_user_utterance}'
  100. logger.info(f'Instruction: {instruction}')
  101. instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
  102. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  103. runtime = create_runtime(config)
  104. call_async_from_sync(runtime.connect)
  105. state: State | None = asyncio.run(
  106. run_controller(
  107. config=config,
  108. initial_user_action=MessageAction(content=instruction),
  109. runtime=runtime,
  110. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
  111. metadata.agent_class
  112. ],
  113. )
  114. )
  115. # ======= Attempt to evaluate the agent's edits =======
  116. # If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  117. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  118. if state is None:
  119. raise ValueError('State should not be None.')
  120. final_message = state.get_last_agent_message()
  121. logger.info(f'Final message: {final_message} | Ground truth: {instance["text"]}')
  122. test_result = game.reward()
  123. metrics = state.metrics.get() if state.metrics else None
  124. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  125. # for compatibility with the existing output format, we can remake the pairs here
  126. # remove when it becomes unnecessary
  127. histories = compatibility_for_eval_history_pairs(state.history)
  128. # Save the output
  129. output = EvalOutput(
  130. instance_id=instance_id,
  131. instance=instance.to_dict(),
  132. instruction=instruction,
  133. metadata=metadata,
  134. history=histories,
  135. metrics=metrics,
  136. error=state.last_error if state and state.last_error else None,
  137. test_result={
  138. 'success': test_result,
  139. 'final_message': final_message,
  140. 'ground_truth': instance['text'],
  141. },
  142. )
  143. return output
  144. if __name__ == '__main__':
  145. parser = get_parser()
  146. parser.add_argument(
  147. '--answerer_model', '-a', default='gpt-3.5-turbo', help='answerer model'
  148. )
  149. parser.add_argument(
  150. '--dataset',
  151. default='things',
  152. choices=['things', 'celebs'],
  153. type=str,
  154. help='dataset to be used',
  155. )
  156. parser.add_argument(
  157. '--OPENAI_API_KEY', type=str, required=True, help='Your OpenAI API key'
  158. )
  159. parser.add_argument(
  160. '--data-split',
  161. default='test',
  162. type=str,
  163. help='data split, eg, test',
  164. )
  165. args, _ = parser.parse_known_args()
  166. eda_dataset = load_dataset(
  167. 'yizheapple/entity-deduction-arena', name=args.dataset, split=args.data_split
  168. )
  169. eda_dataset.rename(columns={'text': 'instance_id'}, inplace=True)
  170. llm_config = None
  171. if args.llm_config:
  172. llm_config = get_llm_config_arg(args.llm_config)
  173. if llm_config is None:
  174. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  175. metadata = make_metadata(
  176. llm_config,
  177. f'eda-{args.dataset}',
  178. args.agent_cls,
  179. args.max_iterations,
  180. args.eval_note,
  181. args.eval_output_dir,
  182. data_split=args.data_split,
  183. details={
  184. 'answerer_model': str(args.answerer_model),
  185. 'openai_api_key': str(args.OPENAI_API_KEY),
  186. },
  187. )
  188. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  189. prepared_dataset = prepare_dataset(
  190. eda_dataset.to_pandas(), output_file, args.eval_n_limit
  191. )
  192. run_evaluation(
  193. prepared_dataset,
  194. metadata,
  195. output_file,
  196. args.eval_num_workers,
  197. process_instance,
  198. )