Explorar el Código

Allow attaching to existing sessions without reinitializing the runtime (#4329)

Co-authored-by: tofarr <tofarr@gmail.com>
Robert Brennan hace 1 año
padre
commit
63ff69fd97

+ 0 - 12
frontend/src/services/actions.ts

@@ -7,13 +7,11 @@ import {
   appendSecurityAnalyzerInput,
 } from "#/state/securityAnalyzerSlice";
 import { setCurStatusMessage } from "#/state/statusSlice";
-import { setRootTask } from "#/state/taskSlice";
 import store from "#/store";
 import ActionType from "#/types/ActionType";
 import { ActionMessage, StatusMessage } from "#/types/Message";
 import { SocketMessage } from "#/types/ResponseType";
 import { handleObservationMessage } from "./observations";
-import { getRootTask } from "./taskService";
 
 const messageActions = {
   [ActionType.BROWSE]: (message: ActionMessage) => {
@@ -75,16 +73,6 @@ const messageActions = {
       store.dispatch(appendJupyterInput(message.args.code));
     }
   },
-  [ActionType.ADD_TASK]: () => {
-    getRootTask().then((fetchedRootTask) =>
-      store.dispatch(setRootTask(fetchedRootTask)),
-    );
-  },
-  [ActionType.MODIFY_TASK]: () => {
-    getRootTask().then((fetchedRootTask) =>
-      store.dispatch(setRootTask(fetchedRootTask)),
-    );
-  },
 };
 
 function getRiskText(risk: ActionSecurityRisk) {

+ 0 - 21
frontend/src/services/taskService.ts

@@ -1,21 +0,0 @@
-import { request } from "./api";
-
-export type Task = {
-  id: string;
-  goal: string;
-  subtasks: Task[];
-  state: TaskState;
-};
-
-export enum TaskState {
-  OPEN_STATE = "open",
-  COMPLETED_STATE = "completed",
-  ABANDONED_STATE = "abandoned",
-  IN_PROGRESS_STATE = "in_progress",
-  VERIFIED_STATE = "verified",
-}
-
-export async function getRootTask(): Promise<Task | undefined> {
-  const res = await request("/api/root_task");
-  return res as Task;
-}

+ 0 - 23
frontend/src/state/taskSlice.ts

@@ -1,23 +0,0 @@
-import { createSlice } from "@reduxjs/toolkit";
-import { Task, TaskState } from "#/services/taskService";
-
-export const taskSlice = createSlice({
-  name: "task",
-  initialState: {
-    task: {
-      id: "",
-      goal: "",
-      subtasks: [],
-      state: TaskState.OPEN_STATE,
-    } as Task,
-  },
-  reducers: {
-    setRootTask: (state, action) => {
-      state.task = action.payload as Task;
-    },
-  },
-});
-
-export const { setRootTask } = taskSlice.actions;
-
-export default taskSlice.reducer;

+ 0 - 2
frontend/src/store.ts

@@ -6,7 +6,6 @@ import codeReducer from "./state/codeSlice";
 import fileStateReducer from "./state/file-state-slice";
 import initialQueryReducer from "./state/initial-query-slice";
 import commandReducer from "./state/commandSlice";
-import taskReducer from "./state/taskSlice";
 import jupyterReducer from "./state/jupyterSlice";
 import securityAnalyzerReducer from "./state/securityAnalyzerSlice";
 import statusReducer from "./state/statusSlice";
@@ -18,7 +17,6 @@ export const rootReducer = combineReducers({
   chat: chatReducer,
   code: codeReducer,
   cmd: commandReducer,
-  task: taskReducer,
   agent: agentReducer,
   jupyter: jupyterReducer,
   securityAnalyzer: securityAnalyzerReducer,

+ 8 - 0
openhands/events/stream.py

@@ -21,6 +21,14 @@ class EventStreamSubscriber(str, Enum):
     TEST = 'test'
 
 
+def session_exists(sid: str, file_store: FileStore) -> bool:
+    try:
+        file_store.list(f'sessions/{sid}')
+        return True
+    except FileNotFoundError:
+        return False
+
+
 class EventStream:
     sid: str
     file_store: FileStore

+ 39 - 18
openhands/runtime/client/runtime.py

@@ -1,7 +1,6 @@
 import os
 import tempfile
 import threading
-import uuid
 from typing import Callable
 from zipfile import ZipFile
 
@@ -104,7 +103,6 @@ class LogBuffer:
 class EventStreamRuntime(Runtime):
     """This runtime will subscribe the event stream.
     When receive an event, it will send the event to runtime-client which run inside the docker environment.
-    From the sid also an instance_id is generated in combination with a UID.
 
     Args:
         config (AppConfig): The application configuration.
@@ -114,7 +112,7 @@ class EventStreamRuntime(Runtime):
         env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None.
     """
 
-    container_name_prefix = 'openhands-sandbox-'
+    container_name_prefix = 'openhands-runtime-'
 
     def __init__(
         self,
@@ -124,27 +122,24 @@ class EventStreamRuntime(Runtime):
         plugins: list[PluginRequirement] | None = None,
         env_vars: dict[str, str] | None = None,
         status_message_callback: Callable | None = None,
+        attach_to_existing: bool = False,
     ):
         self.config = config
         self._host_port = 30000  # initial dummy value
         self._container_port = 30001  # initial dummy value
         self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
         self.session = requests.Session()
-        self.instance_id = (
-            sid + '_' + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
-        )
         self.status_message_callback = status_message_callback
 
         self.send_status_message('STATUS$STARTING_RUNTIME')
         self.docker_client: docker.DockerClient = self._init_docker_client()
         self.base_container_image = self.config.sandbox.base_container_image
         self.runtime_container_image = self.config.sandbox.runtime_container_image
-        self.container_name = self.container_name_prefix + self.instance_id
+        self.container_name = self.container_name_prefix + sid
         self.container = None
         self.action_semaphore = threading.Semaphore(1)  # Ensure one action at a time
 
         self.runtime_builder = DockerRuntimeBuilder(self.docker_client)
-        logger.debug(f'EventStreamRuntime `{self.instance_id}`')
 
         # Buffer for container logs
         self.log_buffer: LogBuffer | None = None
@@ -170,15 +165,25 @@ class EventStreamRuntime(Runtime):
                 extra_deps=self.config.sandbox.runtime_extra_deps,
                 force_rebuild=self.config.sandbox.force_rebuild_runtime,
             )
-        self.container = self._init_container(
-            sandbox_workspace_dir=self.config.workspace_mount_path_in_sandbox,  # e.g. /workspace
-            mount_dir=self.config.workspace_mount_path,  # e.g. /opt/openhands/_test_workspace
-            plugins=plugins,
-        )
+
+        if not attach_to_existing:
+            self._init_container(
+                sandbox_workspace_dir=self.config.workspace_mount_path_in_sandbox,  # e.g. /workspace
+                mount_dir=self.config.workspace_mount_path,  # e.g. /opt/openhands/_test_workspace
+                plugins=plugins,
+            )
+        else:
+            self._attach_to_container()
 
         # will initialize both the event stream and the env vars
         super().__init__(
-            config, event_stream, sid, plugins, env_vars, status_message_callback
+            config,
+            event_stream,
+            sid,
+            plugins,
+            env_vars,
+            status_message_callback,
+            attach_to_existing,
         )
 
         logger.info('Waiting for client to become ready...')
@@ -272,7 +277,7 @@ class EventStreamRuntime(Runtime):
             else:
                 browsergym_arg = ''
 
-            container = self.docker_client.containers.run(
+            self.container = self.docker_client.containers.run(
                 self.runtime_container_image,
                 command=(
                     f'/openhands/micromamba/bin/micromamba run -n openhands '
@@ -292,18 +297,34 @@ class EventStreamRuntime(Runtime):
                 environment=environment,
                 volumes=volumes,
             )
-            self.log_buffer = LogBuffer(container)
+            self.log_buffer = LogBuffer(self.container)
             logger.info(f'Container started. Server url: {self.api_url}')
             self.send_status_message('STATUS$CONTAINER_STARTED')
-            return container
         except Exception as e:
             logger.error(
-                f'Error: Instance {self.instance_id} FAILED to start container!\n'
+                f'Error: Instance {self.container_name} FAILED to start container!\n'
             )
             logger.exception(e)
             self.close(close_client=False)
             raise e
 
+    def _attach_to_container(self):
+        container = self.docker_client.containers.get(self.container_name)
+        self.log_buffer = LogBuffer(container)
+        self.container = container
+        self._container_port = 0
+        for port in container.attrs['NetworkSettings']['Ports']:
+            self._container_port = int(port.split('/')[0])
+            break
+        self._host_port = self._container_port
+        self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
+        logger.info(
+            'attached to container:',
+            self.container_name,
+            self._container_port,
+            self.api_url,
+        )
+
     def _refresh_logs(self):
         logger.debug('Getting container logs...')
 

+ 18 - 7
openhands/runtime/remote/runtime.py

@@ -51,6 +51,7 @@ class RemoteRuntime(Runtime):
         plugins: list[PluginRequirement] | None = None,
         env_vars: dict[str, str] | None = None,
         status_message_callback: Optional[Callable] = None,
+        attach_to_existing: bool = False,
     ):
         self.config = config
         self.status_message_callback = status_message_callback
@@ -75,21 +76,31 @@ class RemoteRuntime(Runtime):
         self.runtime_id: str | None = None
         self.runtime_url: str | None = None
 
-        self.instance_id = sid
+        self.sid = sid
 
-        self._start_or_attach_to_runtime(plugins)
+        self._start_or_attach_to_runtime(plugins, attach_to_existing)
 
         # Initialize the eventstream and env vars
         super().__init__(
-            config, event_stream, sid, plugins, env_vars, status_message_callback
+            config,
+            event_stream,
+            sid,
+            plugins,
+            env_vars,
+            status_message_callback,
+            attach_to_existing,
         )
         self._wait_until_alive()
         self.setup_initial_env()
 
-    def _start_or_attach_to_runtime(self, plugins: list[PluginRequirement] | None):
+    def _start_or_attach_to_runtime(
+        self, plugins: list[PluginRequirement] | None, attach_to_existing: bool = False
+    ):
         existing_runtime = self._check_existing_runtime()
         if existing_runtime:
             logger.info(f'Using existing runtime with ID: {self.runtime_id}')
+        elif attach_to_existing:
+            raise RuntimeError('Could not find existing runtime to attach to.')
         else:
             self.send_status_message('STATUS$STARTING_CONTAINER')
             if self.config.sandbox.runtime_container_image is None:
@@ -117,7 +128,7 @@ class RemoteRuntime(Runtime):
             response = send_request_with_retry(
                 self.session,
                 'GET',
-                f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.instance_id}',
+                f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.sid}',
                 timeout=5,
             )
         except Exception as e:
@@ -146,7 +157,7 @@ class RemoteRuntime(Runtime):
             return False
 
     def _build_runtime(self):
-        logger.debug(f'RemoteRuntime `{self.instance_id}` config:\n{self.config}')
+        logger.debug(f'RemoteRuntime `{self.sid}` config:\n{self.config}')
         response = send_request_with_retry(
             self.session,
             'GET',
@@ -209,7 +220,7 @@ class RemoteRuntime(Runtime):
             ),
             'working_dir': '/openhands/code/',
             'environment': {'DEBUG': 'true'} if self.config.debug else {},
-            'runtime_id': self.instance_id,
+            'runtime_id': self.sid,
         }
 
         # Start the sandbox using the /start endpoint

+ 5 - 0
openhands/runtime/runtime.py

@@ -52,6 +52,7 @@ class Runtime:
     sid: str
     config: AppConfig
     initial_env_vars: dict[str, str]
+    attach_to_existing: bool
 
     def __init__(
         self,
@@ -61,12 +62,14 @@ class Runtime:
         plugins: list[PluginRequirement] | None = None,
         env_vars: dict[str, str] | None = None,
         status_message_callback: Callable | None = None,
+        attach_to_existing: bool = False,
     ):
         self.sid = sid
         self.event_stream = event_stream
         self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
         self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
         self.status_message_callback = status_message_callback
+        self.attach_to_existing = attach_to_existing
 
         self.config = copy.deepcopy(config)
         atexit.register(self.close)
@@ -76,6 +79,8 @@ class Runtime:
             self.initial_env_vars.update(env_vars)
 
     def setup_initial_env(self) -> None:
+        if self.attach_to_existing:
+            return
         logger.debug(f'Adding env vars: {self.initial_env_vars}')
         self.add_env_vars(self.initial_env_vars)
         if self.config.sandbox.runtime_startup_env_vars:

+ 14 - 60
openhands/server/listen.py

@@ -25,7 +25,6 @@ from fastapi import (
     FastAPI,
     HTTPException,
     Request,
-    Response,
     UploadFile,
     WebSocket,
     status,
@@ -40,7 +39,6 @@ import openhands.agenthub  # noqa F401 (we import this to get the agents registe
 from openhands.controller.agent import Agent
 from openhands.core.config import LLMConfig, load_app_config
 from openhands.core.logger import openhands_logger as logger
-from openhands.core.schema import AgentState  # Add this import
 from openhands.events.action import (
     ChangeAgentStateAction,
     FileReadAction,
@@ -213,8 +211,10 @@ async def attach_session(request: Request, call_next):
             content={'error': 'Invalid token'},
         )
 
-    request.state.session = session_manager.get_session(request.state.sid)
-    if request.state.session is None:
+    request.state.conversation = session_manager.attach_to_conversation(
+        request.state.sid
+    )
+    if request.state.conversation is None:
         return JSONResponse(
             status_code=status.HTTP_404_NOT_FOUND,
             content={'error': 'Session not found'},
@@ -434,12 +434,13 @@ async def list_files(request: Request, path: str | None = None):
     Raises:
         HTTPException: If there's an error listing the files.
     """
-    if not request.state.session.agent_session.runtime:
+    if not request.state.conversation.runtime:
         return JSONResponse(
             status_code=status.HTTP_404_NOT_FOUND,
             content={'error': 'Runtime not yet initialized'},
         )
-    runtime: Runtime = request.state.session.agent_session.runtime
+
+    runtime: Runtime = request.state.conversation.runtime
     file_list = await sync_from_async(runtime.list_files, path)
     if path:
         file_list = [os.path.join(path, f) for f in file_list]
@@ -485,7 +486,7 @@ async def select_file(file: str, request: Request):
     Raises:
         HTTPException: If there's an error opening the file.
     """
-    runtime: Runtime = request.state.session.agent_session.runtime
+    runtime: Runtime = request.state.conversation.runtime
 
     file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
     read_action = FileReadAction(file)
@@ -567,7 +568,7 @@ async def upload_file(request: Request, files: list[UploadFile]):
                     tmp_file.write(file_contents)
                     tmp_file.flush()
 
-                runtime: Runtime = request.state.session.agent_session.runtime
+                runtime: Runtime = request.state.conversation.runtime
                 runtime.copy_to(
                     tmp_file_path, runtime.config.workspace_mount_path_in_sandbox
                 )
@@ -635,35 +636,6 @@ async def submit_feedback(request: Request, feedback: FeedbackDataModel):
         )
 
 
-@app.get('/api/root_task')
-def get_root_task(request: Request):
-    """Retrieve the root task of the current agent session.
-
-    To get the root_task:
-    ```sh
-    curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/root_task
-    ```
-
-    Args:
-        request (Request): The incoming request object.
-
-    Returns:
-        dict: The root task data if available.
-
-    Raises:
-        HTTPException: If the root task is not available.
-    """
-    controller = request.state.session.agent_session.controller
-    if controller is not None:
-        state = controller.get_state()
-        if state:
-            return JSONResponse(
-                status_code=status.HTTP_200_OK,
-                content=state.root_task.to_dict(),
-            )
-    return Response(status_code=status.HTTP_204_NO_CONTENT)
-
-
 @app.get('/api/defaults')
 async def appconfig_defaults():
     """Retrieve the default configuration settings.
@@ -700,22 +672,6 @@ async def save_file(request: Request):
             - 500 error if there's an unexpected error during the save operation.
     """
     try:
-        # Get the agent's current state
-        controller = request.state.session.agent_session.controller
-        agent_state = controller.get_agent_state()
-
-        # Check if the agent is in an allowed state for editing
-        if agent_state not in [
-            AgentState.INIT,
-            AgentState.PAUSED,
-            AgentState.FINISHED,
-            AgentState.AWAITING_USER_INPUT,
-        ]:
-            raise HTTPException(
-                status_code=403,
-                detail='Code editing is only allowed when the agent is paused, finished, or awaiting user input',
-            )
-
         # Extract file path and content from the request
         data = await request.json()
         file_path = data.get('filePath')
@@ -726,7 +682,7 @@ async def save_file(request: Request):
             raise HTTPException(status_code=400, detail='Missing filePath or content')
 
         # Save the file to the agent's runtime file store
-        runtime: Runtime = request.state.session.agent_session.runtime
+        runtime: Runtime = request.state.conversation.runtime
         file_path = os.path.join(
             runtime.config.workspace_mount_path_in_sandbox, file_path
         )
@@ -768,13 +724,11 @@ async def security_api(request: Request):
     Raises:
         HTTPException: If the security analyzer is not initialized.
     """
-    if not request.state.session.agent_session.security_analyzer:
+    if not request.state.conversation.security_analyzer:
         raise HTTPException(status_code=404, detail='Security analyzer not initialized')
 
-    return (
-        await request.state.session.agent_session.security_analyzer.handle_api_request(
-            request
-        )
+    return await request.state.conversation.security_analyzer.handle_api_request(
+        request
     )
 
 
@@ -782,7 +736,7 @@ async def security_api(request: Request):
 async def zip_current_workspace(request: Request):
     try:
         logger.info('Zipping workspace')
-        runtime: Runtime = request.state.session.agent_session.runtime
+        runtime: Runtime = request.state.conversation.runtime
 
         path = runtime.config.workspace_mount_path_in_sandbox
         zip_file_bytes = runtime.copy_from(path)

+ 36 - 0
openhands/server/session/conversation.py

@@ -0,0 +1,36 @@
+from openhands.core.config import AppConfig
+from openhands.events.stream import EventStream
+from openhands.runtime import get_runtime_cls
+from openhands.runtime.runtime import Runtime
+from openhands.security import SecurityAnalyzer, options
+from openhands.storage.files import FileStore
+
+
+class Conversation:
+    sid: str
+    file_store: FileStore
+    event_stream: EventStream
+    runtime: Runtime
+
+    def __init__(
+        self,
+        sid: str,
+        file_store: FileStore,
+        config: AppConfig,
+    ):
+        self.sid = sid
+        self.config = config
+        self.file_store = file_store
+        self.event_stream = EventStream(sid, file_store)
+        if config.security.security_analyzer:
+            self.security_analyzer = options.SecurityAnalyzers.get(
+                config.security.security_analyzer, SecurityAnalyzer
+            )(self.event_stream)
+
+        runtime_cls = get_runtime_cls(self.config.runtime)
+        self.runtime = runtime_cls(
+            config=config,
+            event_stream=self.event_stream,
+            sid=self.sid,
+            attach_to_existing=True,
+        )

+ 7 - 0
openhands/server/session/manager.py

@@ -6,7 +6,9 @@ from fastapi import WebSocket
 
 from openhands.core.config import AppConfig
 from openhands.core.logger import openhands_logger as logger
+from openhands.events.stream import session_exists
 from openhands.runtime.utils.shutdown_listener import should_continue
+from openhands.server.session.conversation import Conversation
 from openhands.server.session.session import Session
 from openhands.storage.files import FileStore
 
@@ -44,6 +46,11 @@ class SessionManager:
             return None
         return self._sessions.get(sid)
 
+    def attach_to_conversation(self, sid: str) -> Conversation | None:
+        if not session_exists(sid, self.file_store):
+            return None
+        return Conversation(sid, file_store=self.file_store, config=self.config)
+
     async def send(self, sid: str, data: dict[str, object]) -> bool:
         """Sends data to the client."""
         session = self.get_session(sid)