base.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import atexit
  2. import copy
  3. import json
  4. import os
  5. from abc import abstractmethod
  6. from pathlib import Path
  7. from typing import Callable
  8. from requests.exceptions import ConnectionError
  9. from openhands.core.config import AppConfig, SandboxConfig
  10. from openhands.core.logger import openhands_logger as logger
  11. from openhands.events import EventSource, EventStream, EventStreamSubscriber
  12. from openhands.events.action import (
  13. Action,
  14. ActionConfirmationStatus,
  15. BrowseInteractiveAction,
  16. BrowseURLAction,
  17. CmdRunAction,
  18. FileReadAction,
  19. FileWriteAction,
  20. IPythonRunCellAction,
  21. )
  22. from openhands.events.event import Event
  23. from openhands.events.observation import (
  24. CmdOutputObservation,
  25. ErrorObservation,
  26. NullObservation,
  27. Observation,
  28. UserRejectObservation,
  29. )
  30. from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
  31. from openhands.runtime.plugins import (
  32. JupyterRequirement,
  33. PluginRequirement,
  34. VSCodeRequirement,
  35. )
  36. from openhands.runtime.utils.edit import FileEditRuntimeMixin
  37. from openhands.utils.async_utils import call_sync_from_async
  38. STATUS_MESSAGES = {
  39. 'STATUS$STARTING_RUNTIME': 'Starting runtime...',
  40. 'STATUS$STARTING_CONTAINER': 'Starting container...',
  41. 'STATUS$PREPARING_CONTAINER': 'Preparing container...',
  42. 'STATUS$CONTAINER_STARTED': 'Container started.',
  43. 'STATUS$WAITING_FOR_CLIENT': 'Waiting for client...',
  44. }
  45. class RuntimeUnavailableError(Exception):
  46. pass
  47. class RuntimeNotReadyError(RuntimeUnavailableError):
  48. pass
  49. class RuntimeDisconnectedError(RuntimeUnavailableError):
  50. pass
  51. class RuntimeNotFoundError(RuntimeUnavailableError):
  52. pass
  53. def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
  54. ret = {}
  55. for key in os.environ:
  56. if key.startswith('SANDBOX_ENV_'):
  57. sandbox_key = key.removeprefix('SANDBOX_ENV_')
  58. ret[sandbox_key] = os.environ[key]
  59. if sandbox_config.enable_auto_lint:
  60. ret['ENABLE_AUTO_LINT'] = 'true'
  61. return ret
  62. class Runtime(FileEditRuntimeMixin):
  63. """The runtime is how the agent interacts with the external environment.
  64. This includes a bash sandbox, a browser, and filesystem interactions.
  65. sid is the session id, which is used to identify the current user session.
  66. """
  67. sid: str
  68. config: AppConfig
  69. initial_env_vars: dict[str, str]
  70. attach_to_existing: bool
  71. status_callback: Callable | None
  72. def __init__(
  73. self,
  74. config: AppConfig,
  75. event_stream: EventStream,
  76. sid: str = 'default',
  77. plugins: list[PluginRequirement] | None = None,
  78. env_vars: dict[str, str] | None = None,
  79. status_callback: Callable | None = None,
  80. attach_to_existing: bool = False,
  81. headless_mode: bool = False,
  82. ):
  83. self.sid = sid
  84. self.event_stream = event_stream
  85. self.event_stream.subscribe(
  86. EventStreamSubscriber.RUNTIME, self.on_event, self.sid
  87. )
  88. self.plugins = (
  89. copy.deepcopy(plugins) if plugins is not None and len(plugins) > 0 else []
  90. )
  91. # add VSCode plugin if not in headless mode
  92. if not headless_mode:
  93. self.plugins.append(VSCodeRequirement())
  94. self.status_callback = status_callback
  95. self.attach_to_existing = attach_to_existing
  96. self.config = copy.deepcopy(config)
  97. atexit.register(self.close)
  98. self.initial_env_vars = _default_env_vars(config.sandbox)
  99. if env_vars is not None:
  100. self.initial_env_vars.update(env_vars)
  101. self._vscode_enabled = any(
  102. isinstance(plugin, VSCodeRequirement) for plugin in self.plugins
  103. )
  104. # Load mixins
  105. FileEditRuntimeMixin.__init__(self)
  106. def setup_initial_env(self) -> None:
  107. if self.attach_to_existing:
  108. return
  109. logger.debug(f'Adding env vars: {self.initial_env_vars}')
  110. self.add_env_vars(self.initial_env_vars)
  111. if self.config.sandbox.runtime_startup_env_vars:
  112. self.add_env_vars(self.config.sandbox.runtime_startup_env_vars)
  113. def close(self) -> None:
  114. pass
  115. def log(self, level: str, message: str) -> None:
  116. message = f'[runtime {self.sid}] {message}'
  117. getattr(logger, level)(message, stacklevel=2)
  118. def send_status_message(self, message_id: str):
  119. """Sends a status message if the callback function was provided."""
  120. if self.status_callback:
  121. msg = STATUS_MESSAGES.get(message_id, '')
  122. self.status_callback('info', message_id, msg)
  123. def send_error_message(self, message_id: str, message: str):
  124. if self.status_callback:
  125. self.status_callback('error', message_id, message)
  126. # ====================================================================
  127. def add_env_vars(self, env_vars: dict[str, str]) -> None:
  128. # Add env vars to the IPython shell (if Jupyter is used)
  129. if any(isinstance(plugin, JupyterRequirement) for plugin in self.plugins):
  130. code = 'import os\n'
  131. for key, value in env_vars.items():
  132. # Note: json.dumps gives us nice escaping for free
  133. code += f'os.environ["{key}"] = {json.dumps(value)}\n'
  134. code += '\n'
  135. obs = self.run_ipython(IPythonRunCellAction(code))
  136. self.log('debug', f'Added env vars to IPython: code={code}, obs={obs}')
  137. # Add env vars to the Bash shell
  138. cmd = ''
  139. for key, value in env_vars.items():
  140. # Note: json.dumps gives us nice escaping for free
  141. cmd += f'export {key}={json.dumps(value)}; '
  142. if not cmd:
  143. return
  144. cmd = cmd.strip()
  145. logger.debug(f'Adding env var: {cmd}')
  146. obs = self.run(CmdRunAction(cmd))
  147. if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
  148. raise RuntimeError(
  149. f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
  150. )
  151. async def on_event(self, event: Event) -> None:
  152. if isinstance(event, Action):
  153. # set timeout to default if not set
  154. if event.timeout is None:
  155. event.timeout = self.config.sandbox.timeout
  156. assert event.timeout is not None
  157. try:
  158. observation: Observation = await call_sync_from_async(
  159. self.run_action, event
  160. )
  161. except Exception as e:
  162. err_id = ''
  163. if isinstance(e, ConnectionError) or isinstance(
  164. e, RuntimeDisconnectedError
  165. ):
  166. err_id = 'STATUS$ERROR_RUNTIME_DISCONNECTED'
  167. logger.error(
  168. 'Unexpected error while running action',
  169. exc_info=True,
  170. stack_info=True,
  171. )
  172. self.log('error', f'Problematic action: {str(event)}')
  173. self.send_error_message(err_id, str(e))
  174. self.close()
  175. return
  176. observation._cause = event.id # type: ignore[attr-defined]
  177. observation.tool_call_metadata = event.tool_call_metadata
  178. # this might be unnecessary, since source should be set by the event stream when we're here
  179. source = event.source if event.source else EventSource.AGENT
  180. self.event_stream.add_event(observation, source) # type: ignore[arg-type]
  181. def run_action(self, action: Action) -> Observation:
  182. """Run an action and return the resulting observation.
  183. If the action is not runnable in any runtime, a NullObservation is returned.
  184. If the action is not supported by the current runtime, an ErrorObservation is returned.
  185. """
  186. if not action.runnable:
  187. return NullObservation('')
  188. if (
  189. hasattr(action, 'confirmation_state')
  190. and action.confirmation_state
  191. == ActionConfirmationStatus.AWAITING_CONFIRMATION
  192. ):
  193. return NullObservation('')
  194. action_type = action.action # type: ignore[attr-defined]
  195. if action_type not in ACTION_TYPE_TO_CLASS:
  196. return ErrorObservation(f'Action {action_type} does not exist.')
  197. if not hasattr(self, action_type):
  198. return ErrorObservation(
  199. f'Action {action_type} is not supported in the current runtime.'
  200. )
  201. if (
  202. getattr(action, 'confirmation_state', None)
  203. == ActionConfirmationStatus.REJECTED
  204. ):
  205. return UserRejectObservation(
  206. 'Action has been rejected by the user! Waiting for further user input.'
  207. )
  208. observation = getattr(self, action_type)(action)
  209. return observation
  210. # ====================================================================
  211. # Context manager
  212. # ====================================================================
  213. def __enter__(self) -> 'Runtime':
  214. return self
  215. def __exit__(self, exc_type, exc_value, traceback) -> None:
  216. self.close()
  217. @abstractmethod
  218. async def connect(self) -> None:
  219. pass
  220. # ====================================================================
  221. # Action execution
  222. # ====================================================================
  223. @abstractmethod
  224. def run(self, action: CmdRunAction) -> Observation:
  225. pass
  226. @abstractmethod
  227. def run_ipython(self, action: IPythonRunCellAction) -> Observation:
  228. pass
  229. @abstractmethod
  230. def read(self, action: FileReadAction) -> Observation:
  231. pass
  232. @abstractmethod
  233. def write(self, action: FileWriteAction) -> Observation:
  234. pass
  235. @abstractmethod
  236. def browse(self, action: BrowseURLAction) -> Observation:
  237. pass
  238. @abstractmethod
  239. def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
  240. pass
  241. # ====================================================================
  242. # File operations
  243. # ====================================================================
  244. @abstractmethod
  245. def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
  246. raise NotImplementedError('This method is not implemented in the base class.')
  247. @abstractmethod
  248. def list_files(self, path: str | None = None) -> list[str]:
  249. """List files in the sandbox.
  250. If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
  251. """
  252. raise NotImplementedError('This method is not implemented in the base class.')
  253. @abstractmethod
  254. def copy_from(self, path: str) -> Path:
  255. """Zip all files in the sandbox and return a path in the local filesystem."""
  256. raise NotImplementedError('This method is not implemented in the base class.')
  257. # ====================================================================
  258. # VSCode
  259. # ====================================================================
  260. @property
  261. def vscode_enabled(self) -> bool:
  262. return self._vscode_enabled
  263. @property
  264. def vscode_url(self) -> str | None:
  265. raise NotImplementedError('This method is not implemented in the base class.')