run_infer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import asyncio
  2. import logging
  3. import os
  4. import pathlib
  5. import shutil
  6. import pandas as pd
  7. from datasets import load_dataset
  8. from evaluation.swe_bench.swe_env_box import DockerSSHBox
  9. from evaluation.utils.shared import (
  10. EvalMetadata,
  11. codeact_user_response,
  12. make_metadata,
  13. prepare_dataset,
  14. run_evaluation,
  15. )
  16. from opendevin.controller.agent import Agent
  17. from opendevin.controller.state.state import State
  18. from opendevin.core.config import get_llm_config_arg, get_parser, load_app_config
  19. from opendevin.core.logger import get_console_handler
  20. from opendevin.core.logger import opendevin_logger as logger
  21. from opendevin.core.main import run_controller
  22. from opendevin.llm.llm import LLM
  23. config = load_app_config()
  24. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  25. 'CodeActAgent': codeact_user_response,
  26. }
  27. AGENT_CLS_TO_INST_SUFFIX = {
  28. 'CodeActAgent': 'When you think you have solved the question, please first send your answer to user through message and then exit.\n'
  29. }
  30. def get_choice(answer_str):
  31. choices = [
  32. 'A',
  33. 'B',
  34. 'C',
  35. 'D',
  36. 'E',
  37. 'F',
  38. 'G',
  39. 'H',
  40. 'A)',
  41. 'B)',
  42. 'C)',
  43. 'D)',
  44. 'E)',
  45. 'F)',
  46. 'G)',
  47. 'H)',
  48. 'A.',
  49. 'B.',
  50. 'C.',
  51. 'D.',
  52. 'E.',
  53. 'F.',
  54. 'G.',
  55. 'H.',
  56. ]
  57. for c in choices:
  58. if answer_str.startswith(c):
  59. return c.replace(')', '')
  60. if answer_str.startswith(':'):
  61. return answer_str.replace(':', '').replace('.', '').strip()
  62. return None
  63. def get_test_result(
  64. model_answer: str,
  65. ground_truth: str,
  66. ) -> dict[str, bool]:
  67. gold_answer = ground_truth.replace('(', '').replace(')', '').strip()
  68. answer_str = model_answer if model_answer is not None else ''
  69. prediction = get_choice(answer_str)
  70. indicators = [
  71. 'the correct option is',
  72. 'the correct answer is',
  73. 'The correct answer is',
  74. 'The correct option is',
  75. 'Thus, the answer is',
  76. ]
  77. if prediction is None:
  78. for indicator in indicators:
  79. if answer_str.find(indicator) >= 0:
  80. answer_str = answer_str.split(indicator)[1].strip()
  81. prediction = get_choice(answer_str)
  82. break
  83. isTrue = prediction == gold_answer
  84. test_result = {'result': isTrue}
  85. return test_result
  86. def process_instance(
  87. instance: pd.Series,
  88. metadata: EvalMetadata,
  89. reset_logger: bool = True,
  90. ):
  91. # Create the agent
  92. agent = Agent.get_cls(metadata.agent_class)(llm=LLM(config=metadata.llm_config))
  93. old_workspace_mount_path = config.workspace_mount_path
  94. old_workspace_base = config.workspace_base
  95. try:
  96. workspace_mount_path = os.path.join(
  97. config.workspace_mount_path, '_eval_workspace'
  98. )
  99. # create process-specific workspace dir
  100. workspace_mount_path = os.path.join(workspace_mount_path, str(os.getpid()))
  101. pathlib.Path(workspace_mount_path).mkdir(parents=True, exist_ok=True)
  102. # reset workspace to config
  103. config.workspace_base = workspace_mount_path
  104. config.workspace_mount_path = workspace_mount_path
  105. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  106. if reset_logger:
  107. # Set up logger
  108. log_file = os.path.join(
  109. metadata.eval_output_dir, 'logs', f'instance_{instance["id"]}.log'
  110. )
  111. # Remove all existing handlers from logger
  112. for handler in logger.handlers[:]:
  113. logger.removeHandler(handler)
  114. # add back the console handler to print ONE line
  115. logger.addHandler(get_console_handler())
  116. logger.info(
  117. f'Starting evaluation for instance {instance["id"]}.\nLOG: tail -f {log_file}'
  118. )
  119. # Remove all existing handlers from logger
  120. for handler in logger.handlers[:]:
  121. logger.removeHandler(handler)
  122. file_handler = logging.FileHandler(log_file)
  123. file_handler.setFormatter(
  124. logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
  125. )
  126. logger.addHandler(file_handler)
  127. logger.info(f'Process-specific workspace mounted at {workspace_mount_path}')
  128. # sandbox = DockerSSHBox()
  129. logic_inference_path = os.path.join(workspace_mount_path, 'logic_inference.py')
  130. if not os.path.exists(logic_inference_path):
  131. shutil.copyfile(
  132. './evaluation/logic_reasoning/logic_inference.py', logic_inference_path
  133. )
  134. logger.info(f'logic_inference.py copied to {workspace_mount_path}')
  135. cache_dir = os.path.join(workspace_mount_path, '.cache_program')
  136. if not os.path.exists(cache_dir):
  137. os.makedirs(cache_dir)
  138. # Prepare instruction
  139. with open('./evaluation/logic_reasoning/instruction.txt', 'r') as f:
  140. instruction = f.read()
  141. instance_logic_programs = instance['raw_logic_programs'][0].strip()
  142. instruction = instruction.replace('[[dataset_name]]', dataset_name)
  143. instruction = instruction.replace('[[logic_programs]]', instance_logic_programs)
  144. instruction = instruction.replace(
  145. '[[logic_inference_path.py]]', logic_inference_path
  146. )
  147. # NOTE: You can actually set slightly different instruction for different agents
  148. instruction += AGENT_CLS_TO_INST_SUFFIX[agent.__class__.__name__]
  149. # use a session id for concurrent evaluation
  150. sid = instance['id'] + '_' + str(os.getpid())
  151. sandbox = DockerSSHBox(
  152. config=config.sandbox,
  153. persist_sandbox=False,
  154. workspace_mount_path=config.workspace_mount_path,
  155. sandbox_workspace_dir=config.workspace_mount_path_in_sandbox,
  156. cache_dir=config.cache_dir,
  157. run_as_devin=config.run_as_devin,
  158. sid=sid,
  159. )
  160. exit_code, command_output = sandbox.execute('pip install scitools-pyke')
  161. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  162. config.max_iterations = metadata.max_iterations
  163. state: State | None = asyncio.run(
  164. run_controller(
  165. config=config,
  166. task_str=instruction,
  167. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
  168. agent.__class__.__name__
  169. ),
  170. agent=agent,
  171. sandbox=sandbox,
  172. sid=sid,
  173. )
  174. )
  175. # ======= Attempt to evaluate the agent's edits =======
  176. # If you are working on simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
  177. # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
  178. if state is None:
  179. raise ValueError('State should not be None.')
  180. final_message = ''
  181. messages = []
  182. for event in state.history.get_events(reverse=True):
  183. # will this be a MessageAction?
  184. # TODO we can filter for types of events if we know what to expect
  185. messages.append(event.content)
  186. if str(event.content) in ["'A'", "'B'", "'C'"]:
  187. final_message = event.content
  188. break
  189. final_message = final_message.strip("'")
  190. logger.info(
  191. f'Predicted answer: {final_message}, Ground truth: {instance["answer"]}'
  192. )
  193. test_result = get_test_result(
  194. model_answer=final_message, ground_truth=instance['answer']
  195. )
  196. metrics = state.metrics.get() if state.metrics else None
  197. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  198. # for compatibility with the existing output format, we can remake the pairs here
  199. # remove when it becomes unnecessary
  200. histories = state.history.compatibility_for_eval_history_pairs()
  201. # Save the output
  202. output = {
  203. 'id': instance['id'],
  204. 'instance': instance,
  205. 'instruction': instruction,
  206. # 'metadata': metadata.model_dump(),
  207. 'history': histories,
  208. 'metrics': metrics,
  209. 'final_message': final_message,
  210. 'messages': messages,
  211. 'error': state.last_error if state and state.last_error else None,
  212. 'test_result': test_result,
  213. }
  214. except Exception:
  215. logger.error('Process instance failed')
  216. raise
  217. finally:
  218. config.workspace_mount_path = old_workspace_mount_path
  219. config.workspace_base = old_workspace_base
  220. # Close the sandbox
  221. sandbox.close()
  222. return output
  223. if __name__ == '__main__':
  224. parser = get_parser()
  225. parser.add_argument(
  226. '--dataset',
  227. type=str,
  228. help='the logic reasoning dataset to evaluate on {ProntoQA, ProofWriter}',
  229. default='ProntoQA',
  230. )
  231. parser.add_argument(
  232. '--data_split',
  233. type=str,
  234. help='data split to evaluate on {validation}', # right now we only support validation split
  235. default='validation',
  236. )
  237. args, _ = parser.parse_known_args()
  238. if args.directory:
  239. config.workspace_base = os.path.abspath(args.directory)
  240. print(f'Setting workspace base to {config.workspace_base}')
  241. dataset_name = args.dataset
  242. data_split = args.data_split
  243. dataset = load_dataset(f'renma/{dataset_name}')
  244. logic_reasoning_tests = dataset[data_split]
  245. id_column = 'id'
  246. llm_config = get_llm_config_arg(args.llm_config) if args.llm_config else config.llm
  247. logger.info(f'Config for evaluation: {config}')
  248. metadata = make_metadata(
  249. llm_config,
  250. args.dataset_name,
  251. args.agent_cls,
  252. args.max_iterations,
  253. args.eval_note,
  254. args.eval_output_dir,
  255. )
  256. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  257. instances = prepare_dataset(dataset, output_file, args.eval_n_limit, id_column)
  258. run_evaluation(
  259. instances,
  260. metadata,
  261. output_file,
  262. args.eval_num_workers,
  263. process_instance,
  264. id_column,
  265. )