run_infer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. import asyncio
  2. import json
  3. import logging
  4. import multiprocessing as mp
  5. import os
  6. import pathlib
  7. import subprocess
  8. import time
  9. from concurrent.futures import ProcessPoolExecutor
  10. # import huggingface_hub
  11. from datasets import load_dataset
  12. from tqdm import tqdm
  13. from evaluation.EDA.game import Q20Game, Q20GameCelebrity
  14. # from evaluation.EDA.scorer import question_scorer
  15. from opendevin.controller.state.state import State
  16. from opendevin.core.config import config, get_llm_config_arg, get_parser
  17. from opendevin.core.logger import get_console_handler
  18. from opendevin.core.logger import opendevin_logger as logger
  19. from opendevin.core.main import main
  20. from opendevin.events.action import MessageAction
  21. from opendevin.events.serialization.event import event_to_dict
  22. game = None
  23. def cleanup():
  24. print('Cleaning up child processes...')
  25. for process in mp.active_children():
  26. print(f'Terminating child process: {process.name}')
  27. process.terminate()
  28. process.join()
  29. def codeact_user_response(state: State) -> str:
  30. global game
  31. model_guess = ''
  32. if state.history:
  33. for act, _ in reversed(state.history):
  34. if isinstance(act, MessageAction) and act.source == 'agent':
  35. model_guess = act.content
  36. break
  37. msg = game.generate_user_response(model_guess)
  38. game.curr_turn += 1
  39. logger.info(f'Model guess: {model_guess}')
  40. logger.info(f'Answer response: {msg}')
  41. if 'bingo!' in msg.lower():
  42. return '/exit'
  43. return msg
  44. def monologue_user_response(state: State) -> str:
  45. raise NotImplementedError('MonologueAgent should never ask for user responses.')
  46. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  47. 'CodeActAgent': codeact_user_response,
  48. 'MonologueAgent': monologue_user_response,
  49. }
  50. AGENT_CLS_TO_INST_SUFFIX = {
  51. 'CodeActAgent': 'When you think you have solved the question, please first send your answer to user through message and then exit.\n'
  52. }
  53. def process_instance(
  54. instance, agent_class, metadata, openai_api_key, reset_logger: bool = True
  55. ):
  56. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  57. eval_output_dir = metadata['eval_output_dir']
  58. if reset_logger:
  59. # Set up logger
  60. log_file = os.path.join(
  61. eval_output_dir, 'logs', f'instance_{instance["text"].strip()}.log'
  62. )
  63. # Remove all existing handlers from logger
  64. for handler in logger.handlers[:]:
  65. logger.removeHandler(handler)
  66. # add back the console handler to print ONE line
  67. logger.addHandler(get_console_handler())
  68. logger.info(
  69. f'Starting evaluation for instance {instance["text"].strip()}.\nLOG: tail -f {log_file}'
  70. )
  71. # Remove all existing handlers from logger
  72. for handler in logger.handlers[:]:
  73. logger.removeHandler(handler)
  74. file_handler = logging.FileHandler(log_file)
  75. file_handler.setFormatter(
  76. logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
  77. )
  78. logger.addHandler(file_handler)
  79. # Prepare instruction
  80. _game_class = {'things': Q20Game, 'celebs': Q20GameCelebrity}
  81. guesser_kargs = {
  82. 'max_new_tokens': 64,
  83. 'temperature': 0.8,
  84. 'repetition_penalty': 1.0,
  85. 'do_sample': True,
  86. } # no penalty
  87. # Use codeactagent as guesser_model
  88. global game
  89. game = _game_class[metadata['dataset']](
  90. item=instance['text'].strip(),
  91. answerer_model=metadata['answerer_model'],
  92. guesser_model=None,
  93. num_turns=metadata['max_iterations'],
  94. openai_api_key=openai_api_key,
  95. guesser_kargs=guesser_kargs,
  96. )
  97. instruction = f'{game.first_user_utterance}'
  98. logger.info(f'Instruction: {instruction}')
  99. # instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
  100. # NOTE: You can actually set slightly different instruction for different agents
  101. instruction += AGENT_CLS_TO_INST_SUFFIX.get(agent_class, '')
  102. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  103. state: State = asyncio.run(
  104. main(
  105. instruction,
  106. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(agent_class),
  107. )
  108. )
  109. # ======= Attempt to evaluate the agent's edits =======
  110. # If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  111. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  112. if state is None:
  113. raise ValueError('State should not be None.')
  114. final_message = ''
  115. for act, _ in reversed(state.history):
  116. if isinstance(act, MessageAction) and act.source == 'agent':
  117. final_message = act.content
  118. break
  119. logger.info(f'Final message: {final_message} | Ground truth: {instance["text"]}')
  120. test_result = game.reward()
  121. metrics = state.metrics.get() if state.metrics else None
  122. # Save the output
  123. output = {
  124. 'instance_id': instance['text'].strip(),
  125. 'instance': instance,
  126. 'instruction': instruction,
  127. 'metadata': metadata,
  128. 'history': [
  129. (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
  130. ],
  131. 'metrics': metrics,
  132. 'error': state.error if state and state.error else None,
  133. 'test_result': {
  134. 'success': test_result,
  135. 'final_message': final_message,
  136. 'ground_truth': instance['text'],
  137. },
  138. }
  139. return output
  140. if __name__ == '__main__':
  141. parser = get_parser()
  142. parser.add_argument(
  143. '--answerer_model', '-a', default='gpt-3.5-turbo', help='answerer model'
  144. )
  145. parser.add_argument(
  146. '--dataset',
  147. default='things',
  148. choices=['things', 'celebs'],
  149. type=str,
  150. help='dataset to be used',
  151. )
  152. parser.add_argument(
  153. '--OPENAI_API_KEY', type=str, required=True, help='Your OpenAI API key'
  154. )
  155. parser.add_argument(
  156. '--data-split',
  157. default='test',
  158. type=str,
  159. help='data split, eg, test',
  160. )
  161. args, _ = parser.parse_known_args()
  162. if args.directory:
  163. config.workspace_base = os.path.abspath(args.directory)
  164. print(f'Setting workspace base to {config.workspace_base}')
  165. # NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
  166. # so we don't need to manage file uploading to OpenDevin's repo
  167. eda_dataset = load_dataset(
  168. 'yizheapple/entity-deduction-arena', name=args.dataset, split=args.data_split
  169. )
  170. logger.info(
  171. f'Evaluating Entity Deduction Arena {args.dataset} {args.data_split} split'
  172. )
  173. # Check https://github.com/OpenDevin/OpenDevin/blob/main/evaluation/swe_bench/README.md#configure-opendevin-and-your-llm
  174. # for details of how to set `llm_config`
  175. if args.llm_config:
  176. specified_llm_config = get_llm_config_arg(args.llm_config)
  177. if specified_llm_config:
  178. config.llm = specified_llm_config
  179. logger.info(f'Config for evaluation: {config}')
  180. # TEST METADATA
  181. agent_class = args.agent_cls
  182. assert (
  183. agent_class in AGENT_CLS_TO_FAKE_USER_RESPONSE_FN
  184. ), f'Unsupported agent class: {agent_class}'
  185. model_name = config.llm.model.split('/')[-1]
  186. max_iterations = args.max_iterations
  187. eval_note = ''
  188. if args.eval_note is not None:
  189. eval_note += '_N_' + args.eval_note
  190. eval_output_dir = os.path.join(
  191. args.eval_output_dir,
  192. 'eda',
  193. agent_class,
  194. model_name + '_maxiter_' + str(max_iterations) + eval_note,
  195. )
  196. pathlib.Path(eval_output_dir).mkdir(parents=True, exist_ok=True)
  197. pathlib.Path(os.path.join(eval_output_dir, 'logs')).mkdir(
  198. parents=True, exist_ok=True
  199. )
  200. logger.info(f'Using evaluation output directory: {eval_output_dir}')
  201. metadata = {
  202. 'dataset': args.dataset,
  203. 'data_split': args.data_split,
  204. 'answerer_model': args.answerer_model,
  205. 'agent_class': agent_class,
  206. 'model_name': model_name,
  207. 'max_iterations': max_iterations,
  208. 'eval_output_dir': eval_output_dir,
  209. 'start_time': time.strftime('%Y-%m-%d %H:%M:%S'),
  210. # get the commit id of current repo for reproducibility
  211. 'git_commit': subprocess.check_output(['git', 'rev-parse', 'HEAD'])
  212. .decode('utf-8')
  213. .strip(),
  214. }
  215. logger.info(f'Metadata: {metadata}')
  216. with open(os.path.join(eval_output_dir, 'metadata.json'), 'w') as f:
  217. json.dump(metadata, f)
  218. # LIMIT EVALUATION
  219. eval_n_limit = args.eval_n_limit
  220. if eval_n_limit:
  221. eda_dataset = eda_dataset.select(list(range(eval_n_limit)))
  222. logger.info(f'Limiting evaluation to first {eval_n_limit} instances.')
  223. # OUTPUT FILE
  224. output_file = os.path.join(eval_output_dir, 'output.jsonl')
  225. logger.info(f'Writing evaluation output to {output_file}')
  226. finished_items = set()
  227. if os.path.exists(output_file):
  228. with open(output_file, 'r') as f:
  229. for line in f:
  230. data = json.loads(line)
  231. finished_items.add(data['instance_id'])
  232. logger.warning(
  233. f'Output file {output_file} already exists. Loaded {len(finished_items)} finished instances.'
  234. )
  235. output_fp = open(output_file, 'a')
  236. logger.info(
  237. f'Evaluation started with Agent {agent_class}, model {model_name}, max iterations {max_iterations}.'
  238. )
  239. # =============================================
  240. # filter out finished instances
  241. new_eda_dataset = []
  242. for instance in eda_dataset:
  243. if instance['text'].strip() in finished_items:
  244. logger.info(
  245. f'Skipping instance {instance["text"].strip()} as it is already finished.'
  246. )
  247. continue
  248. new_eda_dataset.append(instance)
  249. eda_dataset = new_eda_dataset
  250. logger.info(
  251. f'Finished instances: {len(finished_items)}, Remaining instances: {len(eda_dataset)}'
  252. )
  253. # =============================================
  254. pbar = tqdm(total=len(eda_dataset))
  255. # This function tracks the progress AND write the output to a JSONL file
  256. def update_progress(future):
  257. pbar.update(1)
  258. output = future.result()
  259. pbar.set_description(f'Instance {output["instance_id"]}')
  260. pbar.set_postfix_str(f'Test Result: {output["test_result"]}')
  261. logger.info(
  262. f'Finished evaluation for instance {output["instance_id"]}: {output["test_result"]}'
  263. )
  264. output_fp.write(json.dumps(output) + '\n')
  265. output_fp.flush()
  266. # This sets the multi-processing
  267. num_workers = args.eval_num_workers
  268. logger.info(f'Using {num_workers} workers for evaluation.')
  269. try:
  270. with ProcessPoolExecutor(num_workers) as executor:
  271. futures = []
  272. # This is how we perform multi-processing
  273. for instance in eda_dataset:
  274. future = executor.submit(
  275. process_instance,
  276. instance,
  277. agent_class,
  278. metadata,
  279. args.OPENAI_API_KEY,
  280. reset_logger=bool(num_workers > 1),
  281. )
  282. future.add_done_callback(update_progress)
  283. futures.append(future)
  284. # Wait for all futures to complete
  285. for future in futures:
  286. future.result()
  287. except KeyboardInterrupt:
  288. print('KeyboardInterrupt received. Cleaning up...')
  289. cleanup()
  290. output_fp.close()
  291. logger.info('Evaluation finished.')