run_infer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import asyncio
  2. import json
  3. import logging
  4. import multiprocessing as mp
  5. import os
  6. import pathlib
  7. import subprocess
  8. import time
  9. from concurrent.futures import ProcessPoolExecutor
  10. # import huggingface_hub
  11. from datasets import load_dataset
  12. from tqdm import tqdm
  13. from evaluation.EDA.game import Q20Game, Q20GameCelebrity
  14. # from evaluation.EDA.scorer import question_scorer
  15. from opendevin.controller.state.state import State
  16. from opendevin.core.config import config, get_llm_config_arg, get_parser
  17. from opendevin.core.logger import get_console_handler
  18. from opendevin.core.logger import opendevin_logger as logger
  19. from opendevin.core.main import main
  20. from opendevin.events.action import MessageAction
  21. from opendevin.events.serialization.event import event_to_dict
  22. game = None
  23. def cleanup():
  24. print('Cleaning up child processes...')
  25. for process in mp.active_children():
  26. print(f'Terminating child process: {process.name}')
  27. process.terminate()
  28. process.join()
  29. def codeact_user_response(state: State) -> str:
  30. global game
  31. model_guess = ''
  32. if state.history:
  33. for act, _ in reversed(state.history):
  34. if isinstance(act, MessageAction) and act.source == 'agent':
  35. model_guess = act.content
  36. break
  37. msg = game.generate_user_response(model_guess)
  38. game.curr_turn += 1
  39. logger.info(f'Model guess: {model_guess}')
  40. logger.info(f'Anwser response: {msg}')
  41. if 'bingo!' in msg.lower():
  42. return '/exit'
  43. return msg
  44. def monologue_user_response(state: State) -> str:
  45. raise NotImplementedError('MonologueAgent should never ask for user responses.')
  46. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  47. 'CodeActAgent': codeact_user_response,
  48. 'MonologueAgent': monologue_user_response,
  49. }
  50. AGENT_CLS_TO_INST_SUFFIX = {
  51. 'CodeActAgent': 'When you think you have solved the question, please first send your answer to user through message and then exit.\n'
  52. }
  53. def process_instance(instance, agent_class, metadata, reset_logger: bool = True):
  54. # Setup the logger properly, so you can run multi-processing to parallize the evaluation
  55. eval_output_dir = metadata['eval_output_dir']
  56. if reset_logger:
  57. # Set up logger
  58. log_file = os.path.join(
  59. eval_output_dir, 'logs', f'instance_{instance["text"].strip()}.log'
  60. )
  61. # Remove all existing handlers from logger
  62. for handler in logger.handlers[:]:
  63. logger.removeHandler(handler)
  64. # add back the console handler to print ONE line
  65. logger.addHandler(get_console_handler())
  66. logger.info(
  67. f'Starting evaluation for instance {instance["text"].strip()}.\nLOG: tail -f {log_file}'
  68. )
  69. # Remove all existing handlers from logger
  70. for handler in logger.handlers[:]:
  71. logger.removeHandler(handler)
  72. file_handler = logging.FileHandler(log_file)
  73. file_handler.setFormatter(
  74. logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
  75. )
  76. logger.addHandler(file_handler)
  77. # Prepare instruction
  78. _game_class = {'things': Q20Game, 'celebs': Q20GameCelebrity}
  79. guesser_kargs = {
  80. 'max_new_tokens': 64,
  81. 'temperature': 0.8,
  82. 'repetition_penalty': 1.0,
  83. 'do_sample': True,
  84. } # no penalty
  85. # Use codeactagent as guesser_model
  86. global game
  87. game = _game_class[metadata['dataset']](
  88. item=instance['text'].strip(),
  89. answerer_model=metadata['answerer_model'],
  90. guesser_model=None,
  91. num_turns=metadata['max_iterations'],
  92. openai_api_key=metadata['openai_api'],
  93. guesser_kargs=guesser_kargs,
  94. )
  95. instruction = f'{game.first_user_utterance}'
  96. logger.info(f'Instruction: {instruction}')
  97. # instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
  98. # NOTE: You can actually set slightly different instruction for different agents
  99. instruction += AGENT_CLS_TO_INST_SUFFIX.get(agent_class, '')
  100. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  101. state: State = asyncio.run(
  102. main(
  103. instruction,
  104. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(agent_class),
  105. )
  106. )
  107. # ======= Attempt to evaluate the agent's edits =======
  108. # If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  109. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  110. if state is None:
  111. raise ValueError('State should not be None.')
  112. final_message = ''
  113. for act, _ in reversed(state.history):
  114. if isinstance(act, MessageAction) and act.source == 'agent':
  115. final_message = act.content
  116. break
  117. logger.info(f'Final message: {final_message} | Ground truth: {instance["text"]}')
  118. test_result = game.reward()
  119. # Save the output
  120. output = {
  121. 'instance_id': instance['text'].strip(),
  122. 'instance': instance,
  123. 'instruction': instruction,
  124. 'metadata': metadata,
  125. 'history': [
  126. (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
  127. ],
  128. 'error': state.error if state and state.error else None,
  129. 'test_result': {
  130. 'success': test_result,
  131. 'final_message': final_message,
  132. 'ground_truth': instance['text'],
  133. },
  134. }
  135. return output
  136. if __name__ == '__main__':
  137. parser = get_parser()
  138. parser.add_argument(
  139. '--answerer_model', '-a', default='gpt-3.5-turbo', help='answerer model'
  140. )
  141. parser.add_argument(
  142. '--dataset',
  143. default='things',
  144. choices=['things', 'celebs'],
  145. type=str,
  146. help='dataset to be used',
  147. )
  148. parser.add_argument(
  149. '--OPENAI_API_KEY', type=str, required=True, help='Your OpenAI API key'
  150. )
  151. parser.add_argument(
  152. '--data-split',
  153. default='test',
  154. type=str,
  155. help='data split, eg, test',
  156. )
  157. args, _ = parser.parse_known_args()
  158. if args.directory:
  159. config.workspace_base = os.path.abspath(args.directory)
  160. print(f'Setting workspace base to {config.workspace_base}')
  161. # NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
  162. # so we don't need to manage file uploading to OpenDevin's repo
  163. eda_dataset = load_dataset(
  164. 'yizheapple/entity-deduction-arena', name=args.dataset, split=args.data_split
  165. )
  166. logger.info(
  167. f'Evaluating Entity Deduction Arena {args.dataset} {args.data_split} split'
  168. )
  169. # Check https://github.com/OpenDevin/OpenDevin/blob/main/evaluation/swe_bench/README.md#configure-opendevin-and-your-llm
  170. # for details of how to set `llm_config`
  171. if args.llm_config:
  172. specified_llm_config = get_llm_config_arg(args.llm_config)
  173. if specified_llm_config:
  174. config.llm = specified_llm_config
  175. logger.info(f'Config for evaluation: {config}')
  176. # TEST METADATA
  177. agent_class = args.agent_cls
  178. assert (
  179. agent_class in AGENT_CLS_TO_FAKE_USER_RESPONSE_FN
  180. ), f'Unsupported agent class: {agent_class}'
  181. model_name = config.llm.model.split('/')[-1]
  182. max_iterations = args.max_iterations
  183. eval_note = ''
  184. if args.eval_note is not None:
  185. eval_note += '_N_' + args.eval_note
  186. eval_output_dir = os.path.join(
  187. args.eval_output_dir,
  188. 'eda',
  189. agent_class,
  190. model_name + '_maxiter_' + str(max_iterations) + eval_note,
  191. )
  192. pathlib.Path(eval_output_dir).mkdir(parents=True, exist_ok=True)
  193. pathlib.Path(os.path.join(eval_output_dir, 'logs')).mkdir(
  194. parents=True, exist_ok=True
  195. )
  196. logger.info(f'Using evaluation output directory: {eval_output_dir}')
  197. metadata = {
  198. 'dataset': args.dataset,
  199. 'data_split': args.data_split,
  200. 'answerer_model': args.answerer_model,
  201. 'agent_class': agent_class,
  202. 'openai_api': args.OPENAI_API_KEY,
  203. 'model_name': model_name,
  204. 'max_iterations': max_iterations,
  205. 'eval_output_dir': eval_output_dir,
  206. 'start_time': time.strftime('%Y-%m-%d %H:%M:%S'),
  207. # get the commit id of current repo for reproducibility
  208. 'git_commit': subprocess.check_output(['git', 'rev-parse', 'HEAD'])
  209. .decode('utf-8')
  210. .strip(),
  211. }
  212. logger.info(f'Metadata: {metadata}')
  213. with open(os.path.join(eval_output_dir, 'metadata.json'), 'w') as f:
  214. json.dump(metadata, f)
  215. # LIMIT EVALUATION
  216. eval_n_limit = args.eval_n_limit
  217. if eval_n_limit:
  218. eda_dataset = eda_dataset.select(list(range(eval_n_limit)))
  219. logger.info(f'Limiting evaluation to first {eval_n_limit} instances.')
  220. # OUTPUT FILE
  221. output_file = os.path.join(eval_output_dir, 'output.jsonl')
  222. logger.info(f'Writing evaluation output to {output_file}')
  223. finished_items = set()
  224. if os.path.exists(output_file):
  225. with open(output_file, 'r') as f:
  226. for line in f:
  227. data = json.loads(line)
  228. finished_items.add(data['instance_id'])
  229. logger.warning(
  230. f'Output file {output_file} already exists. Loaded {len(finished_items)} finished instances.'
  231. )
  232. output_fp = open(output_file, 'a')
  233. logger.info(
  234. f'Evaluation started with Agent {agent_class}, model {model_name}, max iterations {max_iterations}.'
  235. )
  236. # =============================================
  237. # filter out finished instances
  238. new_eda_dataset = []
  239. for instance in eda_dataset:
  240. if instance['text'].strip() in finished_items:
  241. logger.info(
  242. f'Skipping instance {instance["text"].strip()} as it is already finished.'
  243. )
  244. continue
  245. new_eda_dataset.append(instance)
  246. eda_dataset = new_eda_dataset
  247. logger.info(
  248. f'Finished instances: {len(finished_items)}, Remaining instances: {len(eda_dataset)}'
  249. )
  250. # =============================================
  251. pbar = tqdm(total=len(eda_dataset))
  252. # This function tracks the progress AND write the output to a JSONL file
  253. def update_progress(future):
  254. pbar.update(1)
  255. output = future.result()
  256. pbar.set_description(f'Instance {output["instance_id"]}')
  257. pbar.set_postfix_str(f'Test Result: {output["test_result"]}')
  258. logger.info(
  259. f'Finished evaluation for instance {output["instance_id"]}: {output["test_result"]}'
  260. )
  261. output_fp.write(json.dumps(output) + '\n')
  262. output_fp.flush()
  263. # This sets the multi-processing
  264. num_workers = args.eval_num_workers
  265. logger.info(f'Using {num_workers} workers for evaluation.')
  266. try:
  267. with ProcessPoolExecutor(num_workers) as executor:
  268. futures = []
  269. # This is how we perform multi-processing
  270. for instance in eda_dataset:
  271. future = executor.submit(
  272. process_instance,
  273. instance,
  274. agent_class,
  275. metadata,
  276. reset_logger=bool(num_workers > 1),
  277. )
  278. future.add_done_callback(update_progress)
  279. futures.append(future)
  280. # Wait for all futures to complete
  281. for future in futures:
  282. future.result()
  283. except KeyboardInterrupt:
  284. print('KeyboardInterrupt received. Cleaning up...')
  285. cleanup()
  286. output_fp.close()
  287. logger.info('Evaluation finished.')