runtime.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. import os
  2. import tempfile
  3. import threading
  4. import uuid
  5. from zipfile import ZipFile
  6. import requests
  7. from requests.exceptions import Timeout
  8. from tenacity import (
  9. retry,
  10. retry_if_exception_type,
  11. stop_after_attempt,
  12. wait_exponential,
  13. )
  14. from openhands.core.config import AppConfig
  15. from openhands.core.logger import openhands_logger as logger
  16. from openhands.events import EventStream
  17. from openhands.events.action import (
  18. BrowseInteractiveAction,
  19. BrowseURLAction,
  20. CmdRunAction,
  21. FileReadAction,
  22. FileWriteAction,
  23. IPythonRunCellAction,
  24. )
  25. from openhands.events.action.action import Action
  26. from openhands.events.observation import (
  27. ErrorObservation,
  28. NullObservation,
  29. Observation,
  30. )
  31. from openhands.events.serialization import event_to_dict, observation_from_dict
  32. from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
  33. from openhands.runtime.builder.remote import RemoteRuntimeBuilder
  34. from openhands.runtime.plugins import PluginRequirement
  35. from openhands.runtime.runtime import Runtime
  36. from openhands.runtime.utils.request import (
  37. DEFAULT_RETRY_EXCEPTIONS,
  38. is_404_error,
  39. send_request,
  40. )
  41. from openhands.runtime.utils.runtime_build import build_runtime_image
  42. class RemoteRuntime(Runtime):
  43. """This runtime will connect to a remote oh-runtime-client."""
  44. port: int = 60000 # default port for the remote runtime client
  45. def __init__(
  46. self,
  47. config: AppConfig,
  48. event_stream: EventStream,
  49. sid: str = 'default',
  50. plugins: list[PluginRequirement] | None = None,
  51. env_vars: dict[str, str] | None = None,
  52. ):
  53. self.config = config
  54. if self.config.sandbox.api_hostname == 'localhost':
  55. self.config.sandbox.api_hostname = 'api.all-hands.dev/v0/runtime'
  56. logger.warning(
  57. 'Using localhost as the API hostname is not supported in the RemoteRuntime. Please set a proper hostname.\n'
  58. 'Setting it to default value: api.all-hands.dev/v0/runtime'
  59. )
  60. self.api_url = f'https://{self.config.sandbox.api_hostname.rstrip("/")}'
  61. if self.config.sandbox.api_key is None:
  62. raise ValueError(
  63. 'API key is required to use the remote runtime. '
  64. 'Please set the API key in the config (config.toml) or as an environment variable (SANDBOX_API_KEY).'
  65. )
  66. self.session = requests.Session()
  67. self.session.headers.update({'X-API-Key': self.config.sandbox.api_key})
  68. self.action_semaphore = threading.Semaphore(1)
  69. if self.config.workspace_base is not None:
  70. logger.warning(
  71. 'Setting workspace_base is not supported in the remote runtime.'
  72. )
  73. self.runtime_builder = RemoteRuntimeBuilder(
  74. self.api_url, self.config.sandbox.api_key
  75. )
  76. self.runtime_id: str | None = None
  77. self.runtime_url: str | None = None
  78. self.instance_id = (
  79. sid + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
  80. )
  81. if self.config.sandbox.runtime_container_image is not None:
  82. raise ValueError(
  83. 'Setting runtime_container_image is not supported in the remote runtime.'
  84. )
  85. self.container_image: str = self.config.sandbox.base_container_image
  86. self.container_name = 'oh-remote-runtime-' + self.instance_id
  87. logger.debug(f'RemoteRuntime `{sid}` config:\n{self.config}')
  88. response = send_request(self.session, 'GET', f'{self.api_url}/registry_prefix')
  89. response_json = response.json()
  90. registry_prefix = response_json['registry_prefix']
  91. os.environ['OH_RUNTIME_RUNTIME_IMAGE_REPO'] = (
  92. registry_prefix.rstrip('/') + '/runtime'
  93. )
  94. logger.info(
  95. f'Runtime image repo: {os.environ["OH_RUNTIME_RUNTIME_IMAGE_REPO"]}'
  96. )
  97. if self.config.sandbox.runtime_extra_deps:
  98. logger.info(
  99. f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}'
  100. )
  101. # Build the container image
  102. self.container_image = build_runtime_image(
  103. self.container_image,
  104. self.runtime_builder,
  105. extra_deps=self.config.sandbox.runtime_extra_deps,
  106. )
  107. # Use the /image_exists endpoint to check if the image exists
  108. response = send_request(
  109. self.session,
  110. 'GET',
  111. f'{self.api_url}/image_exists',
  112. params={'image': self.container_image},
  113. )
  114. if response.status_code != 200 or not response.json()['exists']:
  115. raise RuntimeError(f'Container image {self.container_image} does not exist')
  116. # Prepare the request body for the /start endpoint
  117. plugin_arg = ''
  118. if plugins is not None and len(plugins) > 0:
  119. plugin_arg = f'--plugins {" ".join([plugin.name for plugin in plugins])} '
  120. browsergym_arg = (
  121. f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}'
  122. if self.config.sandbox.browsergym_eval_env is not None
  123. else ''
  124. )
  125. start_request = {
  126. 'image': self.container_image,
  127. 'command': (
  128. f'/openhands/miniforge3/bin/mamba run --no-capture-output -n base '
  129. 'PYTHONUNBUFFERED=1 poetry run '
  130. f'python -u -m openhands.runtime.client.client {self.port} '
  131. f'--working-dir {self.config.workspace_mount_path_in_sandbox} '
  132. f'{plugin_arg}'
  133. f'--username {"openhands" if self.config.run_as_openhands else "root"} '
  134. f'--user-id {self.config.sandbox.user_id} '
  135. f'{browsergym_arg}'
  136. ),
  137. 'working_dir': '/openhands/code/',
  138. 'name': self.container_name,
  139. 'environment': {'DEBUG': 'true'} if self.config.debug else {},
  140. }
  141. # Start the sandbox using the /start endpoint
  142. response = send_request(
  143. self.session, 'POST', f'{self.api_url}/start', json=start_request
  144. )
  145. if response.status_code != 201:
  146. raise RuntimeError(f'Failed to start sandbox: {response.text}')
  147. start_response = response.json()
  148. self.runtime_id = start_response['runtime_id']
  149. self.runtime_url = start_response['url']
  150. logger.info(
  151. f'Sandbox started. Runtime ID: {self.runtime_id}, URL: {self.runtime_url}'
  152. )
  153. # Initialize the eventstream and env vars
  154. super().__init__(config, event_stream, sid, plugins, env_vars)
  155. logger.info(
  156. f'Runtime initialized with plugins: {[plugin.name for plugin in self.plugins]}'
  157. )
  158. logger.info(f'Runtime initialized with env vars: {env_vars}')
  159. assert (
  160. self.runtime_id is not None
  161. ), 'Runtime ID is not set. This should never happen.'
  162. assert (
  163. self.runtime_url is not None
  164. ), 'Runtime URL is not set. This should never happen.'
  165. @retry(
  166. stop=stop_after_attempt(10),
  167. wait=wait_exponential(multiplier=1, min=4, max=60),
  168. retry=retry_if_exception_type(RuntimeError),
  169. reraise=True,
  170. )
  171. def _wait_until_alive(self):
  172. logger.info('Waiting for sandbox to be alive...')
  173. response = send_request(
  174. self.session,
  175. 'GET',
  176. f'{self.runtime_url}/alive',
  177. # Retry 404 errors for the /alive endpoint
  178. # because the runtime might just be starting up
  179. # and have not registered the endpoint yet
  180. retry_fns=[is_404_error],
  181. # leave enough time for the runtime to start up
  182. timeout=600,
  183. )
  184. if response.status_code != 200:
  185. msg = f'Runtime is not alive yet (id={self.runtime_id}). Status: {response.status_code}.'
  186. logger.warning(msg)
  187. raise RuntimeError(msg)
  188. def close(self):
  189. if self.runtime_id:
  190. try:
  191. response = send_request(
  192. self.session,
  193. 'POST',
  194. f'{self.api_url}/stop',
  195. json={'runtime_id': self.runtime_id},
  196. )
  197. if response.status_code != 200:
  198. logger.error(f'Failed to stop sandbox: {response.text}')
  199. else:
  200. logger.info(f'Sandbox stopped. Runtime ID: {self.runtime_id}')
  201. except Exception as e:
  202. raise e
  203. finally:
  204. self.session.close()
  205. def run_action(self, action: Action) -> Observation:
  206. if action.timeout is None:
  207. action.timeout = self.config.sandbox.timeout
  208. with self.action_semaphore:
  209. if not action.runnable:
  210. return NullObservation('')
  211. action_type = action.action # type: ignore[attr-defined]
  212. if action_type not in ACTION_TYPE_TO_CLASS:
  213. return ErrorObservation(f'Action {action_type} does not exist.')
  214. if not hasattr(self, action_type):
  215. return ErrorObservation(
  216. f'Action {action_type} is not supported in the current runtime.'
  217. )
  218. self._wait_until_alive()
  219. assert action.timeout is not None
  220. try:
  221. logger.info('Executing action')
  222. request_body = {'action': event_to_dict(action)}
  223. logger.debug(f'Request body: {request_body}')
  224. response = send_request(
  225. self.session,
  226. 'POST',
  227. f'{self.runtime_url}/execute_action',
  228. json=request_body,
  229. timeout=action.timeout,
  230. retry_exceptions=list(
  231. filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
  232. ),
  233. # Retry 404 errors for the /execute_action endpoint
  234. # because the runtime might just be starting up
  235. # and have not registered the endpoint yet
  236. retry_fns=[is_404_error],
  237. )
  238. if response.status_code == 200:
  239. output = response.json()
  240. obs = observation_from_dict(output)
  241. obs._cause = action.id # type: ignore[attr-defined]
  242. return obs
  243. else:
  244. error_message = response.text
  245. logger.error(f'Error from server: {error_message}')
  246. obs = ErrorObservation(f'Action execution failed: {error_message}')
  247. except Timeout:
  248. logger.error('No response received within the timeout period.')
  249. obs = ErrorObservation('Action execution timed out')
  250. except Exception as e:
  251. logger.error(f'Error during action execution: {e}')
  252. obs = ErrorObservation(f'Action execution failed: {str(e)}')
  253. return obs
  254. def run(self, action: CmdRunAction) -> Observation:
  255. return self.run_action(action)
  256. def run_ipython(self, action: IPythonRunCellAction) -> Observation:
  257. return self.run_action(action)
  258. def read(self, action: FileReadAction) -> Observation:
  259. return self.run_action(action)
  260. def write(self, action: FileWriteAction) -> Observation:
  261. return self.run_action(action)
  262. def browse(self, action: BrowseURLAction) -> Observation:
  263. return self.run_action(action)
  264. def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
  265. return self.run_action(action)
  266. def copy_to(
  267. self, host_src: str, sandbox_dest: str, recursive: bool = False
  268. ) -> None:
  269. if not os.path.exists(host_src):
  270. raise FileNotFoundError(f'Source file {host_src} does not exist')
  271. self._wait_until_alive()
  272. try:
  273. if recursive:
  274. with tempfile.NamedTemporaryFile(
  275. suffix='.zip', delete=False
  276. ) as temp_zip:
  277. temp_zip_path = temp_zip.name
  278. with ZipFile(temp_zip_path, 'w') as zipf:
  279. for root, _, files in os.walk(host_src):
  280. for file in files:
  281. file_path = os.path.join(root, file)
  282. arcname = os.path.relpath(
  283. file_path, os.path.dirname(host_src)
  284. )
  285. zipf.write(file_path, arcname)
  286. upload_data = {'file': open(temp_zip_path, 'rb')}
  287. else:
  288. upload_data = {'file': open(host_src, 'rb')}
  289. params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
  290. response = send_request(
  291. self.session,
  292. 'POST',
  293. f'{self.runtime_url}/upload_file',
  294. files=upload_data,
  295. params=params,
  296. retry_exceptions=list(
  297. filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
  298. ),
  299. )
  300. if response.status_code == 200:
  301. logger.info(
  302. f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}'
  303. )
  304. return
  305. else:
  306. error_message = response.text
  307. raise Exception(f'Copy operation failed: {error_message}')
  308. except TimeoutError:
  309. raise TimeoutError('Copy operation timed out')
  310. except Exception as e:
  311. raise RuntimeError(f'Copy operation failed: {str(e)}')
  312. finally:
  313. if recursive:
  314. os.unlink(temp_zip_path)
  315. logger.info(f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}')
  316. def list_files(self, path: str | None = None) -> list[str]:
  317. self._wait_until_alive()
  318. try:
  319. data = {}
  320. if path is not None:
  321. data['path'] = path
  322. response = send_request(
  323. self.session,
  324. 'POST',
  325. f'{self.runtime_url}/list_files',
  326. json=data,
  327. retry_exceptions=list(
  328. filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
  329. ),
  330. )
  331. if response.status_code == 200:
  332. response_json = response.json()
  333. assert isinstance(response_json, list)
  334. return response_json
  335. else:
  336. error_message = response.text
  337. raise Exception(f'List files operation failed: {error_message}')
  338. except TimeoutError:
  339. raise TimeoutError('List files operation timed out')
  340. except Exception as e:
  341. raise RuntimeError(f'List files operation failed: {str(e)}')