run_infer.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import asyncio
  2. import importlib.util
  3. import os
  4. import pandas as pd
  5. from evaluation.integration_tests.tests.base import BaseIntegrationTest, TestResult
  6. from evaluation.utils.shared import (
  7. EvalMetadata,
  8. EvalOutput,
  9. codeact_user_response,
  10. make_metadata,
  11. prepare_dataset,
  12. reset_logger_for_multiprocessing,
  13. run_evaluation,
  14. update_llm_config_for_completions_logging,
  15. )
  16. from openhands.controller.state.state import State
  17. from openhands.core.config import (
  18. AgentConfig,
  19. AppConfig,
  20. SandboxConfig,
  21. get_llm_config_arg,
  22. parse_arguments,
  23. )
  24. from openhands.core.logger import openhands_logger as logger
  25. from openhands.core.main import create_runtime, run_controller
  26. from openhands.events.action import MessageAction
  27. from openhands.events.serialization.event import event_to_dict
  28. from openhands.runtime.base import Runtime
  29. from openhands.utils.async_utils import call_async_from_sync
  30. FAKE_RESPONSES = {
  31. 'CodeActAgent': codeact_user_response,
  32. }
  33. def get_config(
  34. metadata: EvalMetadata,
  35. instance_id: str,
  36. ) -> AppConfig:
  37. config = AppConfig(
  38. default_agent=metadata.agent_class,
  39. run_as_openhands=False,
  40. runtime=os.environ.get('RUNTIME', 'eventstream'),
  41. max_iterations=metadata.max_iterations,
  42. sandbox=SandboxConfig(
  43. # use default base_container_image
  44. enable_auto_lint=True,
  45. use_host_network=False,
  46. timeout=300,
  47. # Add platform to the sandbox config to solve issue 4401
  48. platform='linux/amd64',
  49. api_key=os.environ.get('ALLHANDS_API_KEY', None),
  50. remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
  51. keep_runtime_alive=False,
  52. remote_runtime_init_timeout=3600,
  53. ),
  54. # do not mount workspace
  55. workspace_base=None,
  56. workspace_mount_path=None,
  57. # debug
  58. debug=True,
  59. )
  60. config.set_llm_config(
  61. update_llm_config_for_completions_logging(
  62. metadata.llm_config, metadata.eval_output_dir, instance_id
  63. )
  64. )
  65. agent_config = AgentConfig(
  66. codeact_enable_jupyter=True,
  67. codeact_enable_browsing=True,
  68. codeact_enable_llm_editor=False,
  69. )
  70. config.set_agent_config(agent_config)
  71. return config
  72. def process_instance(
  73. instance: pd.Series,
  74. metadata: EvalMetadata,
  75. reset_logger: bool = True,
  76. ) -> EvalOutput:
  77. config = get_config(metadata, instance.instance_id)
  78. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  79. if reset_logger:
  80. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  81. reset_logger_for_multiprocessing(logger, str(instance.instance_id), log_dir)
  82. else:
  83. logger.info(
  84. f'\nStarting evaluation for instance {str(instance.instance_id)}.\n'
  85. )
  86. # =============================================
  87. # import test instance
  88. # =============================================
  89. instance_id = instance.instance_id
  90. spec = importlib.util.spec_from_file_location(instance_id, instance.file_path)
  91. test_module = importlib.util.module_from_spec(spec)
  92. spec.loader.exec_module(test_module)
  93. assert hasattr(
  94. test_module, 'Test'
  95. ), f'Test module {instance_id} does not have a Test class'
  96. test_class: type[BaseIntegrationTest] = test_module.Test
  97. assert issubclass(
  98. test_class, BaseIntegrationTest
  99. ), f'Test class {instance_id} does not inherit from BaseIntegrationTest'
  100. instruction = test_class.INSTRUCTION
  101. # =============================================
  102. # create sandbox and run the agent
  103. # =============================================
  104. runtime: Runtime = create_runtime(config)
  105. call_async_from_sync(runtime.connect)
  106. try:
  107. test_class.initialize_runtime(runtime)
  108. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  109. state: State | None = asyncio.run(
  110. run_controller(
  111. config=config,
  112. initial_user_action=MessageAction(content=instruction),
  113. runtime=runtime,
  114. fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
  115. )
  116. )
  117. if state is None:
  118. raise ValueError('State should not be None.')
  119. # # =============================================
  120. # # result evaluation
  121. # # =============================================
  122. histories = state.history
  123. # some basic check
  124. logger.info(f'Total events in history: {len(histories)}')
  125. assert len(histories) > 0, 'History should not be empty'
  126. test_result: TestResult = test_class.verify_result(runtime, histories)
  127. metrics = state.metrics.get() if state.metrics else None
  128. finally:
  129. runtime.close()
  130. # Save the output
  131. output = EvalOutput(
  132. instance_id=str(instance.instance_id),
  133. instance=instance.to_dict(),
  134. instruction=instruction,
  135. metadata=metadata,
  136. history=[event_to_dict(event) for event in histories],
  137. metrics=metrics,
  138. error=state.last_error if state and state.last_error else None,
  139. test_result=test_result.model_dump(),
  140. )
  141. return output
  142. def load_integration_tests() -> pd.DataFrame:
  143. """Load tests from python files under ./tests"""
  144. cur_dir = os.path.dirname(os.path.abspath(__file__))
  145. test_dir = os.path.join(cur_dir, 'tests')
  146. test_files = [
  147. os.path.join(test_dir, f)
  148. for f in os.listdir(test_dir)
  149. if f.startswith('t') and f.endswith('.py')
  150. ]
  151. df = pd.DataFrame(test_files, columns=['file_path'])
  152. df['instance_id'] = df['file_path'].apply(
  153. lambda x: os.path.basename(x).rstrip('.py')
  154. )
  155. return df
  156. if __name__ == '__main__':
  157. args = parse_arguments()
  158. integration_tests = load_integration_tests()
  159. llm_config = None
  160. if args.llm_config:
  161. llm_config = get_llm_config_arg(args.llm_config)
  162. if llm_config is None:
  163. raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
  164. metadata = make_metadata(
  165. llm_config,
  166. 'integration_tests',
  167. args.agent_cls,
  168. args.max_iterations,
  169. args.eval_note,
  170. args.eval_output_dir,
  171. )
  172. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  173. # Parse dataset IDs if provided
  174. eval_ids = None
  175. if args.eval_ids:
  176. eval_ids = str(args.eval_ids).split(',')
  177. logger.info(f'\nUsing specific dataset IDs: {eval_ids}\n')
  178. instances = prepare_dataset(
  179. integration_tests,
  180. output_file,
  181. args.eval_n_limit,
  182. eval_ids=eval_ids,
  183. )
  184. run_evaluation(
  185. instances,
  186. metadata,
  187. output_file,
  188. args.eval_num_workers,
  189. process_instance,
  190. )
  191. df = pd.read_json(output_file, lines=True, orient='records')
  192. # record success and reason for failure for the final report
  193. df['success'] = df['test_result'].apply(lambda x: x['success'])
  194. df['reason'] = df['test_result'].apply(lambda x: x['reason'])
  195. logger.info('-' * 100)
  196. logger.info(
  197. f'Success rate: {df["success"].mean():.2%} ({df["success"].sum()}/{len(df)})'
  198. )
  199. logger.info(
  200. '\nEvaluation Results:'
  201. + '\n'
  202. + df[['instance_id', 'success', 'reason']].to_string(index=False)
  203. )
  204. logger.info('-' * 100)
  205. # record cost for each instance, with 3 decimal places
  206. df['cost'] = df['metrics'].apply(lambda x: round(x['accumulated_cost'], 3))
  207. logger.info(f'Total cost: USD {df["cost"].sum():.2f}')
  208. report_file = os.path.join(metadata.eval_output_dir, 'report.md')
  209. with open(report_file, 'w') as f:
  210. f.write(
  211. f'Success rate: {df["success"].mean():.2%} ({df["success"].sum()}/{len(df)})\n'
  212. )
  213. f.write(f'\nTotal cost: USD {df["cost"].sum():.2f}\n')
  214. f.write(
  215. df[['instance_id', 'success', 'reason', 'cost']].to_markdown(index=False)
  216. )