runtime.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. import os
  2. import tempfile
  3. import threading
  4. from typing import Callable, Optional
  5. from zipfile import ZipFile
  6. import requests
  7. from requests.exceptions import Timeout
  8. from tenacity import (
  9. retry,
  10. retry_if_exception_type,
  11. stop_after_attempt,
  12. wait_fixed,
  13. )
  14. from openhands.core.config import AppConfig
  15. from openhands.core.logger import openhands_logger as logger
  16. from openhands.events import EventStream
  17. from openhands.events.action import (
  18. BrowseInteractiveAction,
  19. BrowseURLAction,
  20. CmdRunAction,
  21. FileReadAction,
  22. FileWriteAction,
  23. IPythonRunCellAction,
  24. )
  25. from openhands.events.action.action import Action
  26. from openhands.events.observation import (
  27. ErrorObservation,
  28. NullObservation,
  29. Observation,
  30. )
  31. from openhands.events.serialization import event_to_dict, observation_from_dict
  32. from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
  33. from openhands.runtime.builder.remote import RemoteRuntimeBuilder
  34. from openhands.runtime.plugins import PluginRequirement
  35. from openhands.runtime.runtime import Runtime
  36. from openhands.runtime.utils.request import (
  37. DEFAULT_RETRY_EXCEPTIONS,
  38. is_404_error,
  39. send_request_with_retry,
  40. )
  41. from openhands.runtime.utils.runtime_build import build_runtime_image
  42. from openhands.utils.tenacity_stop import stop_if_should_exit
  43. class RemoteRuntime(Runtime):
  44. """This runtime will connect to a remote oh-runtime-client."""
  45. port: int = 60000 # default port for the remote runtime client
  46. def __init__(
  47. self,
  48. config: AppConfig,
  49. event_stream: EventStream,
  50. sid: str = 'default',
  51. plugins: list[PluginRequirement] | None = None,
  52. env_vars: dict[str, str] | None = None,
  53. status_message_callback: Optional[Callable] = None,
  54. ):
  55. self.config = config
  56. self.status_message_callback = status_message_callback
  57. if self.config.sandbox.api_key is None:
  58. raise ValueError(
  59. 'API key is required to use the remote runtime. '
  60. 'Please set the API key in the config (config.toml) or as an environment variable (SANDBOX_API_KEY).'
  61. )
  62. self.session = requests.Session()
  63. self.session.headers.update({'X-API-Key': self.config.sandbox.api_key})
  64. self.action_semaphore = threading.Semaphore(1)
  65. if self.config.workspace_base is not None:
  66. logger.warning(
  67. 'Setting workspace_base is not supported in the remote runtime.'
  68. )
  69. self.runtime_builder = RemoteRuntimeBuilder(
  70. self.config.sandbox.remote_runtime_api_url, self.config.sandbox.api_key
  71. )
  72. self.runtime_id: str | None = None
  73. self.runtime_url: str | None = None
  74. self.instance_id = sid
  75. self._start_or_attach_to_runtime(plugins)
  76. # Initialize the eventstream and env vars
  77. super().__init__(
  78. config, event_stream, sid, plugins, env_vars, status_message_callback
  79. )
  80. self._wait_until_alive()
  81. self.setup_initial_env()
  82. def _start_or_attach_to_runtime(self, plugins: list[PluginRequirement] | None):
  83. existing_runtime = self._check_existing_runtime()
  84. if existing_runtime:
  85. logger.info(f'Using existing runtime with ID: {self.runtime_id}')
  86. else:
  87. self.send_status_message('STATUS$STARTING_CONTAINER')
  88. if self.config.sandbox.runtime_container_image is None:
  89. logger.info(
  90. f'Building remote runtime with base image: {self.config.sandbox.base_container_image}'
  91. )
  92. self._build_runtime()
  93. else:
  94. logger.info(
  95. f'Running remote runtime with image: {self.config.sandbox.runtime_container_image}'
  96. )
  97. self.container_image = self.config.sandbox.runtime_container_image
  98. self._start_runtime(plugins)
  99. assert (
  100. self.runtime_id is not None
  101. ), 'Runtime ID is not set. This should never happen.'
  102. assert (
  103. self.runtime_url is not None
  104. ), 'Runtime URL is not set. This should never happen.'
  105. self.send_status_message('STATUS$WAITING_FOR_CLIENT')
  106. self._wait_until_alive()
  107. def _check_existing_runtime(self) -> bool:
  108. try:
  109. response = send_request_with_retry(
  110. self.session,
  111. 'GET',
  112. f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.instance_id}',
  113. timeout=5,
  114. )
  115. except Exception as e:
  116. logger.debug(f'Error while looking for remote runtime: {e}')
  117. return False
  118. if response.status_code == 200:
  119. data = response.json()
  120. status = data.get('status')
  121. if status == 'running':
  122. self._parse_runtime_response(response)
  123. return True
  124. elif status == 'stopped':
  125. logger.info('Found existing remote runtime, but it is stopped')
  126. return False
  127. elif status == 'paused':
  128. logger.info('Found existing remote runtime, but it is paused')
  129. self._parse_runtime_response(response)
  130. self._resume_runtime()
  131. return True
  132. else:
  133. logger.error(f'Invalid response from runtime API: {data}')
  134. return False
  135. else:
  136. logger.info('Could not find existing remote runtime')
  137. return False
  138. def _build_runtime(self):
  139. logger.debug(f'RemoteRuntime `{self.instance_id}` config:\n{self.config}')
  140. response = send_request_with_retry(
  141. self.session,
  142. 'GET',
  143. f'{self.config.sandbox.remote_runtime_api_url}/registry_prefix',
  144. timeout=30,
  145. )
  146. response_json = response.json()
  147. registry_prefix = response_json['registry_prefix']
  148. os.environ['OH_RUNTIME_RUNTIME_IMAGE_REPO'] = (
  149. registry_prefix.rstrip('/') + '/runtime'
  150. )
  151. logger.info(
  152. f'Runtime image repo: {os.environ["OH_RUNTIME_RUNTIME_IMAGE_REPO"]}'
  153. )
  154. if self.config.sandbox.runtime_extra_deps:
  155. logger.info(
  156. f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}'
  157. )
  158. # Build the container image
  159. self.container_image = build_runtime_image(
  160. self.config.sandbox.base_container_image,
  161. self.runtime_builder,
  162. extra_deps=self.config.sandbox.runtime_extra_deps,
  163. force_rebuild=self.config.sandbox.force_rebuild_runtime,
  164. )
  165. response = send_request_with_retry(
  166. self.session,
  167. 'GET',
  168. f'{self.config.sandbox.remote_runtime_api_url}/image_exists',
  169. params={'image': self.container_image},
  170. timeout=30,
  171. )
  172. if response.status_code != 200 or not response.json()['exists']:
  173. raise RuntimeError(f'Container image {self.container_image} does not exist')
  174. def _start_runtime(self, plugins: list[PluginRequirement] | None):
  175. # Prepare the request body for the /start endpoint
  176. plugin_arg = ''
  177. if plugins is not None and len(plugins) > 0:
  178. plugin_arg = f'--plugins {" ".join([plugin.name for plugin in plugins])} '
  179. browsergym_arg = (
  180. f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}'
  181. if self.config.sandbox.browsergym_eval_env is not None
  182. else ''
  183. )
  184. start_request = {
  185. 'image': self.container_image,
  186. 'command': (
  187. f'/openhands/micromamba/bin/micromamba run -n openhands '
  188. 'poetry run '
  189. f'python -u -m openhands.runtime.client.client {self.port} '
  190. f'--working-dir {self.config.workspace_mount_path_in_sandbox} '
  191. f'{plugin_arg}'
  192. f'--username {"openhands" if self.config.run_as_openhands else "root"} '
  193. f'--user-id {self.config.sandbox.user_id} '
  194. f'{browsergym_arg}'
  195. ),
  196. 'working_dir': '/openhands/code/',
  197. 'environment': {'DEBUG': 'true'} if self.config.debug else {},
  198. 'runtime_id': self.instance_id,
  199. }
  200. # Start the sandbox using the /start endpoint
  201. response = send_request_with_retry(
  202. self.session,
  203. 'POST',
  204. f'{self.config.sandbox.remote_runtime_api_url}/start',
  205. json=start_request,
  206. timeout=300,
  207. )
  208. if response.status_code != 201:
  209. raise RuntimeError(f'Failed to start sandbox: {response.text}')
  210. self._parse_runtime_response(response)
  211. logger.info(
  212. f'Sandbox started. Runtime ID: {self.runtime_id}, URL: {self.runtime_url}'
  213. )
  214. def _resume_runtime(self):
  215. response = send_request_with_retry(
  216. self.session,
  217. 'POST',
  218. f'{self.config.sandbox.remote_runtime_api_url}/resume',
  219. json={'runtime_id': self.runtime_id},
  220. timeout=30,
  221. )
  222. if response.status_code != 200:
  223. raise RuntimeError(f'Failed to resume sandbox: {response.text}')
  224. logger.info(f'Sandbox resumed. Runtime ID: {self.runtime_id}')
  225. def _parse_runtime_response(self, response: requests.Response):
  226. start_response = response.json()
  227. self.runtime_id = start_response['runtime_id']
  228. self.runtime_url = start_response['url']
  229. if 'session_api_key' in start_response:
  230. self.session.headers.update(
  231. {'X-Session-API-Key': start_response['session_api_key']}
  232. )
  233. @retry(
  234. stop=stop_after_attempt(60) | stop_if_should_exit(),
  235. wait=wait_fixed(2),
  236. retry=retry_if_exception_type(RuntimeError),
  237. reraise=True,
  238. )
  239. def _wait_until_alive(self):
  240. logger.info(f'Waiting for runtime to be alive at url: {self.runtime_url}')
  241. response = send_request_with_retry(
  242. self.session,
  243. 'GET',
  244. f'{self.runtime_url}/alive',
  245. # Retry 404 errors for the /alive endpoint
  246. # because the runtime might just be starting up
  247. # and have not registered the endpoint yet
  248. retry_fns=[is_404_error],
  249. # leave enough time for the runtime to start up
  250. timeout=600,
  251. )
  252. if response.status_code != 200:
  253. msg = f'Runtime is not alive yet (id={self.runtime_id}). Status: {response.status_code}.'
  254. logger.warning(msg)
  255. raise RuntimeError(msg)
  256. def close(self, timeout: int = 10):
  257. if self.config.sandbox.keep_remote_runtime_alive:
  258. self.session.close()
  259. return
  260. if self.runtime_id:
  261. try:
  262. response = send_request_with_retry(
  263. self.session,
  264. 'POST',
  265. f'{self.config.sandbox.remote_runtime_api_url}/stop',
  266. json={'runtime_id': self.runtime_id},
  267. timeout=timeout,
  268. )
  269. if response.status_code != 200:
  270. logger.error(f'Failed to stop sandbox: {response.text}')
  271. else:
  272. logger.info(f'Sandbox stopped. Runtime ID: {self.runtime_id}')
  273. except Exception as e:
  274. raise e
  275. finally:
  276. self.session.close()
  277. def run_action(self, action: Action) -> Observation:
  278. if action.timeout is None:
  279. action.timeout = self.config.sandbox.timeout
  280. with self.action_semaphore:
  281. if not action.runnable:
  282. return NullObservation('')
  283. action_type = action.action # type: ignore[attr-defined]
  284. if action_type not in ACTION_TYPE_TO_CLASS:
  285. return ErrorObservation(f'Action {action_type} does not exist.')
  286. if not hasattr(self, action_type):
  287. return ErrorObservation(
  288. f'Action {action_type} is not supported in the current runtime.'
  289. )
  290. assert action.timeout is not None
  291. try:
  292. logger.info('Executing action')
  293. request_body = {'action': event_to_dict(action)}
  294. logger.debug(f'Request body: {request_body}')
  295. response = send_request_with_retry(
  296. self.session,
  297. 'POST',
  298. f'{self.runtime_url}/execute_action',
  299. json=request_body,
  300. timeout=action.timeout,
  301. retry_exceptions=list(
  302. filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
  303. ),
  304. # Retry 404 errors for the /execute_action endpoint
  305. # because the runtime might just be starting up
  306. # and have not registered the endpoint yet
  307. retry_fns=[is_404_error],
  308. )
  309. if response.status_code == 200:
  310. output = response.json()
  311. obs = observation_from_dict(output)
  312. obs._cause = action.id # type: ignore[attr-defined]
  313. return obs
  314. else:
  315. error_message = response.text
  316. logger.error(f'Error from server: {error_message}')
  317. obs = ErrorObservation(f'Action execution failed: {error_message}')
  318. except Timeout:
  319. logger.error('No response received within the timeout period.')
  320. obs = ErrorObservation('Action execution timed out')
  321. except Exception as e:
  322. logger.error(f'Error during action execution: {e}')
  323. obs = ErrorObservation(f'Action execution failed: {str(e)}')
  324. return obs
  325. def run(self, action: CmdRunAction) -> Observation:
  326. return self.run_action(action)
  327. def run_ipython(self, action: IPythonRunCellAction) -> Observation:
  328. return self.run_action(action)
  329. def read(self, action: FileReadAction) -> Observation:
  330. return self.run_action(action)
  331. def write(self, action: FileWriteAction) -> Observation:
  332. return self.run_action(action)
  333. def browse(self, action: BrowseURLAction) -> Observation:
  334. return self.run_action(action)
  335. def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
  336. return self.run_action(action)
  337. def copy_to(
  338. self, host_src: str, sandbox_dest: str, recursive: bool = False
  339. ) -> None:
  340. if not os.path.exists(host_src):
  341. raise FileNotFoundError(f'Source file {host_src} does not exist')
  342. try:
  343. if recursive:
  344. with tempfile.NamedTemporaryFile(
  345. suffix='.zip', delete=False
  346. ) as temp_zip:
  347. temp_zip_path = temp_zip.name
  348. with ZipFile(temp_zip_path, 'w') as zipf:
  349. for root, _, files in os.walk(host_src):
  350. for file in files:
  351. file_path = os.path.join(root, file)
  352. arcname = os.path.relpath(
  353. file_path, os.path.dirname(host_src)
  354. )
  355. zipf.write(file_path, arcname)
  356. upload_data = {'file': open(temp_zip_path, 'rb')}
  357. else:
  358. upload_data = {'file': open(host_src, 'rb')}
  359. params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
  360. response = send_request_with_retry(
  361. self.session,
  362. 'POST',
  363. f'{self.runtime_url}/upload_file',
  364. files=upload_data,
  365. params=params,
  366. retry_exceptions=list(
  367. filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
  368. ),
  369. timeout=300,
  370. )
  371. if response.status_code == 200:
  372. logger.info(
  373. f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}'
  374. )
  375. return
  376. else:
  377. error_message = response.text
  378. raise Exception(f'Copy operation failed: {error_message}')
  379. except TimeoutError:
  380. raise TimeoutError('Copy operation timed out')
  381. except Exception as e:
  382. raise RuntimeError(f'Copy operation failed: {str(e)}')
  383. finally:
  384. if recursive:
  385. os.unlink(temp_zip_path)
  386. logger.info(f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}')
  387. def list_files(self, path: str | None = None) -> list[str]:
  388. try:
  389. data = {}
  390. if path is not None:
  391. data['path'] = path
  392. response = send_request_with_retry(
  393. self.session,
  394. 'POST',
  395. f'{self.runtime_url}/list_files',
  396. json=data,
  397. retry_exceptions=list(
  398. filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
  399. ),
  400. timeout=30,
  401. )
  402. if response.status_code == 200:
  403. response_json = response.json()
  404. assert isinstance(response_json, list)
  405. return response_json
  406. else:
  407. error_message = response.text
  408. raise Exception(f'List files operation failed: {error_message}')
  409. except TimeoutError:
  410. raise TimeoutError('List files operation timed out')
  411. except Exception as e:
  412. raise RuntimeError(f'List files operation failed: {str(e)}')
  413. def copy_from(self, path: str) -> bytes:
  414. """Zip all files in the sandbox and return as a stream of bytes."""
  415. self._wait_until_alive()
  416. try:
  417. params = {'path': path}
  418. response = send_request_with_retry(
  419. self.session,
  420. 'GET',
  421. f'{self.runtime_url}/download_files',
  422. params=params,
  423. timeout=30,
  424. retry_exceptions=list(
  425. filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
  426. ),
  427. )
  428. if response.status_code == 200:
  429. return response.content
  430. else:
  431. error_message = response.text
  432. raise Exception(f'Copy operation failed: {error_message}')
  433. except requests.Timeout:
  434. raise TimeoutError('Copy operation timed out')
  435. except Exception as e:
  436. raise RuntimeError(f'Copy operation failed: {str(e)}')
  437. def send_status_message(self, message: str):
  438. """Sends a status message if the callback function was provided."""
  439. if self.status_message_callback:
  440. self.status_message_callback(message)