test_stress_remote_runtime.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. """Bash-related tests for the EventStreamRuntime, which connects to the ActionExecutor running in the sandbox."""
  2. import asyncio
  3. import os
  4. import tempfile
  5. from unittest.mock import MagicMock
  6. import pandas as pd
  7. import pytest
  8. from conftest import TEST_IN_CI
  9. from evaluation.utils.shared import (
  10. EvalException,
  11. EvalMetadata,
  12. EvalOutput,
  13. assert_and_raise,
  14. codeact_user_response,
  15. make_metadata,
  16. prepare_dataset,
  17. reset_logger_for_multiprocessing,
  18. run_evaluation,
  19. )
  20. from openhands.agenthub import Agent
  21. from openhands.controller.state.state import State
  22. from openhands.core.config import (
  23. AgentConfig,
  24. AppConfig,
  25. LLMConfig,
  26. SandboxConfig,
  27. )
  28. from openhands.core.logger import openhands_logger as logger
  29. from openhands.core.main import create_runtime, run_controller
  30. from openhands.events.action import CmdRunAction, MessageAction
  31. from openhands.events.observation import CmdOutputObservation
  32. from openhands.events.serialization.event import event_to_dict
  33. from openhands.llm import LLM
  34. from openhands.runtime.base import Runtime
  35. from openhands.utils.async_utils import call_async_from_sync
  36. AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
  37. 'CodeActAgent': codeact_user_response,
  38. }
  39. def get_config(
  40. metadata: EvalMetadata,
  41. ) -> AppConfig:
  42. assert (
  43. os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL') is not None
  44. ), 'SANDBOX_REMOTE_RUNTIME_API_URL must be set.'
  45. assert (
  46. os.environ.get('ALLHANDS_API_KEY') is not None
  47. ), 'ALLHANDS_API_KEY must be set.'
  48. config = AppConfig(
  49. default_agent=metadata.agent_class,
  50. run_as_openhands=False,
  51. max_iterations=metadata.max_iterations,
  52. runtime='remote',
  53. sandbox=SandboxConfig(
  54. base_container_image='python:3.11-bookworm',
  55. enable_auto_lint=True,
  56. use_host_network=False,
  57. # large enough timeout, since some testcases take very long to run
  58. timeout=300,
  59. api_key=os.environ.get('ALLHANDS_API_KEY', None),
  60. remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
  61. keep_runtime_alive=False,
  62. ),
  63. # do not mount workspace
  64. workspace_base=None,
  65. workspace_mount_path=None,
  66. )
  67. agent_config = AgentConfig(
  68. codeact_enable_jupyter=False,
  69. codeact_enable_browsing=False,
  70. codeact_enable_llm_editor=False,
  71. )
  72. config.set_agent_config(agent_config)
  73. return config
  74. def initialize_runtime(
  75. runtime: Runtime,
  76. ):
  77. """Initialize the runtime for the agent.
  78. This function is called before the runtime is used to run the agent.
  79. """
  80. logger.info('-' * 30)
  81. logger.info('BEGIN Runtime Initialization Fn')
  82. logger.info('-' * 30)
  83. obs: CmdOutputObservation
  84. action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
  85. action.timeout = 600
  86. logger.info(action, extra={'msg_type': 'ACTION'})
  87. obs = runtime.run_action(action)
  88. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  89. assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {str(obs)}')
  90. action = CmdRunAction(command='mkdir -p /dummy_dir')
  91. action.timeout = 600
  92. logger.info(action, extra={'msg_type': 'ACTION'})
  93. obs = runtime.run_action(action)
  94. logger.info(obs, extra={'msg_type': 'OBSERVATION'})
  95. assert_and_raise(
  96. obs.exit_code == 0,
  97. f'Failed to create /dummy_dir: {str(obs)}',
  98. )
  99. with tempfile.TemporaryDirectory() as temp_dir:
  100. # Construct the full path for the desired file name within the temporary directory
  101. temp_file_path = os.path.join(temp_dir, 'dummy_file')
  102. # Write to the file with the desired name within the temporary directory
  103. with open(temp_file_path, 'w') as f:
  104. f.write('dummy content')
  105. # Copy the file to the desired location
  106. runtime.copy_to(temp_file_path, '/dummy_dir/')
  107. logger.info('-' * 30)
  108. logger.info('END Runtime Initialization Fn')
  109. logger.info('-' * 30)
  110. def process_instance(
  111. instance: pd.Series,
  112. metadata: EvalMetadata,
  113. reset_logger: bool = True,
  114. ) -> EvalOutput:
  115. config = get_config(metadata)
  116. # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
  117. if reset_logger:
  118. log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
  119. reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
  120. else:
  121. logger.info(f'Starting evaluation for instance {instance.instance_id}.')
  122. runtime = create_runtime(config, headless_mode=False)
  123. call_async_from_sync(runtime.connect)
  124. try:
  125. initialize_runtime(runtime)
  126. instruction = 'dummy instruction'
  127. agent = Agent.get_cls(metadata.agent_class)(
  128. llm=LLM(config=metadata.llm_config),
  129. config=config.get_agent_config(metadata.agent_class),
  130. )
  131. def next_command(*args, **kwargs):
  132. return CmdRunAction(command='ls -lah')
  133. agent.step = MagicMock(side_effect=next_command)
  134. # Here's how you can run the agent (similar to the `main` function) and get the final task state
  135. state: State | None = asyncio.run(
  136. run_controller(
  137. config=config,
  138. initial_user_action=MessageAction(content=instruction),
  139. runtime=runtime,
  140. fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
  141. metadata.agent_class
  142. ],
  143. agent=agent,
  144. )
  145. )
  146. # if fatal error, throw EvalError to trigger re-run
  147. if (
  148. state.last_error
  149. and 'fatal error during agent execution' in state.last_error
  150. and 'stuck in a loop' not in state.last_error
  151. ):
  152. raise EvalException('Fatal error detected: ' + state.last_error)
  153. finally:
  154. runtime.close()
  155. test_result = {}
  156. if state is None:
  157. raise ValueError('State should not be None.')
  158. histories = [event_to_dict(event) for event in state.history]
  159. metrics = state.metrics.get() if state.metrics else None
  160. # Save the output
  161. output = EvalOutput(
  162. instance_id=instance.instance_id,
  163. instruction=instruction,
  164. instance=instance.to_dict(), # SWE Bench specific
  165. test_result=test_result,
  166. metadata=metadata,
  167. history=histories,
  168. metrics=metrics,
  169. error=state.last_error if state and state.last_error else None,
  170. )
  171. return output
  172. @pytest.mark.skipif(
  173. TEST_IN_CI,
  174. reason='This test should only be run locally, not in CI.',
  175. )
  176. def test_stress_remote_runtime(n_eval_workers: int = 64):
  177. """Mimic evaluation setting to test remote runtime in a multi-processing setting."""
  178. llm_config = LLMConfig()
  179. metadata = make_metadata(
  180. llm_config,
  181. 'dummy_dataset_descrption',
  182. 'CodeActAgent',
  183. max_iterations=10,
  184. eval_note='dummy_eval_note',
  185. eval_output_dir='./dummy_eval_output_dir',
  186. details={},
  187. )
  188. # generate 300 random dummy instances
  189. dummy_instance = pd.DataFrame(
  190. {
  191. 'instance_id': [f'dummy_instance_{i}' for i in range(300)],
  192. }
  193. )
  194. output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
  195. instances = prepare_dataset(
  196. dummy_instance, output_file, eval_n_limit=len(dummy_instance)
  197. )
  198. run_evaluation(instances, metadata, output_file, n_eval_workers, process_instance)