run_infer.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import asyncio
  2. import logging
  3. import os
  4. import pandas as pd
  5. # import huggingface_hub
  6. from datasets import load_dataset
  7. from evaluation.EDA.game import Q20Game, Q20GameCelebrity
  8. from evaluation.utils.shared import (
  9. EvalMetadata,
  10. make_metadata,
  11. prepare_dataset,
  12. run_evaluation,
  13. )
  14. from opendevin.controller.agent import Agent
  15. # from evaluation.EDA.scorer import question_scorer
  16. from opendevin.controller.state.state import State
  17. from opendevin.core.config import get_llm_config_arg, get_parser, load_app_config
  18. from opendevin.core.logger import get_console_handler
  19. from opendevin.core.logger import opendevin_logger as logger
  20. from opendevin.core.main import run_controller
  21. from opendevin.llm.llm import LLM
  22. config = load_app_config()
  23. game = None
  24. def codeact_user_response_eda(state: State) -> str:
  25. global game
  26. model_guess = ''
  27. # retrieve the latest model message from history
  28. if state.history:
  29. model_guess = state.history.get_last_agent_message()
  30. assert game is not None, 'Game is not initialized.'
  31. msg = game.generate_user_response(model_guess)
  32. game.curr_turn += 1
  33. logger.info(f'Model guess: {model_guess}')
  34. logger.info(f'Answer response: {msg}')
  35. if 'bingo!' in msg.lower():
  36. return '/exit'
  37. return msg
  38. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  39. 'CodeActAgent': codeact_user_response_eda,
  40. }
  41. AGENT_CLS_TO_INST_SUFFIX = {
  42. 'CodeActAgent': 'When you think you have solved the question, please first send your answer to user through message and then exit.\n'
  43. }
  44. def process_instance(
  45. instance: pd.Series,
  46. metadata: EvalMetadata,
  47. reset_logger: bool = True,
  48. ):
  49. # Create the agent
  50. agent = Agent.get_cls(metadata.agent_class)(llm=LLM(config=metadata.llm_config))
  51. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  52. eval_output_dir = metadata.eval_output_dir
  53. if reset_logger:
  54. # Set up logger
  55. log_file = os.path.join(
  56. eval_output_dir, 'logs', f'instance_{instance["text"].strip()}.log'
  57. )
  58. # Remove all existing handlers from logger
  59. for handler in logger.handlers[:]:
  60. logger.removeHandler(handler)
  61. # add back the console handler to print ONE line
  62. logger.addHandler(get_console_handler())
  63. logger.info(
  64. f'Starting evaluation for instance {instance["text"].strip()}.\nLOG: tail -f {log_file}'
  65. )
  66. # Remove all existing handlers from logger
  67. for handler in logger.handlers[:]:
  68. logger.removeHandler(handler)
  69. file_handler = logging.FileHandler(log_file)
  70. file_handler.setFormatter(
  71. logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
  72. )
  73. logger.addHandler(file_handler)
  74. # Prepare instruction
  75. _game_class = {'things': Q20Game, 'celebs': Q20GameCelebrity}
  76. guesser_kargs = {
  77. 'max_new_tokens': 64,
  78. 'temperature': 0.8,
  79. 'repetition_penalty': 1.0,
  80. 'do_sample': True,
  81. } # no penalty
  82. # Use codeactagent as guesser_model
  83. global game
  84. assert metadata.dataset is not None
  85. assert metadata.details is not None
  86. game = _game_class[metadata.dataset](
  87. item=instance['text'].strip(),
  88. answerer_model=metadata.details['answerer_model'],
  89. guesser_model=None,
  90. num_turns=metadata.max_iterations,
  91. openai_api_key=metadata.details['openai_api_key'],
  92. guesser_kargs=guesser_kargs,
  93. )
  94. instruction = f'{game.first_user_utterance}'
  95. logger.info(f'Instruction: {instruction}')
  96. # instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
  97. # NOTE: You can actually set slightly different instruction for different agents
  98. instruction += AGENT_CLS_TO_INST_SUFFIX[agent.__class__.__name__]
  99. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  100. config.max_iterations = metadata.max_iterations
  101. state: State | None = asyncio.run(
  102. run_controller(
  103. config=config,
  104. task_str=instruction,
  105. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
  106. agent.__class__.__name__
  107. ],
  108. agent=agent,
  109. sid=instance['text'].strip(),
  110. )
  111. )
  112. # ======= Attempt to evaluate the agent's edits =======
  113. # If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  114. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  115. if state is None:
  116. raise ValueError('State should not be None.')
  117. final_message = state.history.get_last_agent_message()
  118. logger.info(f'Final message: {final_message} | Ground truth: {instance["text"]}')
  119. test_result = game.reward()
  120. metrics = state.metrics.get() if state.metrics else None
  121. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  122. # for compatibility with the existing output format, we can remake the pairs here
  123. # remove when it becomes unnecessary
  124. histories = state.history.compatibility_for_eval_history_pairs()
  125. # Save the output
  126. output = {
  127. 'instance_id': instance['text'].strip(),
  128. 'instance': instance,
  129. 'instruction': instruction,
  130. 'metadata': metadata.model_dump(),
  131. 'history': histories,
  132. 'metrics': metrics,
  133. 'error': state.last_error if state and state.last_error else None,
  134. 'test_result': {
  135. 'success': test_result,
  136. 'final_message': final_message,
  137. 'ground_truth': instance['text'],
  138. },
  139. }
  140. return output
  141. if __name__ == '__main__':
  142. parser = get_parser()
  143. parser.add_argument(
  144. '--answerer_model', '-a', default='gpt-3.5-turbo', help='answerer model'
  145. )
  146. parser.add_argument(
  147. '--dataset',
  148. default='things',
  149. choices=['things', 'celebs'],
  150. type=str,
  151. help='dataset to be used',
  152. )
  153. parser.add_argument(
  154. '--OPENAI_API_KEY', type=str, required=True, help='Your OpenAI API key'
  155. )
  156. parser.add_argument(
  157. '--data-split',
  158. default='test',
  159. type=str,
  160. help='data split, eg, test',
  161. )
  162. args, _ = parser.parse_known_args()
  163. llm_config = get_llm_config_arg(args.llm_config) if args.llm_config else config.llm
  164. logger.info(f'Config for evaluation: {config}')
  165. eda_dataset = load_dataset(
  166. 'yizheapple/entity-deduction-arena', name=args.dataset, split=args.data_split
  167. )
  168. metadata = make_metadata(
  169. llm_config,
  170. f'eda-{args.dataset}',
  171. args.agent_cls,
  172. args.max_iterations,
  173. args.eval_note,
  174. args.eval_output_dir,
  175. data_split=args.data_split,
  176. details={
  177. 'answerer_model': str(args.answerer_model),
  178. 'openai_api_key': str(args.OPENAI_API_KEY),
  179. },
  180. )
  181. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  182. prepared_dataset = prepare_dataset(
  183. eda_dataset.to_pandas(), output_file, args.eval_n_limit, 'text'
  184. )
  185. agent = Agent.get_cls(args.agent_cls)(llm=LLM(config.llm))
  186. run_evaluation(
  187. prepared_dataset,
  188. metadata,
  189. output_file,
  190. args.eval_num_workers,
  191. process_instance,
  192. 'text',
  193. )