base.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. import atexit
  2. import copy
  3. import json
  4. import os
  5. import random
  6. import string
  7. from abc import abstractmethod
  8. from pathlib import Path
  9. from typing import Callable
  10. from requests.exceptions import ConnectionError
  11. from openhands.core.config import AppConfig, SandboxConfig
  12. from openhands.core.exceptions import AgentRuntimeDisconnectedError
  13. from openhands.core.logger import openhands_logger as logger
  14. from openhands.events import EventSource, EventStream, EventStreamSubscriber
  15. from openhands.events.action import (
  16. Action,
  17. ActionConfirmationStatus,
  18. BrowseInteractiveAction,
  19. BrowseURLAction,
  20. CmdRunAction,
  21. FileReadAction,
  22. FileWriteAction,
  23. IPythonRunCellAction,
  24. )
  25. from openhands.events.event import Event
  26. from openhands.events.observation import (
  27. CmdOutputObservation,
  28. ErrorObservation,
  29. NullObservation,
  30. Observation,
  31. UserRejectObservation,
  32. )
  33. from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
  34. from openhands.runtime.plugins import (
  35. JupyterRequirement,
  36. PluginRequirement,
  37. VSCodeRequirement,
  38. )
  39. from openhands.runtime.utils.edit import FileEditRuntimeMixin
  40. from openhands.utils.async_utils import call_sync_from_async
  41. STATUS_MESSAGES = {
  42. 'STATUS$STARTING_RUNTIME': 'Starting runtime...',
  43. 'STATUS$STARTING_CONTAINER': 'Starting container...',
  44. 'STATUS$PREPARING_CONTAINER': 'Preparing container...',
  45. 'STATUS$CONTAINER_STARTED': 'Container started.',
  46. 'STATUS$WAITING_FOR_CLIENT': 'Waiting for client...',
  47. }
  48. def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
  49. ret = {}
  50. for key in os.environ:
  51. if key.startswith('SANDBOX_ENV_'):
  52. sandbox_key = key.removeprefix('SANDBOX_ENV_')
  53. ret[sandbox_key] = os.environ[key]
  54. if sandbox_config.enable_auto_lint:
  55. ret['ENABLE_AUTO_LINT'] = 'true'
  56. return ret
  57. class Runtime(FileEditRuntimeMixin):
  58. """The runtime is how the agent interacts with the external environment.
  59. This includes a bash sandbox, a browser, and filesystem interactions.
  60. sid is the session id, which is used to identify the current user session.
  61. """
  62. sid: str
  63. config: AppConfig
  64. initial_env_vars: dict[str, str]
  65. attach_to_existing: bool
  66. status_callback: Callable | None
  67. def __init__(
  68. self,
  69. config: AppConfig,
  70. event_stream: EventStream,
  71. sid: str = 'default',
  72. plugins: list[PluginRequirement] | None = None,
  73. env_vars: dict[str, str] | None = None,
  74. status_callback: Callable | None = None,
  75. attach_to_existing: bool = False,
  76. headless_mode: bool = False,
  77. ):
  78. self.sid = sid
  79. self.event_stream = event_stream
  80. self.event_stream.subscribe(
  81. EventStreamSubscriber.RUNTIME, self.on_event, self.sid
  82. )
  83. self.plugins = (
  84. copy.deepcopy(plugins) if plugins is not None and len(plugins) > 0 else []
  85. )
  86. # add VSCode plugin if not in headless mode
  87. if not headless_mode:
  88. self.plugins.append(VSCodeRequirement())
  89. self.status_callback = status_callback
  90. self.attach_to_existing = attach_to_existing
  91. self.config = copy.deepcopy(config)
  92. atexit.register(self.close)
  93. self.initial_env_vars = _default_env_vars(config.sandbox)
  94. if env_vars is not None:
  95. self.initial_env_vars.update(env_vars)
  96. self._vscode_enabled = any(
  97. isinstance(plugin, VSCodeRequirement) for plugin in self.plugins
  98. )
  99. # Load mixins
  100. FileEditRuntimeMixin.__init__(self)
  101. def setup_initial_env(self) -> None:
  102. if self.attach_to_existing:
  103. return
  104. logger.debug(f'Adding env vars: {self.initial_env_vars}')
  105. self.add_env_vars(self.initial_env_vars)
  106. if self.config.sandbox.runtime_startup_env_vars:
  107. self.add_env_vars(self.config.sandbox.runtime_startup_env_vars)
  108. def close(self) -> None:
  109. pass
  110. def log(self, level: str, message: str) -> None:
  111. message = f'[runtime {self.sid}] {message}'
  112. getattr(logger, level)(message, stacklevel=2)
  113. def send_status_message(self, message_id: str):
  114. """Sends a status message if the callback function was provided."""
  115. if self.status_callback:
  116. msg = STATUS_MESSAGES.get(message_id, '')
  117. self.status_callback('info', message_id, msg)
  118. def send_error_message(self, message_id: str, message: str):
  119. if self.status_callback:
  120. self.status_callback('error', message_id, message)
  121. # ====================================================================
  122. def add_env_vars(self, env_vars: dict[str, str]) -> None:
  123. # Add env vars to the IPython shell (if Jupyter is used)
  124. if any(isinstance(plugin, JupyterRequirement) for plugin in self.plugins):
  125. code = 'import os\n'
  126. for key, value in env_vars.items():
  127. # Note: json.dumps gives us nice escaping for free
  128. code += f'os.environ["{key}"] = {json.dumps(value)}\n'
  129. code += '\n'
  130. obs = self.run_ipython(IPythonRunCellAction(code))
  131. self.log('debug', f'Added env vars to IPython: code={code}, obs={obs}')
  132. # Add env vars to the Bash shell
  133. cmd = ''
  134. for key, value in env_vars.items():
  135. # Note: json.dumps gives us nice escaping for free
  136. cmd += f'export {key}={json.dumps(value)}; '
  137. if not cmd:
  138. return
  139. cmd = cmd.strip()
  140. logger.debug(f'Adding env var: {cmd}')
  141. obs = self.run(CmdRunAction(cmd))
  142. if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
  143. raise RuntimeError(
  144. f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
  145. )
  146. async def on_event(self, event: Event) -> None:
  147. if isinstance(event, Action):
  148. # set timeout to default if not set
  149. if event.timeout is None:
  150. event.timeout = self.config.sandbox.timeout
  151. assert event.timeout is not None
  152. try:
  153. observation: Observation = await call_sync_from_async(
  154. self.run_action, event
  155. )
  156. except Exception as e:
  157. err_id = ''
  158. if isinstance(e, ConnectionError) or isinstance(
  159. e, AgentRuntimeDisconnectedError
  160. ):
  161. err_id = 'STATUS$ERROR_RUNTIME_DISCONNECTED'
  162. logger.error(
  163. 'Unexpected error while running action',
  164. exc_info=True,
  165. stack_info=True,
  166. )
  167. self.log('error', f'Problematic action: {str(event)}')
  168. self.send_error_message(err_id, str(e))
  169. self.close()
  170. return
  171. observation._cause = event.id # type: ignore[attr-defined]
  172. observation.tool_call_metadata = event.tool_call_metadata
  173. # this might be unnecessary, since source should be set by the event stream when we're here
  174. source = event.source if event.source else EventSource.AGENT
  175. self.event_stream.add_event(observation, source) # type: ignore[arg-type]
  176. def clone_repo(self, github_token: str | None, selected_repository: str | None):
  177. if not github_token or not selected_repository:
  178. return
  179. url = f'https://{github_token}@github.com/{selected_repository}.git'
  180. dir_name = selected_repository.split('/')[1]
  181. # add random branch name to avoid conflicts
  182. random_str = ''.join(
  183. random.choices(string.ascii_lowercase + string.digits, k=8)
  184. )
  185. branch_name = f'openhands-workspace-{random_str}'
  186. action = CmdRunAction(
  187. command=f'git clone {url} {dir_name} ; cd {dir_name} ; git checkout -b {branch_name}',
  188. )
  189. self.log('info', f'Cloning repo: {selected_repository}')
  190. self.run_action(action)
  191. def get_custom_microagents(self, selected_repository: str | None) -> list[str]:
  192. custom_microagents_content = []
  193. custom_microagents_dir = Path('.openhands') / 'microagents'
  194. dir_name = str(custom_microagents_dir)
  195. if selected_repository:
  196. dir_name = str(
  197. Path(selected_repository.split('/')[1]) / custom_microagents_dir
  198. )
  199. oh_instructions_header = '---\nname: openhands_instructions\nagent: CodeActAgent\ntriggers:\n- ""\n---\n'
  200. obs = self.read(FileReadAction(path='.openhands_instructions'))
  201. if isinstance(obs, ErrorObservation):
  202. self.log('error', 'Failed to read openhands_instructions')
  203. else:
  204. openhands_instructions = oh_instructions_header + obs.content
  205. self.log('info', f'openhands_instructions: {openhands_instructions}')
  206. custom_microagents_content.append(openhands_instructions)
  207. files = self.list_files(dir_name)
  208. self.log('info', f'Found {len(files)} custom microagents.')
  209. for fname in files:
  210. content = self.read(
  211. FileReadAction(path=str(custom_microagents_dir / fname))
  212. ).content
  213. custom_microagents_content.append(content)
  214. return custom_microagents_content
  215. def run_action(self, action: Action) -> Observation:
  216. """Run an action and return the resulting observation.
  217. If the action is not runnable in any runtime, a NullObservation is returned.
  218. If the action is not supported by the current runtime, an ErrorObservation is returned.
  219. """
  220. if not action.runnable:
  221. return NullObservation('')
  222. if (
  223. hasattr(action, 'confirmation_state')
  224. and action.confirmation_state
  225. == ActionConfirmationStatus.AWAITING_CONFIRMATION
  226. ):
  227. return NullObservation('')
  228. action_type = action.action # type: ignore[attr-defined]
  229. if action_type not in ACTION_TYPE_TO_CLASS:
  230. return ErrorObservation(f'Action {action_type} does not exist.')
  231. if not hasattr(self, action_type):
  232. return ErrorObservation(
  233. f'Action {action_type} is not supported in the current runtime.'
  234. )
  235. if (
  236. getattr(action, 'confirmation_state', None)
  237. == ActionConfirmationStatus.REJECTED
  238. ):
  239. return UserRejectObservation(
  240. 'Action has been rejected by the user! Waiting for further user input.'
  241. )
  242. observation = getattr(self, action_type)(action)
  243. return observation
  244. # ====================================================================
  245. # Context manager
  246. # ====================================================================
  247. def __enter__(self) -> 'Runtime':
  248. return self
  249. def __exit__(self, exc_type, exc_value, traceback) -> None:
  250. self.close()
  251. @abstractmethod
  252. async def connect(self) -> None:
  253. pass
  254. # ====================================================================
  255. # Action execution
  256. # ====================================================================
  257. @abstractmethod
  258. def run(self, action: CmdRunAction) -> Observation:
  259. pass
  260. @abstractmethod
  261. def run_ipython(self, action: IPythonRunCellAction) -> Observation:
  262. pass
  263. @abstractmethod
  264. def read(self, action: FileReadAction) -> Observation:
  265. pass
  266. @abstractmethod
  267. def write(self, action: FileWriteAction) -> Observation:
  268. pass
  269. @abstractmethod
  270. def browse(self, action: BrowseURLAction) -> Observation:
  271. pass
  272. @abstractmethod
  273. def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
  274. pass
  275. # ====================================================================
  276. # File operations
  277. # ====================================================================
  278. @abstractmethod
  279. def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
  280. raise NotImplementedError('This method is not implemented in the base class.')
  281. @abstractmethod
  282. def list_files(self, path: str | None = None) -> list[str]:
  283. """List files in the sandbox.
  284. If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
  285. """
  286. raise NotImplementedError('This method is not implemented in the base class.')
  287. @abstractmethod
  288. def copy_from(self, path: str) -> Path:
  289. """Zip all files in the sandbox and return a path in the local filesystem."""
  290. raise NotImplementedError('This method is not implemented in the base class.')
  291. # ====================================================================
  292. # VSCode
  293. # ====================================================================
  294. @property
  295. def vscode_enabled(self) -> bool:
  296. return self._vscode_enabled
  297. @property
  298. def vscode_url(self) -> str | None:
  299. raise NotImplementedError('This method is not implemented in the base class.')