shared.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. import json
  2. import logging
  3. import multiprocessing as mp
  4. import os
  5. import pathlib
  6. import subprocess
  7. import time
  8. import traceback
  9. from typing import Any, Awaitable, Callable, TextIO
  10. import pandas as pd
  11. from pydantic import BaseModel
  12. from tqdm import tqdm
  13. from openhands.controller.state.state import State
  14. from openhands.core.config import LLMConfig
  15. from openhands.core.logger import get_console_handler
  16. from openhands.core.logger import openhands_logger as logger
  17. from openhands.events.action import Action
  18. from openhands.events.action.message import MessageAction
  19. class EvalMetadata(BaseModel):
  20. agent_class: str
  21. llm_config: LLMConfig
  22. max_iterations: int
  23. eval_output_dir: str
  24. start_time: str
  25. git_commit: str
  26. dataset: str | None = None
  27. data_split: str | None = None
  28. details: dict[str, Any] | None = None
  29. def model_dump(self, *args, **kwargs):
  30. dumped_dict = super().model_dump(*args, **kwargs)
  31. # avoid leaking sensitive information
  32. dumped_dict['llm_config'] = self.llm_config.to_safe_dict()
  33. return dumped_dict
  34. def model_dump_json(self, *args, **kwargs):
  35. dumped = super().model_dump_json(*args, **kwargs)
  36. dumped_dict = json.loads(dumped)
  37. logger.debug(f'Dumped metadata: {dumped_dict}')
  38. # avoid leaking sensitive information
  39. dumped_dict['llm_config'] = self.llm_config.to_safe_dict()
  40. return json.dumps(dumped_dict)
  41. class EvalOutput(BaseModel):
  42. # NOTE: User-specified
  43. instance_id: str
  44. # output of the evaluation
  45. # store anything that is needed for the score calculation
  46. test_result: dict[str, Any]
  47. instruction: str | None = None
  48. # Interaction info
  49. metadata: EvalMetadata | None = None
  50. # list[tuple[dict[str, Any], dict[str, Any]]] - for compatibility with the old format
  51. history: (
  52. list[dict[str, Any]] | list[tuple[dict[str, Any], dict[str, Any]]] | None
  53. ) = None
  54. llm_completions: list[dict[str, Any]]
  55. metrics: dict[str, Any] | None = None
  56. error: str | None = None
  57. # Optionally save the input test instance
  58. instance: dict[str, Any] | None = None
  59. def model_dump(self, *args, **kwargs):
  60. dumped_dict = super().model_dump(*args, **kwargs)
  61. # Remove None values
  62. dumped_dict = {k: v for k, v in dumped_dict.items() if v is not None}
  63. # Apply custom serialization for metadata (to avoid leaking sensitive information)
  64. if self.metadata is not None:
  65. dumped_dict['metadata'] = self.metadata.model_dump()
  66. return dumped_dict
  67. def model_dump_json(self, *args, **kwargs):
  68. dumped = super().model_dump_json(*args, **kwargs)
  69. dumped_dict = json.loads(dumped)
  70. # Apply custom serialization for metadata (to avoid leaking sensitive information)
  71. if 'metadata' in dumped_dict:
  72. dumped_dict['metadata'] = json.loads(self.metadata.model_dump_json())
  73. return json.dumps(dumped_dict)
  74. def codeact_user_response(
  75. state: State,
  76. encapsulate_solution: bool = False,
  77. try_parse: Callable[[Action], str] | None = None,
  78. ) -> str:
  79. encaps_str = (
  80. (
  81. 'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
  82. 'For example: The answer to the question is <solution> 42 </solution>.\n'
  83. )
  84. if encapsulate_solution
  85. else ''
  86. )
  87. msg = (
  88. 'Please continue working on the task on whatever approach you think is suitable.\n'
  89. 'If you think you have solved the task, please first send your answer to user through message and then <execute_bash> exit </execute_bash>.\n'
  90. f'{encaps_str}'
  91. 'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP.\n'
  92. )
  93. if state.history:
  94. # check if the last action has an answer, if so, early exit
  95. if try_parse is not None:
  96. last_action = state.history.get_last_action()
  97. ans = try_parse(last_action)
  98. if ans is not None:
  99. return '/exit'
  100. # check if the agent has tried to talk to the user 3 times, if so, let the agent know it can give up
  101. user_msgs = [
  102. event
  103. for event in state.history.get_events()
  104. if isinstance(event, MessageAction) and event.source == 'user'
  105. ]
  106. if len(user_msgs) >= 2:
  107. # let the agent know that it can give up when it has tried 3 times
  108. return (
  109. msg
  110. + 'If you want to give up, run: <execute_bash> exit </execute_bash>.\n'
  111. )
  112. return msg
  113. def cleanup():
  114. print('Cleaning up child processes...')
  115. for process in mp.active_children():
  116. print(f'Terminating child process: {process.name}')
  117. process.terminate()
  118. process.join()
  119. def make_metadata(
  120. llm_config: LLMConfig,
  121. dataset_name: str,
  122. agent_class: str,
  123. max_iterations: int,
  124. eval_note: str | None,
  125. eval_output_dir: str,
  126. data_split: str | None = None,
  127. details: dict[str, Any] | None = None,
  128. ) -> EvalMetadata:
  129. model_name = llm_config.model.split('/')[-1]
  130. model_path = model_name.replace(':', '_')
  131. eval_note = f'_N_{eval_note}' if eval_note else ''
  132. eval_output_path = os.path.join(
  133. eval_output_dir,
  134. dataset_name,
  135. agent_class,
  136. f'{model_path}_maxiter_{max_iterations}{eval_note}',
  137. )
  138. pathlib.Path(eval_output_path).mkdir(parents=True, exist_ok=True)
  139. pathlib.Path(os.path.join(eval_output_path, 'logs')).mkdir(
  140. parents=True, exist_ok=True
  141. )
  142. logger.info(f'Using evaluation output directory: {eval_output_path}')
  143. metadata = EvalMetadata(
  144. agent_class=agent_class,
  145. llm_config=llm_config,
  146. max_iterations=max_iterations,
  147. eval_output_dir=eval_output_path,
  148. start_time=time.strftime('%Y-%m-%d %H:%M:%S'),
  149. git_commit=subprocess.check_output(['git', 'rev-parse', 'HEAD'])
  150. .decode('utf-8')
  151. .strip(),
  152. dataset=dataset_name,
  153. data_split=data_split,
  154. details=details,
  155. )
  156. metadata_json = metadata.model_dump_json()
  157. logger.info(f'Metadata: {metadata_json}')
  158. with open(os.path.join(eval_output_path, 'metadata.json'), 'w') as f:
  159. f.write(metadata_json)
  160. return metadata
  161. def prepare_dataset(
  162. dataset: pd.DataFrame,
  163. output_file: str,
  164. eval_n_limit: int,
  165. eval_ids: list[str] | None = None,
  166. skip_num: int | None = None,
  167. ):
  168. assert (
  169. 'instance_id' in dataset.columns
  170. ), "Expected 'instance_id' column in the dataset. You should define your own unique identifier for each instance and use it as the 'instance_id' column."
  171. id_column = 'instance_id'
  172. logger.info(f'Writing evaluation output to {output_file}')
  173. finished_ids: set[str] = set()
  174. if os.path.exists(output_file):
  175. with open(output_file, 'r') as f:
  176. for line in f:
  177. data = json.loads(line)
  178. finished_ids.add(str(data[id_column]))
  179. logger.warning(
  180. f'\nOutput file {output_file} already exists. Loaded {len(finished_ids)} finished instances.'
  181. )
  182. if eval_ids:
  183. eval_ids_converted = [dataset[id_column].dtype.type(id) for id in eval_ids]
  184. dataset = dataset[dataset[id_column].isin(eval_ids_converted)]
  185. logger.info(f'Limiting evaluation to {len(eval_ids)} specific instances.')
  186. elif skip_num and skip_num >= 0:
  187. skip_num = min(skip_num, len(dataset))
  188. dataset = dataset.iloc[skip_num:]
  189. logger.info(
  190. f'Starting evaluation with skipping first {skip_num} instances ({len(dataset)} instances to run).'
  191. )
  192. if eval_n_limit and eval_n_limit > 0:
  193. dataset = dataset.head(eval_n_limit)
  194. logger.info(f'Limiting evaluation to {eval_n_limit} instances.')
  195. elif eval_n_limit and eval_n_limit > 0:
  196. dataset = dataset.head(eval_n_limit)
  197. logger.info(f'Limiting evaluation to first {eval_n_limit} instances.')
  198. new_dataset = [
  199. instance
  200. for _, instance in dataset.iterrows()
  201. if str(instance[id_column]) not in finished_ids
  202. ]
  203. logger.info(
  204. f'Finished instances: {len(finished_ids)}, Remaining instances: {len(new_dataset)}'
  205. )
  206. return pd.DataFrame(new_dataset)
  207. def update_progress(
  208. result: EvalOutput,
  209. pbar: tqdm,
  210. output_fp: TextIO,
  211. ):
  212. """Update the progress bar and write the result to the output file."""
  213. pbar.update(1)
  214. pbar.set_description(f'Instance {result.instance_id}')
  215. pbar.set_postfix_str(f'Test Result: {result.test_result}')
  216. logger.info(
  217. f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
  218. )
  219. output_fp.write(json.dumps(result.model_dump()) + '\n')
  220. output_fp.flush()
  221. def _process_instance_wrapper(
  222. process_instance_func: Callable[[pd.Series, EvalMetadata, bool], EvalOutput],
  223. instance: pd.Series,
  224. metadata: EvalMetadata,
  225. use_mp: bool,
  226. max_retries: int = 5,
  227. ) -> EvalOutput:
  228. """Wrap the process_instance_func to handle retries and errors.
  229. Retry an instance up to max_retries times if it fails (e.g., due to transient network/runtime issues).
  230. """
  231. for attempt in range(max_retries + 1):
  232. try:
  233. result = process_instance_func(instance, metadata, use_mp)
  234. return result
  235. except Exception as e:
  236. error = str(e)
  237. stacktrace = traceback.format_exc()
  238. if attempt == max_retries:
  239. logger.exception(e)
  240. msg = (
  241. '-' * 10
  242. + '\n'
  243. + f'Error in instance [{instance.instance_id}]: {error}. Stacktrace:\n{stacktrace}'
  244. + '\n'
  245. + f'[Encountered after {max_retries} retries. Please check the logs and report the issue.]'
  246. + '-' * 10
  247. )
  248. # Raise an error after all retries & stop the evaluation
  249. logger.exception(e)
  250. raise RuntimeError(
  251. f'Maximum error retries reached for instance {instance.instance_id}'
  252. ) from e
  253. msg = (
  254. '-' * 10
  255. + '\n'
  256. + f'Error in instance [{instance.instance_id}]: {error}. Stacktrace:\n{stacktrace}'
  257. + '\n'
  258. + '-' * 10
  259. + f'[The above error occurred. Retrying... (attempt {attempt + 1} of {max_retries})]'
  260. + '-' * 10
  261. + '\n'
  262. )
  263. logger.error(msg)
  264. if use_mp:
  265. print(msg) # use print to directly print to console
  266. time.sleep(5)
  267. def _process_instance_wrapper_mp(args):
  268. """Wrapper for multiprocessing, especially for imap_unordered."""
  269. return _process_instance_wrapper(*args)
  270. def run_evaluation(
  271. dataset: pd.DataFrame,
  272. metadata: EvalMetadata | None,
  273. output_file: str,
  274. num_workers: int,
  275. process_instance_func: Callable[
  276. [pd.Series, EvalMetadata, bool], Awaitable[EvalOutput]
  277. ],
  278. max_retries: int = 5, # number of retries for each instance
  279. ):
  280. use_multiprocessing = num_workers > 1
  281. if metadata is not None:
  282. logger.info(
  283. f'Evaluation started with Agent {metadata.agent_class}:\n'
  284. f'model {metadata.llm_config.model}, max iterations {metadata.max_iterations}.\n'
  285. )
  286. else:
  287. logger.info(f'Evaluation started with {num_workers} workers.')
  288. total_instances = len(dataset)
  289. pbar = tqdm(total=total_instances, desc='Instances processed')
  290. output_fp = open(output_file, 'a')
  291. try:
  292. if use_multiprocessing:
  293. with mp.Pool(num_workers) as pool:
  294. args_iter = (
  295. (process_instance_func, instance, metadata, True, max_retries)
  296. for _, instance in dataset.iterrows()
  297. )
  298. results = pool.imap_unordered(_process_instance_wrapper_mp, args_iter)
  299. for result in results:
  300. update_progress(result, pbar, output_fp)
  301. else:
  302. for _, instance in dataset.iterrows():
  303. result = _process_instance_wrapper(
  304. process_instance_func=process_instance_func,
  305. instance=instance,
  306. metadata=metadata,
  307. use_mp=False,
  308. max_retries=max_retries,
  309. )
  310. update_progress(result, pbar, output_fp)
  311. except KeyboardInterrupt:
  312. print('\nKeyboardInterrupt received. Cleaning up...\n')
  313. cleanup()
  314. output_fp.close()
  315. logger.info('\nEvaluation finished.\n')
  316. def reset_logger_for_multiprocessing(
  317. logger: logging.Logger, instance_id: str, log_dir: str
  318. ):
  319. """Reset the logger for multiprocessing.
  320. Save logs to a separate file for each process, instead of trying to write to the
  321. same file/console from multiple processes.
  322. """
  323. # Set up logger
  324. log_file = os.path.join(
  325. log_dir,
  326. f'instance_{instance_id}.log',
  327. )
  328. # Remove all existing handlers from logger
  329. for handler in logger.handlers[:]:
  330. logger.removeHandler(handler)
  331. # add console handler to print ONE line
  332. console_handler = get_console_handler(log_level=logging.INFO)
  333. console_handler.setFormatter(
  334. logging.Formatter(
  335. f'Instance {instance_id} - ' + '%(asctime)s - %(levelname)s - %(message)s'
  336. )
  337. )
  338. logger.addHandler(console_handler)
  339. logger.info(
  340. f'Starting evaluation for instance {instance_id}.\n'
  341. f'Hint: run "tail -f {log_file}" to see live logs in a separate shell'
  342. )
  343. # Only log WARNING or higher to console
  344. console_handler.setLevel(logging.WARNING)
  345. # Log INFO and above to file
  346. os.makedirs(os.path.dirname(log_file), exist_ok=True)
  347. file_handler = logging.FileHandler(log_file)
  348. file_handler.setFormatter(
  349. logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
  350. )
  351. file_handler.setLevel(logging.INFO)
  352. logger.addHandler(file_handler)