run_infer.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. import asyncio
  2. import json
  3. import os
  4. import git
  5. import pandas as pd
  6. from evaluation.benchmarks.discoverybench.eval_utils.eval_w_subhypo_gen import (
  7. run_eval_gold_vs_gen_NL_hypo_workflow,
  8. )
  9. from evaluation.benchmarks.discoverybench.eval_utils.response_parser import (
  10. extract_gen_hypo_from_logs,
  11. )
  12. from evaluation.utils.shared import (
  13. EvalMetadata,
  14. EvalOutput,
  15. codeact_user_response,
  16. compatibility_for_eval_history_pairs,
  17. make_metadata,
  18. prepare_dataset,
  19. reset_logger_for_multiprocessing,
  20. run_evaluation,
  21. )
  22. from openhands.controller.state.state import State
  23. from openhands.core.config import (
  24. AgentConfig,
  25. AppConfig,
  26. SandboxConfig,
  27. get_llm_config_arg,
  28. parse_arguments,
  29. )
  30. from openhands.core.logger import openhands_logger as logger
  31. from openhands.core.main import create_runtime, run_controller
  32. from openhands.events.action import AgentFinishAction, CmdRunAction, MessageAction
  33. from openhands.events.observation import CmdOutputObservation
  34. from openhands.runtime.base import Runtime
  35. from openhands.utils.async_utils import call_async_from_sync
  36. EVALUATION_LLM = 'gpt-4-1106-preview'
  37. DATA_FILES = {}
  38. LIBRARIES = [
  39. 'pandas',
  40. 'numpy',
  41. 'scipy',
  42. 'matplotlib',
  43. 'seaborn',
  44. 'scikit-learn',
  45. 'statsmodels',
  46. ]
  47. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  48. 'CodeActAgent': codeact_user_response,
  49. }
  50. AGENT_CLS_TO_INST_SUFFIX = {
  51. 'CodeActAgent': 'When you think you have fixed the issue through code changes, please finish the interaction using the "finish" tool.\n'
  52. }
  53. def get_config(
  54. metadata: EvalMetadata,
  55. ) -> AppConfig:
  56. config = AppConfig(
  57. default_agent=metadata.agent_class,
  58. run_as_openhands=False,
  59. runtime='eventstream',
  60. max_iterations=metadata.max_iterations,
  61. sandbox=SandboxConfig(
  62. base_container_image='python:3.12-bookworm',
  63. enable_auto_lint=True,
  64. use_host_network=False,
  65. ),
  66. # do not mount workspace
  67. workspace_base=None,
  68. workspace_mount_path=None,
  69. )
  70. config.set_llm_config(metadata.llm_config)
  71. agent_config = AgentConfig(
  72. function_calling=False,
  73. codeact_enable_jupyter=True,
  74. codeact_enable_browsing_delegate=True,
  75. )
  76. config.set_agent_config(agent_config)
  77. return config
  78. def get_dv_query_for_real(
  79. datasets, question, domain_knowledge=None, workflow_tags=None
  80. ):
  81. """
  82. Prepare a structured query for the agent to execute on the specified datasets.
  83. This function constructs a query by compiling metadata from the provided datasets, along with any relevant domain knowledge and workflow tags.
  84. Args:
  85. datasets: List of datasets
  86. question: Query to be answered
  87. domain_knowledge: Domain knowledge if any
  88. workflow_tags: Workflow tags if any
  89. Returns:
  90. query_to_dv: Query to be run on the dataset
  91. dataset_meta: Metadata of the dataset
  92. """
  93. dataset_meta = ''
  94. for dataset_metadata in datasets:
  95. dataset_meta += 'Dataset name: ' + dataset_metadata['name']
  96. dataset_meta += 'Dataset description: ' + dataset_metadata['description']
  97. dataset_meta += '\nBrief description of columns: '
  98. for col in dataset_metadata['columns']['raw']:
  99. dataset_meta += col['name'] + ': ' + col['description'] + ', '
  100. query_to_dv = dataset_meta
  101. query_to_dv += f'\nQuery: {question}'
  102. if domain_knowledge:
  103. query_to_dv += (
  104. '\nAdditionally, we provide some hints that might be useful to solve the task. Domain Knowledge: \n'
  105. + domain_knowledge
  106. + '.\n'
  107. )
  108. if workflow_tags:
  109. query_to_dv += 'The meta tags are: ' + workflow_tags + '.\n'
  110. query_to_dv += (
  111. 'In the final answer, please write down a scientific hypothesis in '
  112. 'natural language, derived from the provided dataset, clearly stating the '
  113. 'context of hypothesis (if any), variables chosen (if any) and '
  114. 'relationship between those variables (if any) including any statistical significance.'
  115. 'Also generate a summary of the full workflow starting from data loading that led to the final answer as WORKFLOW SUMMARY:'
  116. )
  117. # Run the NL query through datavoyager
  118. return query_to_dv, dataset_meta
  119. def initialize_runtime(runtime: Runtime, data_files: list[str]):
  120. """
  121. Initialize the runtime for the agent.
  122. This function is called before the runtime is used to run the agent.
  123. """
  124. logger.info(f"{'-' * 50} BEGIN Runtime Initialization Fn {'-' * 50}")
  125. obs: CmdOutputObservation
  126. action = CmdRunAction(command='mkdir -p /workspace')
  127. logger.info(action, extra={'msg_type': 'ACTION'})
  128. obs = runtime.run_action(action)
  129. assert obs.exit_code == 0
  130. action = CmdRunAction(command='cd /workspace')
  131. logger.info(action, extra={'msg_type': 'ACTION'})
  132. obs = runtime.run_action(action)
  133. assert obs.exit_code == 0
  134. for file in data_files:
  135. runtime.copy_to(
  136. file,
  137. '/workspace',
  138. )
  139. for lib in LIBRARIES:
  140. action = CmdRunAction(command=f'pip install {lib}')
  141. logger.info(action, extra={'msg_type': 'ACTION'})
  142. obs = runtime.run_action(action)
  143. assert obs.exit_code == 0
  144. logger.info(f"{'-' * 50} END Runtime Initialization Fn {'-' * 50}")
  145. def get_last_agent_finish_action(state: State) -> AgentFinishAction:
  146. for event in reversed(state.history):
  147. if isinstance(event, AgentFinishAction):
  148. return event
  149. return None
  150. def get_last_message_action(state: State) -> MessageAction:
  151. for event in reversed(state.history):
  152. if isinstance(event, MessageAction):
  153. return event
  154. return None
  155. def complete_runtime(state: State):
  156. last_agent_finish_action = get_last_agent_finish_action(state)
  157. last_agent_message_action = get_last_message_action(state)
  158. if last_agent_finish_action is not None:
  159. final_message_1 = last_agent_finish_action.thought
  160. gen_hypo_1, gen_workflow_1, error_1 = extract_gen_hypo_from_logs(
  161. final_message_1
  162. )
  163. else:
  164. gen_hypo_1, gen_workflow_1, error_1 = '', '', ''
  165. if last_agent_message_action is not None:
  166. final_message_2 = last_agent_message_action.content
  167. gen_hypo_2, gen_workflow_2, error_2 = extract_gen_hypo_from_logs(
  168. final_message_2
  169. )
  170. else:
  171. gen_hypo_2, gen_workflow_2, error_2 = '', '', ''
  172. if gen_hypo_1 and gen_hypo_2:
  173. test_result = {
  174. 'gen_hypo': last_agent_finish_action.thought
  175. if last_agent_finish_action
  176. else last_agent_message_action.content,
  177. 'gen_workflow': '',
  178. 'error': '',
  179. }
  180. return test_result
  181. test_result = {
  182. 'gen_hypo': gen_hypo_1 if gen_hypo_1 else gen_hypo_2,
  183. 'gen_workflow': gen_workflow_1 if gen_workflow_1 else gen_workflow_2,
  184. 'error': error_1 if error_1 else error_2,
  185. }
  186. return test_result
  187. def process_instance(
  188. instance: pd.Series,
  189. metadata: EvalMetadata,
  190. reset_logger: bool = True,
  191. ):
  192. """
  193. Process and evaluate a single instance of the dataset.
  194. This function executes the OpenHands agent
  195. for a specific instance of the dataset. It retrieves
  196. the agent's results and evaluates them against the gold
  197. hypothesis.
  198. Args:
  199. instance: A single row of the dataset
  200. metadata: Metadata for the evaluation
  201. reset_logger: Whether to reset the logger
  202. Returns:
  203. output: EvalOutput object
  204. """
  205. config = get_config(metadata)
  206. # Setup the logger properly, so you can run
  207. # multi-processing to parallelize the evaluation
  208. if reset_logger:
  209. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  210. reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
  211. else:
  212. logger.info(f'Starting evaluation for instance {instance.instance_id}.')
  213. problem_statement, dataset_metadata = get_dv_query_for_real(
  214. datasets=instance.datasets,
  215. question=instance.query,
  216. domain_knowledge=instance.domain_knowledge,
  217. workflow_tags=instance.workflow_tags,
  218. )
  219. # Prepare instruction
  220. instruction = (
  221. f'You are a discovery agent who can execute a python code only once to answer a query based on one or more datasets. The datasets will be present in the current directory.\n\n'
  222. 'Environment has been set up for you to start working. You may assume all necessary tools and datasets are installed.\n\n'
  223. '# Problem Statement\n'
  224. f'{problem_statement}\n\n'
  225. )
  226. instruction += (
  227. 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
  228. 'You should NOT modify any existing test case files. If needed, you can add new test cases in a NEW file to reproduce the issue.\n'
  229. 'You SHOULD INCLUDE PROPER INDENTATION in your edit commands.\n'
  230. )
  231. # NOTE: You can actually set slightly different instruction for different agents
  232. instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
  233. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  234. runtime = create_runtime(config)
  235. call_async_from_sync(runtime.connect)
  236. initialize_runtime(runtime, instance.data_files)
  237. state: State | None = asyncio.run(
  238. run_controller(
  239. config=config,
  240. initial_user_action=MessageAction(content=instruction),
  241. runtime=runtime,
  242. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
  243. metadata.agent_class
  244. ),
  245. )
  246. )
  247. if state is None:
  248. raise ValueError('State should not be None.')
  249. metrics = state.metrics.get() if state.metrics else None
  250. test_result = complete_runtime(state)
  251. # history is now available as a stream of events, rather than list of pairs of (Action, Observation)
  252. # for compatibility with the existing output format, we can remake the pairs here
  253. # remove when it becomes unnecessary
  254. histories = compatibility_for_eval_history_pairs(state.history)
  255. # DiscoveryBench Evaluation
  256. eval_rec = run_eval_gold_vs_gen_NL_hypo_workflow(
  257. query=instance.query,
  258. gold_hypo=instance.gold_hypo,
  259. gold_workflow='',
  260. gen_hypo=test_result['gen_hypo'],
  261. gen_workflow='',
  262. dataset_meta=instance.dataset_metadata,
  263. llm_used=EVALUATION_LLM,
  264. dataset_type='real',
  265. )
  266. test_result['eval_rec'] = eval_rec
  267. output = EvalOutput(
  268. instance_id=str(instance.instance_id),
  269. instruction=instruction,
  270. metadata=metadata,
  271. history=histories,
  272. metrics=metrics,
  273. error=state.last_error if state and state.last_error else None,
  274. test_result=test_result,
  275. )
  276. return output
  277. def update_csv_name(name):
  278. name = name.replace('-', '_')
  279. if 'meta_regression' in name:
  280. name = name.replace('meta_regression', 'meta-regression')
  281. if 'ML_enabled' in name:
  282. name = name.replace('ML_enabled', 'ML-enabled')
  283. return name
  284. def list_csv_files(list_of_datasets):
  285. res = []
  286. for ele in list_of_datasets:
  287. for key, value in ele.items():
  288. if key == 'name':
  289. csv_file_name = update_csv_name(value)
  290. res.append(DATA_FILES[csv_file_name])
  291. return res
  292. def create_dataset(repo_location: str, split: str = 'test'):
  293. """
  294. Create a dataset from the discoverybench repository
  295. by walking through the repository and extracting metadata
  296. from the metadata_{}.json files
  297. Args:
  298. repo_location: Location of the repository
  299. split: Split of the dataset to use
  300. Returns:
  301. df: DataFrame containing the dataset instances
  302. """
  303. data_dict = {}
  304. data_location = os.path.join(repo_location, 'discoverybench', 'real', split)
  305. answer_key_location = os.path.join(repo_location, 'eval', 'answer_key_real.csv')
  306. idx = 0
  307. for root, dirs, files in os.walk(data_location):
  308. for file in files:
  309. if file.endswith('.json'):
  310. if 'metadata' in file:
  311. metadata = json.load(open(os.path.join(root, file)))
  312. dataset = root.split('/')[-1]
  313. metadata_id = file.split('_')[-1].split('.')[0]
  314. domain = metadata.get('domain', '')
  315. domain_knowledge = metadata.get('domain_knowledge', '')
  316. workflow_tags = metadata.get('workflow_tags', '')
  317. datasets = metadata.get('datasets', [])
  318. queries = metadata.get('queries', [])
  319. gold_workflow = metadata.get('workflow')
  320. # loop through queries list to get queries
  321. # and each query has qid; add that to dictionary
  322. for query in queries[0]:
  323. qid = query.get('qid', '')
  324. data = {
  325. 'dataset': dataset,
  326. 'metadata_id': metadata_id,
  327. 'qid': qid,
  328. 'domain': domain,
  329. 'domain_knowledge': domain_knowledge,
  330. 'workflow_tags': workflow_tags,
  331. 'datasets': datasets,
  332. 'question_type': query['question_type'],
  333. 'query': query['question'],
  334. 'gold_workflow': gold_workflow,
  335. 'dataset_metadata': metadata,
  336. }
  337. data_dict[idx] = data
  338. idx += 1
  339. if file.endswith('.csv'):
  340. DATA_FILES[file] = os.path.join(root, file)
  341. if file.endswith('.txt'):
  342. DATA_FILES[file] = os.path.join(root, file)
  343. df = pd.DataFrame.from_dict(data_dict, orient='index')
  344. df['instance_id'] = df.index
  345. df['data_files'] = df['datasets'].apply(lambda x: list_csv_files(x))
  346. answer_key = pd.read_csv(answer_key_location)
  347. answer_key = answer_key.rename(
  348. columns={
  349. 'metadataid': 'metadata_id',
  350. 'query_id': 'qid',
  351. 'gold_hypothesis': 'gold_hypothesis',
  352. }
  353. )
  354. df['qid'] = df['qid'].astype(int)
  355. df['metadata_id'] = df['metadata_id'].astype(int)
  356. answer_key['qid'] = answer_key['qid'].astype(int)
  357. answer_key['metadata_id'] = answer_key['metadata_id'].astype(int)
  358. df = pd.merge(df, answer_key, on=['dataset', 'metadata_id', 'qid'], how='left')
  359. return df
  360. if __name__ == '__main__':
  361. args = parse_arguments()
  362. # clone git repositor for csv files
  363. repo_url = 'https://github.com/allenai/discoverybench.git'
  364. repo_location = 'git-discoverybench-allenai'
  365. try:
  366. git.Repo.clone_from(repo_url, repo_location)
  367. except git.exc.GitCommandError:
  368. print('Repository already exists')
  369. dataset = create_dataset(repo_location)
  370. # check if there is any empty csv_file
  371. if dataset['data_files'].isnull().any():
  372. raise ValueError('Some csv files are missing.')
  373. llm_config = None
  374. if args.llm_config:
  375. llm_config = get_llm_config_arg(args.llm_config)
  376. if llm_config is None:
  377. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  378. metadata = make_metadata(
  379. llm_config,
  380. 'discoverybench-python',
  381. args.agent_cls,
  382. args.max_iterations,
  383. args.eval_note,
  384. args.eval_output_dir,
  385. )
  386. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  387. instances = prepare_dataset(dataset, output_file, args.eval_n_limit)
  388. run_evaluation(
  389. instances,
  390. metadata,
  391. output_file,
  392. args.eval_num_workers,
  393. process_instance,
  394. )