base.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import atexit
  2. import copy
  3. import json
  4. import os
  5. from abc import abstractmethod
  6. from typing import Callable
  7. from openhands.core.config import AppConfig, SandboxConfig
  8. from openhands.core.logger import openhands_logger as logger
  9. from openhands.events import EventSource, EventStream, EventStreamSubscriber
  10. from openhands.events.action import (
  11. Action,
  12. ActionConfirmationStatus,
  13. BrowseInteractiveAction,
  14. BrowseURLAction,
  15. CmdRunAction,
  16. FileReadAction,
  17. FileWriteAction,
  18. IPythonRunCellAction,
  19. )
  20. from openhands.events.event import Event
  21. from openhands.events.observation import (
  22. CmdOutputObservation,
  23. ErrorObservation,
  24. NullObservation,
  25. Observation,
  26. UserRejectObservation,
  27. )
  28. from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
  29. from openhands.runtime.plugins import JupyterRequirement, PluginRequirement
  30. from openhands.runtime.utils.edit import FileEditRuntimeMixin
  31. from openhands.utils.async_utils import call_sync_from_async
  32. def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
  33. ret = {}
  34. for key in os.environ:
  35. if key.startswith('SANDBOX_ENV_'):
  36. sandbox_key = key.removeprefix('SANDBOX_ENV_')
  37. ret[sandbox_key] = os.environ[key]
  38. if sandbox_config.enable_auto_lint:
  39. ret['ENABLE_AUTO_LINT'] = 'true'
  40. return ret
  41. class Runtime(FileEditRuntimeMixin):
  42. """The runtime is how the agent interacts with the external environment.
  43. This includes a bash sandbox, a browser, and filesystem interactions.
  44. sid is the session id, which is used to identify the current user session.
  45. """
  46. sid: str
  47. config: AppConfig
  48. initial_env_vars: dict[str, str]
  49. attach_to_existing: bool
  50. def __init__(
  51. self,
  52. config: AppConfig,
  53. event_stream: EventStream,
  54. sid: str = 'default',
  55. plugins: list[PluginRequirement] | None = None,
  56. env_vars: dict[str, str] | None = None,
  57. status_message_callback: Callable | None = None,
  58. attach_to_existing: bool = False,
  59. ):
  60. self.sid = sid
  61. self.event_stream = event_stream
  62. self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
  63. self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
  64. self.status_message_callback = status_message_callback
  65. self.attach_to_existing = attach_to_existing
  66. self.config = copy.deepcopy(config)
  67. atexit.register(self.close)
  68. self.initial_env_vars = _default_env_vars(config.sandbox)
  69. if env_vars is not None:
  70. self.initial_env_vars.update(env_vars)
  71. # Load mixins
  72. FileEditRuntimeMixin.__init__(self)
  73. def setup_initial_env(self) -> None:
  74. if self.attach_to_existing:
  75. return
  76. logger.debug(f'Adding env vars: {self.initial_env_vars}')
  77. self.add_env_vars(self.initial_env_vars)
  78. if self.config.sandbox.runtime_startup_env_vars:
  79. self.add_env_vars(self.config.sandbox.runtime_startup_env_vars)
  80. def close(self) -> None:
  81. pass
  82. # ====================================================================
  83. def add_env_vars(self, env_vars: dict[str, str]) -> None:
  84. # Add env vars to the IPython shell (if Jupyter is used)
  85. if any(isinstance(plugin, JupyterRequirement) for plugin in self.plugins):
  86. code = 'import os\n'
  87. for key, value in env_vars.items():
  88. # Note: json.dumps gives us nice escaping for free
  89. code += f'os.environ["{key}"] = {json.dumps(value)}\n'
  90. code += '\n'
  91. obs = self.run_ipython(IPythonRunCellAction(code))
  92. logger.info(f'Added env vars to IPython: code={code}, obs={obs}')
  93. # Add env vars to the Bash shell
  94. cmd = ''
  95. for key, value in env_vars.items():
  96. # Note: json.dumps gives us nice escaping for free
  97. cmd += f'export {key}={json.dumps(value)}; '
  98. if not cmd:
  99. return
  100. cmd = cmd.strip()
  101. logger.debug(f'Adding env var: {cmd}')
  102. obs = self.run(CmdRunAction(cmd))
  103. if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
  104. raise RuntimeError(
  105. f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
  106. )
  107. async def on_event(self, event: Event) -> None:
  108. if isinstance(event, Action):
  109. # set timeout to default if not set
  110. if event.timeout is None:
  111. event.timeout = self.config.sandbox.timeout
  112. assert event.timeout is not None
  113. observation = await call_sync_from_async(self.run_action, event)
  114. observation._cause = event.id # type: ignore[attr-defined]
  115. source = event.source if event.source else EventSource.AGENT
  116. await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type]
  117. def run_action(self, action: Action) -> Observation:
  118. """Run an action and return the resulting observation.
  119. If the action is not runnable in any runtime, a NullObservation is returned.
  120. If the action is not supported by the current runtime, an ErrorObservation is returned.
  121. """
  122. if not action.runnable:
  123. return NullObservation('')
  124. if (
  125. hasattr(action, 'confirmation_state')
  126. and action.confirmation_state
  127. == ActionConfirmationStatus.AWAITING_CONFIRMATION
  128. ):
  129. return NullObservation('')
  130. action_type = action.action # type: ignore[attr-defined]
  131. if action_type not in ACTION_TYPE_TO_CLASS:
  132. return ErrorObservation(f'Action {action_type} does not exist.')
  133. if not hasattr(self, action_type):
  134. return ErrorObservation(
  135. f'Action {action_type} is not supported in the current runtime.'
  136. )
  137. if (
  138. getattr(action, 'confirmation_state', None)
  139. == ActionConfirmationStatus.REJECTED
  140. ):
  141. return UserRejectObservation(
  142. 'Action has been rejected by the user! Waiting for further user input.'
  143. )
  144. observation = getattr(self, action_type)(action)
  145. return observation
  146. # ====================================================================
  147. # Context manager
  148. # ====================================================================
  149. def __enter__(self) -> 'Runtime':
  150. return self
  151. def __exit__(self, exc_type, exc_value, traceback) -> None:
  152. self.close()
  153. @abstractmethod
  154. async def connect(self) -> None:
  155. pass
  156. # ====================================================================
  157. # Action execution
  158. # ====================================================================
  159. @abstractmethod
  160. def run(self, action: CmdRunAction) -> Observation:
  161. pass
  162. @abstractmethod
  163. def run_ipython(self, action: IPythonRunCellAction) -> Observation:
  164. pass
  165. @abstractmethod
  166. def read(self, action: FileReadAction) -> Observation:
  167. pass
  168. @abstractmethod
  169. def write(self, action: FileWriteAction) -> Observation:
  170. pass
  171. @abstractmethod
  172. def browse(self, action: BrowseURLAction) -> Observation:
  173. pass
  174. @abstractmethod
  175. def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
  176. pass
  177. # ====================================================================
  178. # File operations
  179. # ====================================================================
  180. @abstractmethod
  181. def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
  182. raise NotImplementedError('This method is not implemented in the base class.')
  183. @abstractmethod
  184. def list_files(self, path: str | None = None) -> list[str]:
  185. """List files in the sandbox.
  186. If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
  187. """
  188. raise NotImplementedError('This method is not implemented in the base class.')
  189. @abstractmethod
  190. def copy_from(self, path: str) -> bytes:
  191. """Zip all files in the sandbox and return as a stream of bytes."""
  192. raise NotImplementedError('This method is not implemented in the base class.')