| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- import asyncio
- from typing import Callable, Optional
- from openhands.controller import AgentController
- from openhands.controller.agent import Agent
- from openhands.controller.state.state import State
- from openhands.core.config import AgentConfig, AppConfig, LLMConfig
- from openhands.core.logger import openhands_logger as logger
- from openhands.core.schema.agent import AgentState
- from openhands.events.action import ChangeAgentStateAction
- from openhands.events.event import EventSource
- from openhands.events.stream import EventStream
- from openhands.runtime import get_runtime_cls
- from openhands.runtime.base import Runtime, RuntimeUnavailableError
- from openhands.security import SecurityAnalyzer, options
- from openhands.storage.files import FileStore
- from openhands.utils.async_utils import call_async_from_sync
- class AgentSession:
- """Represents a session with an Agent
- Attributes:
- controller: The AgentController instance for controlling the agent.
- """
- sid: str
- event_stream: EventStream
- file_store: FileStore
- controller: AgentController | None = None
- runtime: Runtime | None = None
- security_analyzer: SecurityAnalyzer | None = None
- _closed: bool = False
- loop: asyncio.AbstractEventLoop | None = None
- def __init__(
- self,
- sid: str,
- file_store: FileStore,
- status_callback: Optional[Callable] = None,
- ):
- """Initializes a new instance of the Session class
- Parameters:
- - sid: The session ID
- - file_store: Instance of the FileStore
- """
- self.sid = sid
- self.event_stream = EventStream(sid, file_store)
- self.file_store = file_store
- self._status_callback = status_callback
- async def start(
- self,
- runtime_name: str,
- config: AppConfig,
- agent: Agent,
- max_iterations: int,
- max_budget_per_task: float | None = None,
- agent_to_llm_config: dict[str, LLMConfig] | None = None,
- agent_configs: dict[str, AgentConfig] | None = None,
- github_token: str | None = None,
- selected_repository: str | None = None,
- ):
- """Starts the Agent session
- Parameters:
- - runtime_name: The name of the runtime associated with the session
- - config:
- - agent:
- - max_iterations:
- - max_budget_per_task:
- - agent_to_llm_config:
- - agent_configs:
- """
- if self.controller or self.runtime:
- raise RuntimeError(
- 'Session already started. You need to close this session and start a new one.'
- )
- asyncio.get_event_loop().run_in_executor(
- None,
- self._start_thread,
- runtime_name,
- config,
- agent,
- max_iterations,
- max_budget_per_task,
- agent_to_llm_config,
- agent_configs,
- github_token,
- selected_repository,
- )
- def _start_thread(self, *args):
- try:
- asyncio.run(self._start(*args), debug=True)
- except RuntimeError:
- logger.error(f'Error starting session: {RuntimeError}', exc_info=True)
- logger.debug('Session Finished')
- async def _start(
- self,
- runtime_name: str,
- config: AppConfig,
- agent: Agent,
- max_iterations: int,
- max_budget_per_task: float | None = None,
- agent_to_llm_config: dict[str, LLMConfig] | None = None,
- agent_configs: dict[str, AgentConfig] | None = None,
- github_token: str | None = None,
- selected_repository: str | None = None,
- ):
- self._create_security_analyzer(config.security.security_analyzer)
- await self._create_runtime(
- runtime_name=runtime_name,
- config=config,
- agent=agent,
- github_token=github_token,
- selected_repository=selected_repository,
- )
- self._create_controller(
- agent,
- config.security.confirmation_mode,
- max_iterations,
- max_budget_per_task=max_budget_per_task,
- agent_to_llm_config=agent_to_llm_config,
- agent_configs=agent_configs,
- )
- self.event_stream.add_event(
- ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
- )
- if self.controller:
- self.controller.agent_task = self.controller.start_step_loop()
- await self.controller.agent_task # type: ignore
- def close(self):
- """Closes the Agent session"""
- if self._closed:
- return
- self._closed = True
- call_async_from_sync(self._close)
- async def _close(self):
- if self.controller is not None:
- end_state = self.controller.get_state()
- end_state.save_to_session(self.sid, self.file_store)
- await self.controller.close()
- if self.runtime is not None:
- self.runtime.close()
- if self.security_analyzer is not None:
- await self.security_analyzer.close()
- async def stop_agent_loop_for_error(self):
- if self.controller is not None:
- await self.controller.set_agent_state_to(AgentState.ERROR)
- def _create_security_analyzer(self, security_analyzer: str | None):
- """Creates a SecurityAnalyzer instance that will be used to analyze the agent actions
- Parameters:
- - security_analyzer: The name of the security analyzer to use
- """
- if security_analyzer:
- logger.debug(f'Using security analyzer: {security_analyzer}')
- self.security_analyzer = options.SecurityAnalyzers.get(
- security_analyzer, SecurityAnalyzer
- )(self.event_stream)
- async def _create_runtime(
- self,
- runtime_name: str,
- config: AppConfig,
- agent: Agent,
- github_token: str | None = None,
- selected_repository: str | None = None,
- ):
- """Creates a runtime instance
- Parameters:
- - runtime_name: The name of the runtime associated with the session
- - config:
- - agent:
- """
- if self.runtime is not None:
- raise RuntimeError('Runtime already created')
- logger.debug(f'Initializing runtime `{runtime_name}` now...')
- runtime_cls = get_runtime_cls(runtime_name)
- self.runtime = runtime_cls(
- config=config,
- event_stream=self.event_stream,
- sid=self.sid,
- plugins=agent.sandbox_plugins,
- status_callback=self._status_callback,
- headless_mode=False,
- )
- try:
- await self.runtime.connect()
- except RuntimeUnavailableError as e:
- logger.error(f'Runtime initialization failed: {e}', exc_info=True)
- if self._status_callback:
- self._status_callback(
- 'error', 'STATUS$ERROR_RUNTIME_DISCONNECTED', str(e)
- )
- return
- if self.runtime is not None:
- self.runtime.clone_repo(github_token, selected_repository)
- if agent.prompt_manager:
- agent.prompt_manager.load_microagent_files(
- self.runtime.get_custom_microagents(selected_repository)
- )
- logger.debug(
- f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
- )
- else:
- logger.warning('Runtime initialization failed')
- def _create_controller(
- self,
- agent: Agent,
- confirmation_mode: bool,
- max_iterations: int,
- max_budget_per_task: float | None = None,
- agent_to_llm_config: dict[str, LLMConfig] | None = None,
- agent_configs: dict[str, AgentConfig] | None = None,
- ):
- """Creates an AgentController instance
- Parameters:
- - agent:
- - confirmation_mode: Whether to use confirmation mode
- - max_iterations:
- - max_budget_per_task:
- - agent_to_llm_config:
- - agent_configs:
- """
- if self.controller is not None:
- raise RuntimeError('Controller already created')
- if self.runtime is None:
- raise RuntimeError(
- 'Runtime must be initialized before the agent controller'
- )
- msg = (
- '\n--------------------------------- OpenHands Configuration ---------------------------------\n'
- f'LLM: {agent.llm.config.model}\n'
- f'Base URL: {agent.llm.config.base_url}\n'
- )
- if agent.llm.config.draft_editor:
- msg += (
- f'Draft editor LLM (for file editing): {agent.llm.config.draft_editor.model}\n'
- f'Draft editor LLM (for file editing) Base URL: {agent.llm.config.draft_editor.base_url}\n'
- )
- msg += (
- f'Agent: {agent.name}\n'
- f'Runtime: {self.runtime.__class__.__name__}\n'
- f'Plugins: {agent.sandbox_plugins}\n'
- '-------------------------------------------------------------------------------------------'
- )
- logger.debug(msg)
- self.controller = AgentController(
- sid=self.sid,
- event_stream=self.event_stream,
- agent=agent,
- max_iterations=int(max_iterations),
- max_budget_per_task=max_budget_per_task,
- agent_to_llm_config=agent_to_llm_config,
- agent_configs=agent_configs,
- confirmation_mode=confirmation_mode,
- headless_mode=False,
- status_callback=self._status_callback,
- )
- try:
- agent_state = State.restore_from_session(self.sid, self.file_store)
- self.controller.set_initial_state(
- agent_state, max_iterations, confirmation_mode
- )
- logger.debug(f'Restored agent state from session, sid: {self.sid}')
- except Exception as e:
- logger.debug(f'State could not be restored: {e}')
- logger.debug('Agent controller initialized.')
|