shared.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. import json
  2. import logging
  3. import multiprocessing as mp
  4. import os
  5. import pathlib
  6. import signal
  7. import subprocess
  8. import time
  9. import traceback
  10. from contextlib import contextmanager
  11. from inspect import signature
  12. from typing import Any, Awaitable, Callable, TextIO
  13. import pandas as pd
  14. from pydantic import BaseModel
  15. from tqdm import tqdm
  16. from openhands.controller.state.state import State
  17. from openhands.core.config import LLMConfig
  18. from openhands.core.exceptions import (
  19. AgentRuntimeBuildError,
  20. AgentRuntimeDisconnectedError,
  21. AgentRuntimeError,
  22. AgentRuntimeNotFoundError,
  23. AgentRuntimeNotReadyError,
  24. AgentRuntimeTimeoutError,
  25. AgentRuntimeUnavailableError,
  26. )
  27. from openhands.core.logger import get_console_handler
  28. from openhands.core.logger import openhands_logger as logger
  29. from openhands.events.action import Action
  30. from openhands.events.action.message import MessageAction
  31. from openhands.events.event import Event
  32. from openhands.events.serialization.event import event_to_dict
  33. from openhands.events.utils import get_pairs_from_events
  34. class EvalMetadata(BaseModel):
  35. agent_class: str
  36. llm_config: LLMConfig
  37. max_iterations: int
  38. eval_output_dir: str
  39. start_time: str
  40. git_commit: str
  41. dataset: str | None = None
  42. data_split: str | None = None
  43. details: dict[str, Any] | None = None
  44. def model_dump(self, *args, **kwargs):
  45. dumped_dict = super().model_dump(*args, **kwargs)
  46. # avoid leaking sensitive information
  47. dumped_dict['llm_config'] = self.llm_config.to_safe_dict()
  48. return dumped_dict
  49. def model_dump_json(self, *args, **kwargs):
  50. dumped = super().model_dump_json(*args, **kwargs)
  51. dumped_dict = json.loads(dumped)
  52. # avoid leaking sensitive information
  53. dumped_dict['llm_config'] = self.llm_config.to_safe_dict()
  54. logger.debug(f'Dumped metadata: {dumped_dict}')
  55. return json.dumps(dumped_dict)
  56. class EvalOutput(BaseModel):
  57. # NOTE: User-specified
  58. instance_id: str
  59. # output of the evaluation
  60. # store anything that is needed for the score calculation
  61. test_result: dict[str, Any]
  62. instruction: str | None = None
  63. # Interaction info
  64. metadata: EvalMetadata | None = None
  65. # list[tuple[dict[str, Any], dict[str, Any]]] - for compatibility with the old format
  66. history: (
  67. list[dict[str, Any]] | list[tuple[dict[str, Any], dict[str, Any]]] | None
  68. ) = None
  69. metrics: dict[str, Any] | None = None
  70. error: str | None = None
  71. # Optionally save the input test instance
  72. instance: dict[str, Any] | None = None
  73. def model_dump(self, *args, **kwargs):
  74. dumped_dict = super().model_dump(*args, **kwargs)
  75. # Remove None values
  76. dumped_dict = {k: v for k, v in dumped_dict.items() if v is not None}
  77. # Apply custom serialization for metadata (to avoid leaking sensitive information)
  78. if self.metadata is not None:
  79. dumped_dict['metadata'] = self.metadata.model_dump()
  80. return dumped_dict
  81. def model_dump_json(self, *args, **kwargs):
  82. dumped = super().model_dump_json(*args, **kwargs)
  83. dumped_dict = json.loads(dumped)
  84. # Apply custom serialization for metadata (to avoid leaking sensitive information)
  85. if 'metadata' in dumped_dict:
  86. dumped_dict['metadata'] = json.loads(self.metadata.model_dump_json())
  87. return json.dumps(dumped_dict)
  88. class EvalException(Exception):
  89. pass
  90. class EvalTimeoutException(Exception):
  91. pass
  92. @contextmanager
  93. def timeout(seconds: int):
  94. def timeout_handler(signum, frame):
  95. raise EvalTimeoutException(f'Function timed out after {seconds} seconds')
  96. # Set up the signal handler
  97. original_handler = signal.signal(signal.SIGALRM, timeout_handler)
  98. signal.alarm(seconds)
  99. try:
  100. yield
  101. finally:
  102. # Restore the original handler and disable the alarm
  103. signal.alarm(0)
  104. signal.signal(signal.SIGALRM, original_handler)
  105. def codeact_user_response(
  106. state: State,
  107. encapsulate_solution: bool = False,
  108. try_parse: Callable[[Action], str] | None = None,
  109. ) -> str:
  110. encaps_str = (
  111. (
  112. 'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
  113. 'For example: The answer to the question is <solution> 42 </solution>.\n'
  114. )
  115. if encapsulate_solution
  116. else ''
  117. )
  118. msg = (
  119. 'Please continue working on the task on whatever approach you think is suitable.\n'
  120. 'If you think you have solved the task, please first send your answer to user through message and then finish the interaction.\n'
  121. f'{encaps_str}'
  122. 'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP.\n'
  123. )
  124. if state.history:
  125. # check if the last action has an answer, if so, early exit
  126. if try_parse is not None:
  127. last_action = next(
  128. (
  129. event
  130. for event in reversed(state.history)
  131. if isinstance(event, Action)
  132. ),
  133. None,
  134. )
  135. ans = try_parse(last_action)
  136. if ans is not None:
  137. return '/exit'
  138. # check if the agent has tried to talk to the user 3 times, if so, let the agent know it can give up
  139. user_msgs = [
  140. event
  141. for event in state.history
  142. if isinstance(event, MessageAction) and event.source == 'user'
  143. ]
  144. if len(user_msgs) >= 2:
  145. # let the agent know that it can give up when it has tried 3 times
  146. return (
  147. msg
  148. + 'If you want to give up, use the "finish" tool to finish the interaction.\n'
  149. )
  150. return msg
  151. def cleanup():
  152. print('Cleaning up child processes...')
  153. for process in mp.active_children():
  154. print(f'Terminating child process: {process.name}')
  155. process.terminate()
  156. process.join()
  157. def make_metadata(
  158. llm_config: LLMConfig,
  159. dataset_name: str,
  160. agent_class: str,
  161. max_iterations: int,
  162. eval_note: str | None,
  163. eval_output_dir: str,
  164. data_split: str | None = None,
  165. details: dict[str, Any] | None = None,
  166. ) -> EvalMetadata:
  167. model_name = llm_config.model.split('/')[-1]
  168. model_path = model_name.replace(':', '_').replace('@', '-')
  169. eval_note = f'_N_{eval_note}' if eval_note else ''
  170. eval_output_path = os.path.join(
  171. eval_output_dir,
  172. dataset_name,
  173. agent_class,
  174. f'{model_path}_maxiter_{max_iterations}{eval_note}',
  175. )
  176. pathlib.Path(eval_output_path).mkdir(parents=True, exist_ok=True)
  177. pathlib.Path(os.path.join(eval_output_path, 'logs')).mkdir(
  178. parents=True, exist_ok=True
  179. )
  180. logger.info(f'Using evaluation output directory: {eval_output_path}')
  181. metadata = EvalMetadata(
  182. agent_class=agent_class,
  183. llm_config=llm_config,
  184. max_iterations=max_iterations,
  185. eval_output_dir=eval_output_path,
  186. start_time=time.strftime('%Y-%m-%d %H:%M:%S'),
  187. git_commit=subprocess.check_output(['git', 'rev-parse', 'HEAD'])
  188. .decode('utf-8')
  189. .strip(),
  190. dataset=dataset_name,
  191. data_split=data_split,
  192. details=details,
  193. )
  194. metadata_json = metadata.model_dump_json()
  195. logger.info(f'Metadata: {metadata_json}')
  196. with open(os.path.join(eval_output_path, 'metadata.json'), 'w') as f:
  197. f.write(metadata_json)
  198. return metadata
  199. def prepare_dataset(
  200. dataset: pd.DataFrame,
  201. output_file: str,
  202. eval_n_limit: int,
  203. eval_ids: list[str] | None = None,
  204. skip_num: int | None = None,
  205. ):
  206. assert (
  207. 'instance_id' in dataset.columns
  208. ), "Expected 'instance_id' column in the dataset. You should define your own unique identifier for each instance and use it as the 'instance_id' column."
  209. id_column = 'instance_id'
  210. logger.info(f'Writing evaluation output to {output_file}')
  211. finished_ids: set[str] = set()
  212. if os.path.exists(output_file):
  213. with open(output_file, 'r') as f:
  214. for line in f:
  215. data = json.loads(line)
  216. finished_ids.add(str(data[id_column]))
  217. logger.warning(
  218. f'\nOutput file {output_file} already exists. Loaded {len(finished_ids)} finished instances.'
  219. )
  220. if eval_ids:
  221. eval_ids_converted = [dataset[id_column].dtype.type(id) for id in eval_ids]
  222. dataset = dataset[dataset[id_column].isin(eval_ids_converted)]
  223. logger.info(f'Limiting evaluation to {len(eval_ids)} specific instances.')
  224. elif skip_num and skip_num >= 0:
  225. skip_num = min(skip_num, len(dataset))
  226. dataset = dataset.iloc[skip_num:]
  227. logger.info(
  228. f'Starting evaluation with skipping first {skip_num} instances ({len(dataset)} instances to run).'
  229. )
  230. if eval_n_limit and eval_n_limit > 0:
  231. dataset = dataset.head(eval_n_limit)
  232. logger.info(f'Limiting evaluation to {eval_n_limit} instances.')
  233. elif eval_n_limit and eval_n_limit > 0:
  234. dataset = dataset.head(eval_n_limit)
  235. logger.info(f'Limiting evaluation to first {eval_n_limit} instances.')
  236. new_dataset = [
  237. instance
  238. for _, instance in dataset.iterrows()
  239. if str(instance[id_column]) not in finished_ids
  240. ]
  241. logger.info(
  242. f'Finished instances: {len(finished_ids)}, Remaining instances: {len(new_dataset)}'
  243. )
  244. return pd.DataFrame(new_dataset)
  245. def update_progress(
  246. result: EvalOutput,
  247. pbar: tqdm,
  248. output_fp: TextIO,
  249. ):
  250. """Update the progress bar and write the result to the output file."""
  251. pbar.update(1)
  252. pbar.set_description(f'Instance {result.instance_id}')
  253. pbar.set_postfix_str(f'Test Result: {str(result.test_result)[:300]}...')
  254. logger.info(
  255. f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
  256. )
  257. output_fp.write(json.dumps(result.model_dump()) + '\n')
  258. output_fp.flush()
  259. def assert_and_raise(condition: bool, msg: str):
  260. """Raise an EvalException if the condition is not met.
  261. This will be used in conjunction with _process_instance_wrapper to handle retries. An EvalException should trigger a retry.
  262. """
  263. if not condition:
  264. raise EvalException(msg)
  265. def _process_instance_wrapper(
  266. process_instance_func: Callable[[pd.Series, EvalMetadata, bool], EvalOutput],
  267. instance: pd.Series,
  268. metadata: EvalMetadata,
  269. use_mp: bool,
  270. max_retries: int = 5,
  271. timeout_seconds: int | None = None,
  272. ) -> EvalOutput:
  273. """Wrap the process_instance_func to handle retries and errors."""
  274. runtime_failure_count = 0
  275. for attempt in range(max_retries + 1):
  276. try:
  277. kwargs = {}
  278. # check if process_instance_func accepts timeout_seconds parameter
  279. sig = signature(process_instance_func)
  280. if 'runtime_failure_count' in sig.parameters:
  281. kwargs['runtime_failure_count'] = runtime_failure_count
  282. if timeout_seconds is not None:
  283. with timeout(timeout_seconds):
  284. result = process_instance_func(instance, metadata, use_mp, **kwargs)
  285. else:
  286. result = process_instance_func(instance, metadata, use_mp, **kwargs)
  287. return result
  288. except EvalTimeoutException as e:
  289. error = f'Timeout after {timeout_seconds} seconds'
  290. stacktrace = traceback.format_exc()
  291. msg = (
  292. '-' * 10
  293. + '\n'
  294. + f'Timeout ({timeout_seconds} seconds) in instance [{instance.instance_id}], Stopped evaluation for this instance.'
  295. + '\n'
  296. + '-' * 10
  297. )
  298. logger.exception(e)
  299. return EvalOutput(
  300. instance_id=instance.instance_id,
  301. test_result={},
  302. error=error,
  303. )
  304. except Exception as e:
  305. error = str(e)
  306. stacktrace = traceback.format_exc()
  307. if attempt == max_retries:
  308. logger.exception(e)
  309. msg = (
  310. '-' * 10
  311. + '\n'
  312. + f'Error in instance [{instance.instance_id}]: {error}. Stacktrace:\n{stacktrace}'
  313. + '\n'
  314. + f'[Encountered after {max_retries} retries. Please check the logs and report the issue.]'
  315. + '-' * 10
  316. )
  317. # Raise an error after all retries & stop the evaluation
  318. logger.exception(e)
  319. raise RuntimeError(
  320. f'Maximum error retries reached for instance {instance.instance_id}'
  321. ) from e
  322. msg = (
  323. '-' * 10
  324. + '\n'
  325. + f'Error in instance [{instance.instance_id}]: {error}. Stacktrace:\n{stacktrace}'
  326. + '\n'
  327. + '-' * 10
  328. + f'[The above error occurred. Retrying... (attempt {attempt + 1} of {max_retries})]'
  329. + '-' * 10
  330. + '\n'
  331. )
  332. if isinstance(
  333. e, (AgentRuntimeDisconnectedError, AgentRuntimeUnavailableError)
  334. ):
  335. runtime_failure_count += 1
  336. msg += f'Runtime disconnected error detected for instance {instance.instance_id}, runtime failure count: {runtime_failure_count}'
  337. logger.error(msg)
  338. if use_mp:
  339. print(msg) # use print to directly print to console
  340. time.sleep(5)
  341. def _process_instance_wrapper_mp(args):
  342. """Wrapper for multiprocessing, especially for imap_unordered."""
  343. return _process_instance_wrapper(*args)
  344. def run_evaluation(
  345. dataset: pd.DataFrame,
  346. metadata: EvalMetadata | None,
  347. output_file: str,
  348. num_workers: int,
  349. process_instance_func: Callable[
  350. [pd.Series, EvalMetadata, bool], Awaitable[EvalOutput]
  351. ],
  352. max_retries: int = 5, # number of retries for each instance
  353. timeout_seconds: int | None = None,
  354. ):
  355. use_multiprocessing = num_workers > 1
  356. if metadata is not None:
  357. logger.info(
  358. f'Evaluation started with Agent {metadata.agent_class}:\n'
  359. f'model {metadata.llm_config.model}, max iterations {metadata.max_iterations}.\n'
  360. )
  361. else:
  362. logger.warning('Running evaluation without metadata.')
  363. logger.info(f'Evaluation started with {num_workers} workers.')
  364. total_instances = len(dataset)
  365. pbar = tqdm(total=total_instances, desc='Instances processed')
  366. output_fp = open(output_file, 'a')
  367. try:
  368. if use_multiprocessing:
  369. with mp.Pool(num_workers) as pool:
  370. args_iter = (
  371. (
  372. process_instance_func,
  373. instance,
  374. metadata,
  375. True,
  376. max_retries,
  377. timeout_seconds,
  378. )
  379. for _, instance in dataset.iterrows()
  380. )
  381. results = pool.imap_unordered(_process_instance_wrapper_mp, args_iter)
  382. for result in results:
  383. update_progress(result, pbar, output_fp)
  384. else:
  385. for _, instance in dataset.iterrows():
  386. result = _process_instance_wrapper(
  387. process_instance_func=process_instance_func,
  388. instance=instance,
  389. metadata=metadata,
  390. use_mp=False,
  391. max_retries=max_retries,
  392. )
  393. update_progress(result, pbar, output_fp)
  394. except KeyboardInterrupt:
  395. print('\nKeyboardInterrupt received. Cleaning up...\n')
  396. cleanup()
  397. output_fp.close()
  398. logger.info('\nEvaluation finished.\n')
  399. def reset_logger_for_multiprocessing(
  400. logger: logging.Logger, instance_id: str, log_dir: str
  401. ):
  402. """Reset the logger for multiprocessing.
  403. Save logs to a separate file for each process, instead of trying to write to the
  404. same file/console from multiple processes.
  405. """
  406. # Set up logger
  407. log_file = os.path.join(
  408. log_dir,
  409. f'instance_{instance_id}.log',
  410. )
  411. # Remove all existing handlers from logger
  412. for handler in logger.handlers[:]:
  413. logger.removeHandler(handler)
  414. # add console handler to print ONE line
  415. console_handler = get_console_handler(log_level=logging.INFO)
  416. console_handler.setFormatter(
  417. logging.Formatter(
  418. f'Instance {instance_id} - ' + '%(asctime)s - %(levelname)s - %(message)s'
  419. )
  420. )
  421. logger.addHandler(console_handler)
  422. logger.info(
  423. f'Starting evaluation for instance {instance_id}.\n'
  424. f'Hint: run "tail -f {log_file}" to see live logs in a separate shell'
  425. )
  426. # Only log WARNING or higher to console
  427. console_handler.setLevel(logging.WARNING)
  428. # Log INFO and above to file
  429. os.makedirs(os.path.dirname(log_file), exist_ok=True)
  430. file_handler = logging.FileHandler(log_file)
  431. file_handler.setFormatter(
  432. logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
  433. )
  434. file_handler.setLevel(logging.INFO)
  435. logger.addHandler(file_handler)
  436. def update_llm_config_for_completions_logging(
  437. llm_config: LLMConfig,
  438. eval_output_dir: str,
  439. instance_id: str,
  440. ) -> LLMConfig:
  441. """Update the LLM config for logging completions."""
  442. if llm_config.log_completions:
  443. llm_config.log_completions_folder = os.path.join(
  444. eval_output_dir, 'llm_completions', instance_id
  445. )
  446. logger.info(
  447. f'Logging LLM completions for instance {instance_id} to '
  448. f'{llm_config.log_completions_folder}'
  449. )
  450. return llm_config
  451. # history is now available as a filtered stream of events, rather than list of pairs of (Action, Observation)
  452. # we rebuild the pairs here
  453. # for compatibility with the existing output format in evaluations
  454. # remove this when it's no longer necessary
  455. def compatibility_for_eval_history_pairs(
  456. history: list[Event],
  457. ) -> list[tuple[dict, dict]]:
  458. history_pairs = []
  459. for action, observation in get_pairs_from_events(history):
  460. history_pairs.append((event_to_dict(action), event_to_dict(observation)))
  461. return history_pairs
  462. def is_fatal_evaluation_error(error: str | None) -> bool:
  463. if not error:
  464. return False
  465. FATAL_EXCEPTIONS = [
  466. AgentRuntimeError,
  467. AgentRuntimeBuildError,
  468. AgentRuntimeTimeoutError,
  469. AgentRuntimeUnavailableError,
  470. AgentRuntimeNotReadyError,
  471. AgentRuntimeDisconnectedError,
  472. AgentRuntimeNotFoundError,
  473. ]
  474. if any(exception.__name__ in error for exception in FATAL_EXCEPTIONS):
  475. logger.error(f'Fatal evaluation error detected: {error}')
  476. return True
  477. return False