Просмотр исходного кода

Feature: Add ability to reconnect websockets (#4526)

tofarr 1 год назад
Родитель
Сommit
0e5e754420

+ 19 - 13
frontend/src/context/socket.tsx

@@ -2,9 +2,11 @@ import React from "react";
 import { Data } from "ws";
 import EventLogger from "#/utils/event-logger";
 
+const RECONNECT_RETRIES = 5;
+
 interface WebSocketClientOptions {
   token: string | null;
-  onOpen?: (event: Event) => void;
+  onOpen?: (event: Event, isNewSession: boolean) => void;
   onMessage?: (event: MessageEvent<Data>) => void;
   onError?: (event: Event) => void;
   onClose?: (event: Event) => void;
@@ -14,8 +16,8 @@ interface WebSocketContextType {
   send: (data: string | ArrayBufferLike | Blob | ArrayBufferView) => void;
   start: (options?: WebSocketClientOptions) => void;
   stop: () => void;
-  setRuntimeIsInitialized: () => void;
-  runtimeActive: boolean;
+  setRuntimeIsInitialized: (runtimeIsInitialized: boolean) => void;
+  runtimeIsInitialized: boolean;
   isConnected: boolean;
   events: Record<string, unknown>[];
 }
@@ -30,14 +32,11 @@ interface SocketProviderProps {
 
 function SocketProvider({ children }: SocketProviderProps) {
   const wsRef = React.useRef<WebSocket | null>(null);
+  const wsReconnectRetries = React.useRef<number>(RECONNECT_RETRIES);
   const [isConnected, setIsConnected] = React.useState(false);
-  const [runtimeActive, setRuntimeActive] = React.useState(false);
+  const [runtimeIsInitialized, setRuntimeIsInitialized] = React.useState(false);
   const [events, setEvents] = React.useState<Record<string, unknown>[]>([]);
 
-  const setRuntimeIsInitialized = () => {
-    setRuntimeActive(true);
-  };
-
   const start = React.useCallback((options?: WebSocketClientOptions): void => {
     if (wsRef.current) {
       EventLogger.warning(
@@ -59,7 +58,9 @@ function SocketProvider({ children }: SocketProviderProps) {
 
     ws.addEventListener("open", (event) => {
       setIsConnected(true);
-      options?.onOpen?.(event);
+      const isNewSession = sessionToken === "NO_JWT";
+      wsReconnectRetries.current = RECONNECT_RETRIES;
+      options?.onOpen?.(event, isNewSession);
     });
 
     ws.addEventListener("message", (event) => {
@@ -76,17 +77,22 @@ function SocketProvider({ children }: SocketProviderProps) {
 
     ws.addEventListener("close", (event) => {
       EventLogger.event(event, "SOCKET CLOSE");
-
       setIsConnected(false);
-      setRuntimeActive(false);
+      setRuntimeIsInitialized(false);
       wsRef.current = null;
       options?.onClose?.(event);
+      if (wsReconnectRetries.current) {
+        wsReconnectRetries.current -= 1;
+        const token = localStorage.getItem("token");
+        setTimeout(() => start({ ...(options || {}), token }), 1);
+      }
     });
 
     wsRef.current = ws;
   }, []);
 
   const stop = React.useCallback((): void => {
+    wsReconnectRetries.current = 0;
     if (wsRef.current) {
       wsRef.current.close();
       wsRef.current = null;
@@ -111,7 +117,7 @@ function SocketProvider({ children }: SocketProviderProps) {
       start,
       stop,
       setRuntimeIsInitialized,
-      runtimeActive,
+      runtimeIsInitialized,
       isConnected,
       events,
     }),
@@ -120,7 +126,7 @@ function SocketProvider({ children }: SocketProviderProps) {
       start,
       stop,
       setRuntimeIsInitialized,
-      runtimeActive,
+      runtimeIsInitialized,
       isConnected,
       events,
     ],

+ 30 - 18
frontend/src/routes/_oh.app.tsx

@@ -124,7 +124,8 @@ function App() {
   const { files, importedProjectZip } = useSelector(
     (state: RootState) => state.initalQuery,
   );
-  const { start, send, setRuntimeIsInitialized, runtimeActive } = useSocket();
+  const { start, send, setRuntimeIsInitialized, runtimeIsInitialized } =
+    useSocket();
   const { settings, token, ghToken, repo, q, lastCommit } =
     useLoaderData<typeof clientLoader>();
   const fetcher = useFetcher();
@@ -161,21 +162,32 @@ function App() {
     );
   };
 
+  const doSendInitialQuery = React.useRef<boolean>(true);
+
   const sendInitialQuery = (query: string, base64Files: string[]) => {
     const timestamp = new Date().toISOString();
     send(createChatMessage(query, base64Files, timestamp));
   };
 
-  const handleOpen = React.useCallback(() => {
-    const initEvent = {
-      action: ActionType.INIT,
-      args: settings,
-    };
-    send(JSON.stringify(initEvent));
-
-    // display query in UI, but don't send it to the server
-    if (q) addIntialQueryToChat(q, files);
-  }, [settings]);
+  const handleOpen = React.useCallback(
+    (event: Event, isNewSession: boolean) => {
+      if (!isNewSession) {
+        dispatch(clearMessages());
+        dispatch(clearTerminal());
+        dispatch(clearJupyter());
+      }
+      doSendInitialQuery.current = isNewSession;
+      const initEvent = {
+        action: ActionType.INIT,
+        args: settings,
+      };
+      send(JSON.stringify(initEvent));
+
+      // display query in UI, but don't send it to the server
+      if (q && isNewSession) addIntialQueryToChat(q, files);
+    },
+    [settings],
+  );
 
   const handleMessage = React.useCallback(
     (message: MessageEvent<WebSocket.Data>) => {
@@ -218,7 +230,7 @@ function App() {
         isAgentStateChange(parsed) &&
         parsed.extras.agent_state === AgentState.INIT
       ) {
-        setRuntimeIsInitialized();
+        setRuntimeIsInitialized(true);
 
         // handle new session
         if (!token) {
@@ -233,7 +245,7 @@ function App() {
             additionalInfo = `Files have been uploaded. Please check the /workspace for files.`;
           }
 
-          if (q) {
+          if (q && doSendInitialQuery.current) {
             if (additionalInfo) {
               sendInitialQuery(`${q}\n\n[${additionalInfo}]`, files);
             } else {
@@ -265,15 +277,15 @@ function App() {
   });
 
   React.useEffect(() => {
-    if (runtimeActive && userId && ghToken) {
+    if (runtimeIsInitialized && userId && ghToken) {
       // Export if the user valid, this could happen mid-session so it is handled here
       send(getGitHubTokenCommand(ghToken));
     }
-  }, [userId, ghToken, runtimeActive]);
+  }, [userId, ghToken, runtimeIsInitialized]);
 
   React.useEffect(() => {
     (async () => {
-      if (runtimeActive && importedProjectZip) {
+      if (runtimeIsInitialized && importedProjectZip) {
         // upload files action
         try {
           const blob = base64ToBlob(importedProjectZip);
@@ -287,7 +299,7 @@ function App() {
         }
       }
     })();
-  }, [runtimeActive, importedProjectZip]);
+  }, [runtimeIsInitialized, importedProjectZip]);
 
   const {
     isOpen: securityModalIsOpen,
@@ -303,7 +315,7 @@ function App() {
             className={cn(
               "w-2 h-2 rounded-full border",
               "absolute left-3 top-3",
-              runtimeActive
+              runtimeIsInitialized
                 ? "bg-green-800 border-green-500"
                 : "bg-red-800 border-red-500",
             )}

+ 7 - 6
openhands/controller/agent_controller.py

@@ -105,18 +105,19 @@ class AgentController:
         self.agent = agent
         self.headless_mode = headless_mode
 
-        # subscribe to the event stream
-        self.event_stream = event_stream
-        self.event_stream.subscribe(
-            EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
-        )
-
         # state from the previous session, state from a parent agent, or a fresh state
         self.set_initial_state(
             state=initial_state,
             max_iterations=max_iterations,
             confirmation_mode=confirmation_mode,
         )
+
+        # subscribe to the event stream
+        self.event_stream = event_stream
+        self.event_stream.subscribe(
+            EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
+        )
+
         self.max_budget_per_task = max_budget_per_task
         self.agent_to_llm_config = agent_to_llm_config if agent_to_llm_config else {}
         self.agent_configs = agent_configs if agent_configs else {}

+ 2 - 1
openhands/server/listen.py

@@ -329,11 +329,12 @@ async def websocket_endpoint(websocket: WebSocket):
             await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
             await websocket.close()
             return
+        logger.info(f'Existing session: {sid}')
     else:
         sid = str(uuid.uuid4())
         jwt_token = sign_token({'sid': sid}, config.jwt_secret)
+        logger.info(f'New session: {sid}')
 
-    logger.info(f'New session: {sid}')
     session = session_manager.add_or_restart_session(sid, websocket)
     await websocket.send_json({'token': jwt_token, 'status': 'ok'})
 

+ 3 - 3
openhands/server/session/agent_session.py

@@ -100,9 +100,9 @@ class AgentSession:
         config: AppConfig,
         agent: Agent,
         max_iterations: int,
-        max_budget_per_task: float | None = None,
-        agent_to_llm_config: dict[str, LLMConfig] | None = None,
-        agent_configs: dict[str, AgentConfig] | None = None,
+        max_budget_per_task: float | None,
+        agent_to_llm_config: dict[str, LLMConfig] | None,
+        agent_configs: dict[str, AgentConfig] | None,
     ):
         self._create_security_analyzer(config.security.security_analyzer)
         await self._create_runtime(

+ 4 - 2
openhands/server/session/session.py

@@ -30,7 +30,7 @@ class Session:
     sid: str
     websocket: WebSocket | None
     last_active_ts: int = 0
-    is_alive: bool = True
+    is_alive: bool = False
     agent_session: AgentSession
     loop: asyncio.AbstractEventLoop
 
@@ -109,6 +109,7 @@ class Session:
 
         # Create the agent session
         try:
+            self.is_alive = True
             await self.agent_session.start(
                 runtime_name=self.config.runtime,
                 config=self.config,
@@ -155,7 +156,8 @@ class Session:
     async def dispatch(self, data: dict):
         action = data.get('action', '')
         if action == ActionType.INIT:
-            await self._initialize_agent(data)
+            if not self.is_alive:
+                await self._initialize_agent(data)
             return
         event = event_from_dict(data.copy())
         # This checks if the model supports images