| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- import asyncio
- import atexit
- import copy
- import json
- import os
- from abc import abstractmethod
- from openhands.core.config import AppConfig, SandboxConfig
- from openhands.core.logger import openhands_logger as logger
- from openhands.events import EventSource, EventStream, EventStreamSubscriber
- from openhands.events.action import (
- Action,
- ActionConfirmationStatus,
- BrowseInteractiveAction,
- BrowseURLAction,
- CmdRunAction,
- FileReadAction,
- FileWriteAction,
- IPythonRunCellAction,
- )
- from openhands.events.event import Event
- from openhands.events.observation import (
- CmdOutputObservation,
- ErrorObservation,
- NullObservation,
- Observation,
- UserRejectObservation,
- )
- from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
- from openhands.runtime.plugins import JupyterRequirement, PluginRequirement
- def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
- ret = {}
- for key in os.environ:
- if key.startswith('SANDBOX_ENV_'):
- sandbox_key = key.removeprefix('SANDBOX_ENV_')
- ret[sandbox_key] = os.environ[key]
- if sandbox_config.enable_auto_lint:
- ret['ENABLE_AUTO_LINT'] = 'true'
- return ret
- class Runtime:
- """The runtime is how the agent interacts with the external environment.
- This includes a bash sandbox, a browser, and filesystem interactions.
- sid is the session id, which is used to identify the current user session.
- """
- sid: str
- config: AppConfig
- DEFAULT_ENV_VARS: dict[str, str]
- def __init__(
- self,
- config: AppConfig,
- event_stream: EventStream,
- sid: str = 'default',
- plugins: list[PluginRequirement] | None = None,
- ):
- self.sid = sid
- self.event_stream = event_stream
- self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
- self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
- self.config = copy.deepcopy(config)
- self.DEFAULT_ENV_VARS = _default_env_vars(config.sandbox)
- atexit.register(self.close_sync)
- logger.debug(f'Runtime `{sid}` config:\n{self.config}')
- async def ainit(self, env_vars: dict[str, str] | None = None) -> None:
- """
- Initialize the runtime (asynchronously).
- This method should be called after the runtime's constructor.
- """
- if self.DEFAULT_ENV_VARS:
- logger.debug(f'Adding default env vars: {self.DEFAULT_ENV_VARS}')
- await self.add_env_vars(self.DEFAULT_ENV_VARS)
- if env_vars is not None:
- logger.debug(f'Adding provided env vars: {env_vars}')
- await self.add_env_vars(env_vars)
- async def close(self) -> None:
- pass
- def close_sync(self) -> None:
- try:
- loop = asyncio.get_running_loop()
- except RuntimeError:
- # No running event loop, use asyncio.run()
- asyncio.run(self.close())
- else:
- # There is a running event loop, create a task
- if loop.is_running():
- loop.create_task(self.close())
- else:
- loop.run_until_complete(self.close())
- # ====================================================================
- async def add_env_vars(self, env_vars: dict[str, str]) -> None:
- # Add env vars to the IPython shell (if Jupyter is used)
- if any(isinstance(plugin, JupyterRequirement) for plugin in self.plugins):
- code = 'import os\n'
- for key, value in env_vars.items():
- # Note: json.dumps gives us nice escaping for free
- code += f'os.environ["{key}"] = {json.dumps(value)}\n'
- code += '\n'
- obs = await self.run_ipython(IPythonRunCellAction(code))
- logger.info(f'Added env vars to IPython: code={code}, obs={obs}')
- # Add env vars to the Bash shell
- cmd = ''
- for key, value in env_vars.items():
- # Note: json.dumps gives us nice escaping for free
- cmd += f'export {key}={json.dumps(value)}; '
- if not cmd:
- return
- cmd = cmd.strip()
- logger.debug(f'Adding env var: {cmd}')
- obs = await self.run(CmdRunAction(cmd))
- if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
- raise RuntimeError(
- f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
- )
- async def on_event(self, event: Event) -> None:
- if isinstance(event, Action):
- # set timeout to default if not set
- if event.timeout is None:
- event.timeout = self.config.sandbox.timeout
- assert event.timeout is not None
- observation = await self.run_action(event)
- observation._cause = event.id # type: ignore[attr-defined]
- source = event.source if event.source else EventSource.AGENT
- self.event_stream.add_event(observation, source) # type: ignore[arg-type]
- async def run_action(self, action: Action) -> Observation:
- """Run an action and return the resulting observation.
- If the action is not runnable in any runtime, a NullObservation is returned.
- If the action is not supported by the current runtime, an ErrorObservation is returned.
- """
- if not action.runnable:
- return NullObservation('')
- if (
- hasattr(action, 'is_confirmed')
- and action.is_confirmed == ActionConfirmationStatus.AWAITING_CONFIRMATION
- ):
- return NullObservation('')
- action_type = action.action # type: ignore[attr-defined]
- if action_type not in ACTION_TYPE_TO_CLASS:
- return ErrorObservation(f'Action {action_type} does not exist.')
- if not hasattr(self, action_type):
- return ErrorObservation(
- f'Action {action_type} is not supported in the current runtime.'
- )
- if (
- hasattr(action, 'is_confirmed')
- and action.is_confirmed == ActionConfirmationStatus.REJECTED
- ):
- return UserRejectObservation(
- 'Action has been rejected by the user! Waiting for further user input.'
- )
- observation = await getattr(self, action_type)(action)
- return observation
- # ====================================================================
- # Action execution
- # ====================================================================
- @abstractmethod
- async def run(self, action: CmdRunAction) -> Observation:
- pass
- @abstractmethod
- async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
- pass
- @abstractmethod
- async def read(self, action: FileReadAction) -> Observation:
- pass
- @abstractmethod
- async def write(self, action: FileWriteAction) -> Observation:
- pass
- @abstractmethod
- async def browse(self, action: BrowseURLAction) -> Observation:
- pass
- @abstractmethod
- async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
- pass
- # ====================================================================
- # File operations
- # ====================================================================
- @abstractmethod
- async def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
- raise NotImplementedError('This method is not implemented in the base class.')
- @abstractmethod
- async def list_files(self, path: str | None = None) -> list[str]:
- """List files in the sandbox.
- If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
- """
- raise NotImplementedError('This method is not implemented in the base class.')
|