run_infer.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import asyncio
  2. import os
  3. import pandas as pd
  4. from datasets import load_dataset
  5. from evaluation.benchmarks.EDA.game import Q20Game, Q20GameCelebrity
  6. from evaluation.utils.shared import (
  7. EvalMetadata,
  8. EvalOutput,
  9. compatibility_for_eval_history_pairs,
  10. make_metadata,
  11. prepare_dataset,
  12. reset_logger_for_multiprocessing,
  13. run_evaluation,
  14. )
  15. from openhands.controller.state.state import State
  16. from openhands.core.config import (
  17. AppConfig,
  18. SandboxConfig,
  19. get_llm_config_arg,
  20. get_parser,
  21. )
  22. from openhands.core.logger import openhands_logger as logger
  23. from openhands.core.main import create_runtime, run_controller
  24. from openhands.events.action import MessageAction
  25. from openhands.utils.async_utils import call_async_from_sync
  26. game = None
  27. def codeact_user_response_eda(state: State) -> str:
  28. global game
  29. model_guess = ''
  30. # retrieve the latest model message from history
  31. if state.history:
  32. last_agent_message = state.get_last_agent_message()
  33. model_guess = last_agent_message.content if last_agent_message else ''
  34. assert game is not None, 'Game is not initialized.'
  35. msg = game.generate_user_response(model_guess)
  36. game.curr_turn += 1
  37. logger.info(f'Model guess: {model_guess}')
  38. logger.info(f'Answer response: {msg}')
  39. if 'bingo!' in msg.lower():
  40. return '/exit'
  41. return msg
  42. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  43. 'CodeActAgent': codeact_user_response_eda,
  44. }
  45. AGENT_CLS_TO_INST_SUFFIX = {
  46. 'CodeActAgent': 'When you think you have solved the question, please first send your answer to user through message and then exit.\n'
  47. }
  48. def get_config(
  49. metadata: EvalMetadata,
  50. ) -> AppConfig:
  51. config = AppConfig(
  52. default_agent=metadata.agent_class,
  53. run_as_openhands=False,
  54. runtime='eventstream',
  55. max_iterations=metadata.max_iterations,
  56. sandbox=SandboxConfig(
  57. base_container_image='python:3.12-bookworm',
  58. enable_auto_lint=False,
  59. use_host_network=False,
  60. ),
  61. # do not mount workspace
  62. workspace_base=None,
  63. workspace_mount_path=None,
  64. )
  65. config.set_llm_config(metadata.llm_config)
  66. return config
  67. def process_instance(
  68. instance: pd.Series,
  69. metadata: EvalMetadata,
  70. reset_logger: bool = True,
  71. ) -> EvalOutput:
  72. config = get_config(metadata)
  73. instance_id = instance['text'].strip()
  74. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  75. if reset_logger:
  76. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  77. reset_logger_for_multiprocessing(logger, instance_id, log_dir)
  78. else:
  79. logger.info(f'Starting evaluation for instance {instance_id}.')
  80. # Prepare instruction
  81. _game_class = {'eda-things': Q20Game, 'eda-celebs': Q20GameCelebrity}
  82. guesser_kargs = {
  83. 'max_new_tokens': 64,
  84. 'temperature': 0.8,
  85. 'repetition_penalty': 1.0,
  86. 'do_sample': True,
  87. } # no penalty
  88. # Use codeactagent as guesser_model
  89. global game
  90. assert metadata.dataset is not None
  91. assert metadata.details is not None
  92. game = _game_class[metadata.dataset](
  93. item=instance['text'].strip(),
  94. answerer_model=metadata.details['answerer_model'],
  95. guesser_model=None,
  96. num_turns=metadata.max_iterations,
  97. openai_api_key=metadata.details['openai_api_key'],
  98. guesser_kargs=guesser_kargs,
  99. )
  100. instruction = f'{game.first_user_utterance}'
  101. logger.info(f'Instruction: {instruction}')
  102. instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
  103. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  104. runtime = create_runtime(config)
  105. call_async_from_sync(runtime.connect)
  106. state: State | None = asyncio.run(
  107. run_controller(
  108. config=config,
  109. initial_user_action=MessageAction(content=instruction),
  110. runtime=runtime,
  111. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
  112. metadata.agent_class
  113. ],
  114. )
  115. )
  116. # ======= Attempt to evaluate the agent's edits =======
  117. # If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  118. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  119. if state is None:
  120. raise ValueError('State should not be None.')
  121. last_agent_message = state.get_last_agent_message()
  122. final_message = last_agent_message.content if last_agent_message else ''
  123. logger.info(f'Final message: {final_message} | Ground truth: {instance["text"]}')
  124. test_result = game.reward()
  125. metrics = state.metrics.get() if state.metrics else None
  126. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  127. # for compatibility with the existing output format, we can remake the pairs here
  128. # remove when it becomes unnecessary
  129. histories = compatibility_for_eval_history_pairs(state.history)
  130. # Save the output
  131. output = EvalOutput(
  132. instance_id=instance_id,
  133. instance=instance.to_dict(),
  134. instruction=instruction,
  135. metadata=metadata,
  136. history=histories,
  137. metrics=metrics,
  138. error=state.last_error if state and state.last_error else None,
  139. test_result={
  140. 'success': test_result,
  141. 'final_message': final_message,
  142. 'ground_truth': instance['text'],
  143. },
  144. )
  145. return output
  146. if __name__ == '__main__':
  147. parser = get_parser()
  148. parser.add_argument(
  149. '--answerer_model', '-a', default='gpt-3.5-turbo', help='answerer model'
  150. )
  151. parser.add_argument(
  152. '--dataset',
  153. default='things',
  154. choices=['things', 'celebs'],
  155. type=str,
  156. help='dataset to be used',
  157. )
  158. parser.add_argument(
  159. '--OPENAI_API_KEY', type=str, required=True, help='Your OpenAI API key'
  160. )
  161. parser.add_argument(
  162. '--data-split',
  163. default='test',
  164. type=str,
  165. help='data split, eg, test',
  166. )
  167. args, _ = parser.parse_known_args()
  168. eda_dataset = load_dataset(
  169. 'yizheapple/entity-deduction-arena', name=args.dataset, split=args.data_split
  170. )
  171. eda_dataset.rename(columns={'text': 'instance_id'}, inplace=True)
  172. llm_config = None
  173. if args.llm_config:
  174. llm_config = get_llm_config_arg(args.llm_config)
  175. if llm_config is None:
  176. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  177. metadata = make_metadata(
  178. llm_config,
  179. f'eda-{args.dataset}',
  180. args.agent_cls,
  181. args.max_iterations,
  182. args.eval_note,
  183. args.eval_output_dir,
  184. data_split=args.data_split,
  185. details={
  186. 'answerer_model': str(args.answerer_model),
  187. 'openai_api_key': str(args.OPENAI_API_KEY),
  188. },
  189. )
  190. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  191. prepared_dataset = prepare_dataset(
  192. eda_dataset.to_pandas(), output_file, args.eval_n_limit
  193. )
  194. run_evaluation(
  195. prepared_dataset,
  196. metadata,
  197. output_file,
  198. args.eval_num_workers,
  199. process_instance,
  200. )