base.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. import atexit
  2. import copy
  3. import json
  4. import os
  5. from abc import abstractmethod
  6. from typing import Callable
  7. from requests.exceptions import ConnectionError
  8. from openhands.core.config import AppConfig, SandboxConfig
  9. from openhands.core.logger import openhands_logger as logger
  10. from openhands.events import EventSource, EventStream, EventStreamSubscriber
  11. from openhands.events.action import (
  12. Action,
  13. ActionConfirmationStatus,
  14. BrowseInteractiveAction,
  15. BrowseURLAction,
  16. CmdRunAction,
  17. FileReadAction,
  18. FileWriteAction,
  19. IPythonRunCellAction,
  20. )
  21. from openhands.events.event import Event
  22. from openhands.events.observation import (
  23. CmdOutputObservation,
  24. ErrorObservation,
  25. NullObservation,
  26. Observation,
  27. UserRejectObservation,
  28. )
  29. from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
  30. from openhands.runtime.plugins import JupyterRequirement, PluginRequirement
  31. from openhands.runtime.utils.edit import FileEditRuntimeMixin
  32. from openhands.utils.async_utils import call_sync_from_async
  33. STATUS_MESSAGES = {
  34. 'STATUS$STARTING_RUNTIME': 'Starting runtime...',
  35. 'STATUS$STARTING_CONTAINER': 'Starting container...',
  36. 'STATUS$PREPARING_CONTAINER': 'Preparing container...',
  37. 'STATUS$CONTAINER_STARTED': 'Container started.',
  38. 'STATUS$WAITING_FOR_CLIENT': 'Waiting for client...',
  39. }
  40. class RuntimeNotReadyError(Exception):
  41. pass
  42. class RuntimeDisconnectedError(Exception):
  43. pass
  44. def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
  45. ret = {}
  46. for key in os.environ:
  47. if key.startswith('SANDBOX_ENV_'):
  48. sandbox_key = key.removeprefix('SANDBOX_ENV_')
  49. ret[sandbox_key] = os.environ[key]
  50. if sandbox_config.enable_auto_lint:
  51. ret['ENABLE_AUTO_LINT'] = 'true'
  52. return ret
  53. class Runtime(FileEditRuntimeMixin):
  54. """The runtime is how the agent interacts with the external environment.
  55. This includes a bash sandbox, a browser, and filesystem interactions.
  56. sid is the session id, which is used to identify the current user session.
  57. """
  58. sid: str
  59. config: AppConfig
  60. initial_env_vars: dict[str, str]
  61. attach_to_existing: bool
  62. status_callback: Callable | None
  63. def __init__(
  64. self,
  65. config: AppConfig,
  66. event_stream: EventStream,
  67. sid: str = 'default',
  68. plugins: list[PluginRequirement] | None = None,
  69. env_vars: dict[str, str] | None = None,
  70. status_callback: Callable | None = None,
  71. attach_to_existing: bool = False,
  72. ):
  73. self.sid = sid
  74. self.event_stream = event_stream
  75. self.event_stream.subscribe(
  76. EventStreamSubscriber.RUNTIME, self.on_event, self.sid
  77. )
  78. self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
  79. self.status_callback = status_callback
  80. self.attach_to_existing = attach_to_existing
  81. self.config = copy.deepcopy(config)
  82. atexit.register(self.close)
  83. self.initial_env_vars = _default_env_vars(config.sandbox)
  84. if env_vars is not None:
  85. self.initial_env_vars.update(env_vars)
  86. # Load mixins
  87. FileEditRuntimeMixin.__init__(self)
  88. def setup_initial_env(self) -> None:
  89. if self.attach_to_existing:
  90. return
  91. logger.debug(f'Adding env vars: {self.initial_env_vars}')
  92. self.add_env_vars(self.initial_env_vars)
  93. if self.config.sandbox.runtime_startup_env_vars:
  94. self.add_env_vars(self.config.sandbox.runtime_startup_env_vars)
  95. def close(self) -> None:
  96. pass
  97. def log(self, level: str, message: str) -> None:
  98. message = f'[runtime {self.sid}] {message}'
  99. getattr(logger, level)(message, stacklevel=2)
  100. def send_status_message(self, message_id: str):
  101. """Sends a status message if the callback function was provided."""
  102. if self.status_callback:
  103. msg = STATUS_MESSAGES.get(message_id, '')
  104. self.status_callback('info', message_id, msg)
  105. def send_error_message(self, message_id: str, message: str):
  106. if self.status_callback:
  107. self.status_callback('error', message_id, message)
  108. # ====================================================================
  109. def add_env_vars(self, env_vars: dict[str, str]) -> None:
  110. # Add env vars to the IPython shell (if Jupyter is used)
  111. if any(isinstance(plugin, JupyterRequirement) for plugin in self.plugins):
  112. code = 'import os\n'
  113. for key, value in env_vars.items():
  114. # Note: json.dumps gives us nice escaping for free
  115. code += f'os.environ["{key}"] = {json.dumps(value)}\n'
  116. code += '\n'
  117. obs = self.run_ipython(IPythonRunCellAction(code))
  118. self.log('debug', f'Added env vars to IPython: code={code}, obs={obs}')
  119. # Add env vars to the Bash shell
  120. cmd = ''
  121. for key, value in env_vars.items():
  122. # Note: json.dumps gives us nice escaping for free
  123. cmd += f'export {key}={json.dumps(value)}; '
  124. if not cmd:
  125. return
  126. cmd = cmd.strip()
  127. logger.debug(f'Adding env var: {cmd}')
  128. obs = self.run(CmdRunAction(cmd))
  129. if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
  130. raise RuntimeError(
  131. f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
  132. )
  133. async def on_event(self, event: Event) -> None:
  134. if isinstance(event, Action):
  135. # set timeout to default if not set
  136. if event.timeout is None:
  137. event.timeout = self.config.sandbox.timeout
  138. assert event.timeout is not None
  139. try:
  140. observation: Observation = await call_sync_from_async(
  141. self.run_action, event
  142. )
  143. except Exception as e:
  144. err_id = ''
  145. if isinstance(e, ConnectionError) or isinstance(
  146. e, RuntimeDisconnectedError
  147. ):
  148. err_id = 'STATUS$ERROR_RUNTIME_DISCONNECTED'
  149. self.log('error', f'Unexpected error while running action {e}')
  150. self.log('error', f'Problematic action: {str(event)}')
  151. self.send_error_message(err_id, str(e))
  152. self.close()
  153. return
  154. observation._cause = event.id # type: ignore[attr-defined]
  155. observation.tool_call_metadata = event.tool_call_metadata
  156. # this might be unnecessary, since source should be set by the event stream when we're here
  157. source = event.source if event.source else EventSource.AGENT
  158. self.event_stream.add_event(observation, source) # type: ignore[arg-type]
  159. def run_action(self, action: Action) -> Observation:
  160. """Run an action and return the resulting observation.
  161. If the action is not runnable in any runtime, a NullObservation is returned.
  162. If the action is not supported by the current runtime, an ErrorObservation is returned.
  163. """
  164. if not action.runnable:
  165. return NullObservation('')
  166. if (
  167. hasattr(action, 'confirmation_state')
  168. and action.confirmation_state
  169. == ActionConfirmationStatus.AWAITING_CONFIRMATION
  170. ):
  171. return NullObservation('')
  172. action_type = action.action # type: ignore[attr-defined]
  173. if action_type not in ACTION_TYPE_TO_CLASS:
  174. return ErrorObservation(f'Action {action_type} does not exist.')
  175. if not hasattr(self, action_type):
  176. return ErrorObservation(
  177. f'Action {action_type} is not supported in the current runtime.'
  178. )
  179. if (
  180. getattr(action, 'confirmation_state', None)
  181. == ActionConfirmationStatus.REJECTED
  182. ):
  183. return UserRejectObservation(
  184. 'Action has been rejected by the user! Waiting for further user input.'
  185. )
  186. observation = getattr(self, action_type)(action)
  187. return observation
  188. # ====================================================================
  189. # Context manager
  190. # ====================================================================
  191. def __enter__(self) -> 'Runtime':
  192. return self
  193. def __exit__(self, exc_type, exc_value, traceback) -> None:
  194. self.close()
  195. @abstractmethod
  196. async def connect(self) -> None:
  197. pass
  198. # ====================================================================
  199. # Action execution
  200. # ====================================================================
  201. @abstractmethod
  202. def run(self, action: CmdRunAction) -> Observation:
  203. pass
  204. @abstractmethod
  205. def run_ipython(self, action: IPythonRunCellAction) -> Observation:
  206. pass
  207. @abstractmethod
  208. def read(self, action: FileReadAction) -> Observation:
  209. pass
  210. @abstractmethod
  211. def write(self, action: FileWriteAction) -> Observation:
  212. pass
  213. @abstractmethod
  214. def browse(self, action: BrowseURLAction) -> Observation:
  215. pass
  216. @abstractmethod
  217. def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
  218. pass
  219. # ====================================================================
  220. # File operations
  221. # ====================================================================
  222. @abstractmethod
  223. def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
  224. raise NotImplementedError('This method is not implemented in the base class.')
  225. @abstractmethod
  226. def list_files(self, path: str | None = None) -> list[str]:
  227. """List files in the sandbox.
  228. If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
  229. """
  230. raise NotImplementedError('This method is not implemented in the base class.')
  231. @abstractmethod
  232. def copy_from(self, path: str) -> bytes:
  233. """Zip all files in the sandbox and return as a stream of bytes."""
  234. raise NotImplementedError('This method is not implemented in the base class.')