runtime.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import asyncio
  2. import atexit
  3. import copy
  4. import json
  5. import os
  6. from abc import abstractmethod
  7. from openhands.core.config import AppConfig, SandboxConfig
  8. from openhands.core.logger import openhands_logger as logger
  9. from openhands.events import EventSource, EventStream, EventStreamSubscriber
  10. from openhands.events.action import (
  11. Action,
  12. ActionConfirmationStatus,
  13. BrowseInteractiveAction,
  14. BrowseURLAction,
  15. CmdRunAction,
  16. FileReadAction,
  17. FileWriteAction,
  18. IPythonRunCellAction,
  19. )
  20. from openhands.events.event import Event
  21. from openhands.events.observation import (
  22. CmdOutputObservation,
  23. ErrorObservation,
  24. NullObservation,
  25. Observation,
  26. UserRejectObservation,
  27. )
  28. from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
  29. from openhands.runtime.plugins import JupyterRequirement, PluginRequirement
  30. def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
  31. ret = {}
  32. for key in os.environ:
  33. if key.startswith('SANDBOX_ENV_'):
  34. sandbox_key = key.removeprefix('SANDBOX_ENV_')
  35. ret[sandbox_key] = os.environ[key]
  36. if sandbox_config.enable_auto_lint:
  37. ret['ENABLE_AUTO_LINT'] = 'true'
  38. return ret
  39. class Runtime:
  40. """The runtime is how the agent interacts with the external environment.
  41. This includes a bash sandbox, a browser, and filesystem interactions.
  42. sid is the session id, which is used to identify the current user session.
  43. """
  44. sid: str
  45. config: AppConfig
  46. DEFAULT_ENV_VARS: dict[str, str]
  47. def __init__(
  48. self,
  49. config: AppConfig,
  50. event_stream: EventStream,
  51. sid: str = 'default',
  52. plugins: list[PluginRequirement] | None = None,
  53. ):
  54. self.sid = sid
  55. self.event_stream = event_stream
  56. self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
  57. self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
  58. self.config = copy.deepcopy(config)
  59. self.DEFAULT_ENV_VARS = _default_env_vars(config.sandbox)
  60. atexit.register(self.close_sync)
  61. logger.debug(f'Runtime `{sid}` config:\n{self.config}')
  62. async def ainit(self, env_vars: dict[str, str] | None = None) -> None:
  63. """
  64. Initialize the runtime (asynchronously).
  65. This method should be called after the runtime's constructor.
  66. """
  67. if self.DEFAULT_ENV_VARS:
  68. logger.debug(f'Adding default env vars: {self.DEFAULT_ENV_VARS}')
  69. await self.add_env_vars(self.DEFAULT_ENV_VARS)
  70. if env_vars is not None:
  71. logger.debug(f'Adding provided env vars: {env_vars}')
  72. await self.add_env_vars(env_vars)
  73. async def close(self) -> None:
  74. pass
  75. def close_sync(self) -> None:
  76. try:
  77. loop = asyncio.get_running_loop()
  78. except RuntimeError:
  79. # No running event loop, use asyncio.run()
  80. asyncio.run(self.close())
  81. else:
  82. # There is a running event loop, create a task
  83. if loop.is_running():
  84. loop.create_task(self.close())
  85. else:
  86. loop.run_until_complete(self.close())
  87. # ====================================================================
  88. async def add_env_vars(self, env_vars: dict[str, str]) -> None:
  89. # Add env vars to the IPython shell (if Jupyter is used)
  90. if any(isinstance(plugin, JupyterRequirement) for plugin in self.plugins):
  91. code = 'import os\n'
  92. for key, value in env_vars.items():
  93. # Note: json.dumps gives us nice escaping for free
  94. code += f'os.environ["{key}"] = {json.dumps(value)}\n'
  95. code += '\n'
  96. obs = await self.run_ipython(IPythonRunCellAction(code))
  97. logger.info(f'Added env vars to IPython: code={code}, obs={obs}')
  98. # Add env vars to the Bash shell
  99. cmd = ''
  100. for key, value in env_vars.items():
  101. # Note: json.dumps gives us nice escaping for free
  102. cmd += f'export {key}={json.dumps(value)}; '
  103. if not cmd:
  104. return
  105. cmd = cmd.strip()
  106. logger.debug(f'Adding env var: {cmd}')
  107. obs = await self.run(CmdRunAction(cmd))
  108. if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
  109. raise RuntimeError(
  110. f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
  111. )
  112. async def on_event(self, event: Event) -> None:
  113. if isinstance(event, Action):
  114. # set timeout to default if not set
  115. if event.timeout is None:
  116. event.timeout = self.config.sandbox.timeout
  117. assert event.timeout is not None
  118. observation = await self.run_action(event)
  119. observation._cause = event.id # type: ignore[attr-defined]
  120. source = event.source if event.source else EventSource.AGENT
  121. self.event_stream.add_event(observation, source) # type: ignore[arg-type]
  122. async def run_action(self, action: Action) -> Observation:
  123. """Run an action and return the resulting observation.
  124. If the action is not runnable in any runtime, a NullObservation is returned.
  125. If the action is not supported by the current runtime, an ErrorObservation is returned.
  126. """
  127. if not action.runnable:
  128. return NullObservation('')
  129. if (
  130. hasattr(action, 'is_confirmed')
  131. and action.is_confirmed == ActionConfirmationStatus.AWAITING_CONFIRMATION
  132. ):
  133. return NullObservation('')
  134. action_type = action.action # type: ignore[attr-defined]
  135. if action_type not in ACTION_TYPE_TO_CLASS:
  136. return ErrorObservation(f'Action {action_type} does not exist.')
  137. if not hasattr(self, action_type):
  138. return ErrorObservation(
  139. f'Action {action_type} is not supported in the current runtime.'
  140. )
  141. if (
  142. hasattr(action, 'is_confirmed')
  143. and action.is_confirmed == ActionConfirmationStatus.REJECTED
  144. ):
  145. return UserRejectObservation(
  146. 'Action has been rejected by the user! Waiting for further user input.'
  147. )
  148. observation = await getattr(self, action_type)(action)
  149. return observation
  150. # ====================================================================
  151. # Action execution
  152. # ====================================================================
  153. @abstractmethod
  154. async def run(self, action: CmdRunAction) -> Observation:
  155. pass
  156. @abstractmethod
  157. async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
  158. pass
  159. @abstractmethod
  160. async def read(self, action: FileReadAction) -> Observation:
  161. pass
  162. @abstractmethod
  163. async def write(self, action: FileWriteAction) -> Observation:
  164. pass
  165. @abstractmethod
  166. async def browse(self, action: BrowseURLAction) -> Observation:
  167. pass
  168. @abstractmethod
  169. async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
  170. pass
  171. # ====================================================================
  172. # File operations
  173. # ====================================================================
  174. @abstractmethod
  175. async def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
  176. raise NotImplementedError('This method is not implemented in the base class.')
  177. @abstractmethod
  178. async def list_files(self, path: str | None = None) -> list[str]:
  179. """List files in the sandbox.
  180. If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
  181. """
  182. raise NotImplementedError('This method is not implemented in the base class.')