Просмотр исходного кода

Refactor websocket (#4879)

Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com>
tofarr 1 год назад
Родитель
Сommit
a1a9d2f175

+ 7 - 7
frontend/__tests__/components/chat/chat-interface.test.tsx

@@ -16,14 +16,14 @@ describe("Empty state", () => {
     send: vi.fn(),
   }));
 
-  const { useSocket: useSocketMock } = vi.hoisted(() => ({
-    useSocket: vi.fn(() => ({ send: sendMock, runtimeActive: true })),
+  const { useWsClient: useWsClientMock } = vi.hoisted(() => ({
+    useWsClient: vi.fn(() => ({ send: sendMock, runtimeActive: true })),
   }));
 
   beforeAll(() => {
     vi.mock("#/context/socket", async (importActual) => ({
-      ...(await importActual<typeof import("#/context/socket")>()),
-      useSocket: useSocketMock,
+      ...(await importActual<typeof import("#/context/ws-client-provider")>()),
+      useWsClient: useWsClientMock,
     }));
   });
 
@@ -77,7 +77,7 @@ describe("Empty state", () => {
     "should load the a user message to the input when selecting",
     async () => {
       // this is to test that the message is in the UI before the socket is called
-      useSocketMock.mockImplementation(() => ({
+      useWsClientMock.mockImplementation(() => ({
         send: sendMock,
         runtimeActive: false, // mock an inactive runtime setup
       }));
@@ -106,7 +106,7 @@ describe("Empty state", () => {
   it.fails(
     "should send the message to the socket only if the runtime is active",
     async () => {
-      useSocketMock.mockImplementation(() => ({
+      useWsClientMock.mockImplementation(() => ({
         send: sendMock,
         runtimeActive: false, // mock an inactive runtime setup
       }));
@@ -123,7 +123,7 @@ describe("Empty state", () => {
       await user.click(displayedSuggestions[0]);
       expect(sendMock).not.toHaveBeenCalled();
 
-      useSocketMock.mockImplementation(() => ({
+      useWsClientMock.mockImplementation(() => ({
         send: sendMock,
         runtimeActive: true, // mock an active runtime setup
       }));

+ 16 - 4
frontend/__tests__/hooks/use-terminal.test.tsx

@@ -2,8 +2,9 @@ import { beforeAll, describe, expect, it, vi } from "vitest";
 import { render } from "@testing-library/react";
 import { afterEach } from "node:test";
 import { useTerminal } from "#/hooks/useTerminal";
-import { SocketProvider } from "#/context/socket";
 import { Command } from "#/state/commandSlice";
+import { WsClientProvider } from "#/context/ws-client-provider";
+import { ReactNode } from "react";
 
 interface TestTerminalComponentProps {
   commands: Command[];
@@ -18,6 +19,17 @@ function TestTerminalComponent({
   return <div ref={ref} />;
 }
 
+interface WrapperProps {
+  children: ReactNode;
+}
+
+
+function Wrapper({children}: WrapperProps) {
+  return (
+    <WsClientProvider enabled={true} token="NO_JWT" ghToken="NO_GITHUB" settings={null}>{children}</WsClientProvider>
+  )
+}
+
 describe("useTerminal", () => {
   const mockTerminal = vi.hoisted(() => ({
     loadAddon: vi.fn(),
@@ -50,7 +62,7 @@ describe("useTerminal", () => {
 
   it("should render", () => {
     render(<TestTerminalComponent commands={[]} secrets={[]} />, {
-      wrapper: SocketProvider,
+      wrapper: Wrapper,
     });
   });
 
@@ -61,7 +73,7 @@ describe("useTerminal", () => {
     ];
 
     render(<TestTerminalComponent commands={commands} secrets={[]} />, {
-      wrapper: SocketProvider,
+      wrapper: Wrapper,
     });
 
     expect(mockTerminal.writeln).toHaveBeenNthCalledWith(1, "echo hello");
@@ -85,7 +97,7 @@ describe("useTerminal", () => {
         secrets={[secret, anotherSecret]}
       />,
       {
-        wrapper: SocketProvider,
+        wrapper: Wrapper,
       },
     );
 

+ 2 - 2
frontend/src/components/AgentControlBar.tsx

@@ -6,7 +6,7 @@ import PlayIcon from "#/assets/play";
 import { generateAgentStateChangeEvent } from "#/services/agentStateService";
 import { RootState } from "#/store";
 import AgentState from "#/types/AgentState";
-import { useSocket } from "#/context/socket";
+import { useWsClient } from "#/context/ws-client-provider";
 
 const IgnoreTaskStateMap: Record<string, AgentState[]> = {
   [AgentState.PAUSED]: [
@@ -72,7 +72,7 @@ function ActionButton({
 }
 
 function AgentControlBar() {
-  const { send } = useSocket();
+  const { send } = useWsClient();
   const { curAgentState } = useSelector((state: RootState) => state.agent);
 
   const handleAction = (action: AgentState) => {

+ 2 - 2
frontend/src/components/chat-interface.tsx

@@ -1,7 +1,6 @@
 import { useDispatch, useSelector } from "react-redux";
 import React from "react";
 import posthog from "posthog-js";
-import { useSocket } from "#/context/socket";
 import { convertImageToBase64 } from "#/utils/convert-image-to-base-64";
 import { ChatMessage } from "./chat-message";
 import { FeedbackActions } from "./feedback-actions";
@@ -22,13 +21,14 @@ import { ScrollToBottomButton } from "./scroll-to-bottom-button";
 import { Suggestions } from "./suggestions";
 import { SUGGESTIONS } from "#/utils/suggestions";
 import BuildIt from "#/icons/build-it.svg?react";
+import { useWsClient } from "#/context/ws-client-provider";
 
 const isErrorMessage = (
   message: Message | ErrorMessage,
 ): message is ErrorMessage => "error" in message;
 
 export function ChatInterface() {
-  const { send } = useSocket();
+  const { send } = useWsClient();
   const dispatch = useDispatch();
   const scrollRef = React.useRef<HTMLDivElement>(null);
   const { scrollDomToBottom, onChatBodyScroll, hitBottom } =

+ 2 - 2
frontend/src/components/chat/ConfirmationButtons.tsx

@@ -5,7 +5,7 @@ import RejectIcon from "#/assets/reject";
 import { I18nKey } from "#/i18n/declaration";
 import AgentState from "#/types/AgentState";
 import { generateAgentStateChangeEvent } from "#/services/agentStateService";
-import { useSocket } from "#/context/socket";
+import { useWsClient } from "#/context/ws-client-provider";
 
 interface ActionTooltipProps {
   type: "confirm" | "reject";
@@ -37,7 +37,7 @@ function ActionTooltip({ type, onClick }: ActionTooltipProps) {
 
 function ConfirmationButtons() {
   const { t } = useTranslation();
-  const { send } = useSocket();
+  const { send } = useWsClient();
 
   const handleStateChange = (state: AgentState) => {
     const event = generateAgentStateChangeEvent(state);

+ 188 - 0
frontend/src/components/event-handler.tsx

@@ -0,0 +1,188 @@
+import React from "react";
+import {
+  useFetcher,
+  useLoaderData,
+  useRouteLoaderData,
+} from "@remix-run/react";
+import { useDispatch, useSelector } from "react-redux";
+import toast from "react-hot-toast";
+
+import posthog from "posthog-js";
+import {
+  useWsClient,
+  WsClientProviderStatus,
+} from "#/context/ws-client-provider";
+import { ErrorObservation } from "#/types/core/observations";
+import { addErrorMessage, addUserMessage } from "#/state/chatSlice";
+import { handleAssistantMessage } from "#/services/actions";
+import {
+  getCloneRepoCommand,
+  getGitHubTokenCommand,
+} from "#/services/terminalService";
+import {
+  clearFiles,
+  clearSelectedRepository,
+  setImportedProjectZip,
+} from "#/state/initial-query-slice";
+import { clientLoader as appClientLoader } from "#/routes/_oh.app";
+import store, { RootState } from "#/store";
+import { createChatMessage } from "#/services/chatService";
+import { clientLoader as rootClientLoader } from "#/routes/_oh";
+import { isGitHubErrorReponse } from "#/api/github";
+import OpenHands from "#/api/open-hands";
+import { base64ToBlob } from "#/utils/base64-to-blob";
+import { setCurrentAgentState } from "#/state/agentSlice";
+import AgentState from "#/types/AgentState";
+import { getSettings } from "#/services/settings";
+
+interface ServerError {
+  error: boolean | string;
+  message: string;
+  [key: string]: unknown;
+}
+
+const isServerError = (data: object): data is ServerError => "error" in data;
+
+const isErrorObservation = (data: object): data is ErrorObservation =>
+  "observation" in data && data.observation === "error";
+
+export function EventHandler({ children }: React.PropsWithChildren) {
+  const { events, status, send } = useWsClient();
+  const statusRef = React.useRef<WsClientProviderStatus | null>(null);
+  const runtimeActive = status === WsClientProviderStatus.ACTIVE;
+  const fetcher = useFetcher();
+  const dispatch = useDispatch();
+  const { files, importedProjectZip } = useSelector(
+    (state: RootState) => state.initalQuery,
+  );
+  const { ghToken, repo } = useLoaderData<typeof appClientLoader>();
+  const initialQueryRef = React.useRef<string | null>(
+    store.getState().initalQuery.initialQuery,
+  );
+
+  const sendInitialQuery = (query: string, base64Files: string[]) => {
+    const timestamp = new Date().toISOString();
+    send(createChatMessage(query, base64Files, timestamp));
+  };
+  const data = useRouteLoaderData<typeof rootClientLoader>("routes/_oh");
+  const userId = React.useMemo(() => {
+    if (data?.user && !isGitHubErrorReponse(data.user)) return data.user.id;
+    return null;
+  }, [data?.user]);
+  const userSettings = getSettings();
+
+  React.useEffect(() => {
+    if (!events.length) {
+      return;
+    }
+    const event = events[events.length - 1];
+    if (event.token) {
+      fetcher.submit({ token: event.token as string }, { method: "post" });
+      return;
+    }
+
+    if (isServerError(event)) {
+      if (event.error_code === 401) {
+        toast.error("Session expired.");
+        fetcher.submit({}, { method: "POST", action: "/end-session" });
+        return;
+      }
+
+      if (typeof event.error === "string") {
+        toast.error(event.error);
+      } else {
+        toast.error(event.message);
+      }
+      return;
+    }
+
+    if (isErrorObservation(event)) {
+      dispatch(
+        addErrorMessage({
+          id: event.extras?.error_id,
+          message: event.message,
+        }),
+      );
+      return;
+    }
+    handleAssistantMessage(event);
+  }, [events.length]);
+
+  React.useEffect(() => {
+    if (statusRef.current === status) {
+      return; // This is a check because of strict mode - if the status did not change, don't do anything
+    }
+    statusRef.current = status;
+    const initialQuery = initialQueryRef.current;
+
+    if (status === WsClientProviderStatus.ACTIVE) {
+      let additionalInfo = "";
+      if (ghToken && repo) {
+        send(getCloneRepoCommand(ghToken, repo));
+        additionalInfo = `Repository ${repo} has been cloned to /workspace. Please check the /workspace for files.`;
+        dispatch(clearSelectedRepository()); // reset selected repository; maybe better to move this to '/'?
+      }
+      // if there's an uploaded project zip, add it to the chat
+      else if (importedProjectZip) {
+        additionalInfo = `Files have been uploaded. Please check the /workspace for files.`;
+      }
+
+      if (initialQuery) {
+        if (additionalInfo) {
+          sendInitialQuery(`${initialQuery}\n\n[${additionalInfo}]`, files);
+        } else {
+          sendInitialQuery(initialQuery, files);
+        }
+        dispatch(clearFiles()); // reset selected files
+        initialQueryRef.current = null;
+      }
+    }
+
+    if (status === WsClientProviderStatus.OPENING && initialQuery) {
+      dispatch(
+        addUserMessage({
+          content: initialQuery,
+          imageUrls: files,
+          timestamp: new Date().toISOString(),
+        }),
+      );
+    }
+
+    if (status === WsClientProviderStatus.STOPPED) {
+      store.dispatch(setCurrentAgentState(AgentState.STOPPED));
+    }
+  }, [status]);
+
+  React.useEffect(() => {
+    if (runtimeActive && userId && ghToken) {
+      // Export if the user valid, this could happen mid-session so it is handled here
+      send(getGitHubTokenCommand(ghToken));
+    }
+  }, [userId, ghToken, runtimeActive]);
+
+  React.useEffect(() => {
+    (async () => {
+      if (runtimeActive && importedProjectZip) {
+        // upload files action
+        try {
+          const blob = base64ToBlob(importedProjectZip);
+          const file = new File([blob], "imported-project.zip", {
+            type: blob.type,
+          });
+          await OpenHands.uploadFiles([file]);
+          dispatch(setImportedProjectZip(null));
+        } catch (error) {
+          toast.error("Failed to upload project files.");
+        }
+      }
+    })();
+  }, [runtimeActive, importedProjectZip]);
+
+  React.useEffect(() => {
+    if (userSettings.LLM_API_KEY) {
+      posthog.capture("user_activated");
+    }
+  }, [userSettings.LLM_API_KEY]);
+
+  return children;
+}

+ 2 - 2
frontend/src/components/project-menu/ProjectMenuCard.tsx

@@ -6,13 +6,13 @@ import EllipsisH from "#/icons/ellipsis-h.svg?react";
 import { ModalBackdrop } from "../modals/modal-backdrop";
 import { ConnectToGitHubModal } from "../modals/connect-to-github-modal";
 import { addUserMessage } from "#/state/chatSlice";
-import { useSocket } from "#/context/socket";
 import { createChatMessage } from "#/services/chatService";
 import { ProjectMenuCardContextMenu } from "./project.menu-card-context-menu";
 import { ProjectMenuDetailsPlaceholder } from "./project-menu-details-placeholder";
 import { ProjectMenuDetails } from "./project-menu-details";
 import { downloadWorkspace } from "#/utils/download-workspace";
 import { LoadingSpinner } from "../modals/LoadingProject";
+import { useWsClient } from "#/context/ws-client-provider";
 
 interface ProjectMenuCardProps {
   isConnectedToGitHub: boolean;
@@ -27,7 +27,7 @@ export function ProjectMenuCard({
   isConnectedToGitHub,
   githubData,
 }: ProjectMenuCardProps) {
-  const { send } = useSocket();
+  const { send } = useWsClient();
   const dispatch = useDispatch();
 
   const [contextMenuIsOpen, setContextMenuIsOpen] = React.useState(false);

+ 0 - 146
frontend/src/context/socket.tsx

@@ -1,146 +0,0 @@
-import React from "react";
-import { Data } from "ws";
-import posthog from "posthog-js";
-import EventLogger from "#/utils/event-logger";
-
-interface WebSocketClientOptions {
-  token: string | null;
-  onOpen?: (event: Event) => void;
-  onMessage?: (event: MessageEvent<Data>) => void;
-  onError?: (event: Event) => void;
-  onClose?: (event: Event) => void;
-}
-
-interface WebSocketContextType {
-  send: (data: string | ArrayBufferLike | Blob | ArrayBufferView) => void;
-  start: (options?: WebSocketClientOptions) => void;
-  stop: () => void;
-  setRuntimeIsInitialized: () => void;
-  runtimeActive: boolean;
-  isConnected: boolean;
-  events: Record<string, unknown>[];
-}
-
-const SocketContext = React.createContext<WebSocketContextType | undefined>(
-  undefined,
-);
-
-interface SocketProviderProps {
-  children: React.ReactNode;
-}
-
-function SocketProvider({ children }: SocketProviderProps) {
-  const wsRef = React.useRef<WebSocket | null>(null);
-  const [isConnected, setIsConnected] = React.useState(false);
-  const [runtimeActive, setRuntimeActive] = React.useState(false);
-  const [events, setEvents] = React.useState<Record<string, unknown>[]>([]);
-
-  const setRuntimeIsInitialized = () => {
-    setRuntimeActive(true);
-  };
-
-  const start = React.useCallback((options?: WebSocketClientOptions): void => {
-    if (wsRef.current) {
-      EventLogger.warning(
-        "WebSocket connection is already established, but a new one is starting anyways.",
-      );
-    }
-
-    const baseUrl =
-      import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host;
-    const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
-    const sessionToken = options?.token || "NO_JWT"; // not allowed to be empty or duplicated
-    const ghToken = localStorage.getItem("ghToken") || "NO_GITHUB";
-
-    const ws = new WebSocket(`${protocol}//${baseUrl}/ws`, [
-      "openhands",
-      sessionToken,
-      ghToken,
-    ]);
-
-    ws.addEventListener("open", (event) => {
-      posthog.capture("socket_opened");
-      setIsConnected(true);
-      options?.onOpen?.(event);
-    });
-
-    ws.addEventListener("message", (event) => {
-      EventLogger.message(event);
-
-      setEvents((prevEvents) => [...prevEvents, JSON.parse(event.data)]);
-      options?.onMessage?.(event);
-    });
-
-    ws.addEventListener("error", (event) => {
-      posthog.capture("socket_error");
-      EventLogger.event(event, "SOCKET ERROR");
-      options?.onError?.(event);
-    });
-
-    ws.addEventListener("close", (event) => {
-      posthog.capture("socket_closed");
-      EventLogger.event(event, "SOCKET CLOSE");
-
-      setIsConnected(false);
-      setRuntimeActive(false);
-      wsRef.current = null;
-      options?.onClose?.(event);
-    });
-
-    wsRef.current = ws;
-  }, []);
-
-  const stop = React.useCallback((): void => {
-    if (wsRef.current) {
-      wsRef.current.close();
-      wsRef.current = null;
-    }
-  }, []);
-
-  const send = React.useCallback(
-    (data: string | ArrayBufferLike | Blob | ArrayBufferView) => {
-      if (!wsRef.current) {
-        EventLogger.error("WebSocket is not connected.");
-        return;
-      }
-      setEvents((prevEvents) => [...prevEvents, JSON.parse(data.toString())]);
-      wsRef.current.send(data);
-    },
-    [],
-  );
-
-  const value = React.useMemo(
-    () => ({
-      send,
-      start,
-      stop,
-      setRuntimeIsInitialized,
-      runtimeActive,
-      isConnected,
-      events,
-    }),
-    [
-      send,
-      start,
-      stop,
-      setRuntimeIsInitialized,
-      runtimeActive,
-      isConnected,
-      events,
-    ],
-  );
-
-  return (
-    <SocketContext.Provider value={value}>{children}</SocketContext.Provider>
-  );
-}
-
-function useSocket() {
-  const context = React.useContext(SocketContext);
-  if (context === undefined) {
-    throw new Error("useSocket must be used within a SocketProvider");
-  }
-  return context;
-}
-
-export { SocketProvider, useSocket };

+ 175 - 0
frontend/src/context/ws-client-provider.tsx

@@ -0,0 +1,175 @@
+import posthog from "posthog-js";
+import React from "react";
+import { Settings } from "#/services/settings";
+import ActionType from "#/types/ActionType";
+import EventLogger from "#/utils/event-logger";
+import AgentState from "#/types/AgentState";
+
+export enum WsClientProviderStatus {
+  STOPPED,
+  OPENING,
+  ACTIVE,
+  ERROR,
+}
+
+interface UseWsClient {
+  status: WsClientProviderStatus;
+  events: Record<string, unknown>[];
+  send: (event: Record<string, unknown>) => void;
+}
+
+const WsClientContext = React.createContext<UseWsClient>({
+  status: WsClientProviderStatus.STOPPED,
+  events: [],
+  send: () => {
+    throw new Error("not connected");
+  },
+});
+
+interface WsClientProviderProps {
+  enabled: boolean;
+  token: string | null;
+  ghToken: string | null;
+  settings: Settings | null;
+}
+
+export function WsClientProvider({
+  enabled,
+  token,
+  ghToken,
+  settings,
+  children,
+}: React.PropsWithChildren<WsClientProviderProps>) {
+  const wsRef = React.useRef<WebSocket | null>(null);
+  const tokenRef = React.useRef<string | null>(token);
+  const ghTokenRef = React.useRef<string | null>(ghToken);
+  const closeRef = React.useRef<ReturnType<typeof setTimeout> | null>(null);
+  const [status, setStatus] = React.useState(WsClientProviderStatus.STOPPED);
+  const [events, setEvents] = React.useState<Record<string, unknown>[]>([]);
+
+  function send(event: Record<string, unknown>) {
+    if (!wsRef.current) {
+      EventLogger.error("WebSocket is not connected.");
+      return;
+    }
+    wsRef.current.send(JSON.stringify(event));
+  }
+
+  function handleOpen() {
+    setStatus(WsClientProviderStatus.OPENING);
+    const initEvent = {
+      action: ActionType.INIT,
+      args: settings,
+    };
+    send(initEvent);
+  }
+
+  function handleMessage(messageEvent: MessageEvent) {
+    const event = JSON.parse(messageEvent.data);
+    setEvents((prevEvents) => [...prevEvents, event]);
+    if (event.extras?.agent_state === AgentState.INIT) {
+      setStatus(WsClientProviderStatus.ACTIVE);
+    }
+    if (
+      status !== WsClientProviderStatus.ACTIVE &&
+      event?.observation === "error"
+    ) {
+      setStatus(WsClientProviderStatus.ERROR);
+    }
+  }
+
+  function handleClose() {
+    setStatus(WsClientProviderStatus.STOPPED);
+    setEvents([]);
+    wsRef.current = null;
+  }
+
+  function handleError(event: Event) {
+    posthog.capture("socket_error");
+    EventLogger.event(event, "SOCKET ERROR");
+    setStatus(WsClientProviderStatus.ERROR);
+  }
+
+  // Connect websocket
+  React.useEffect(() => {
+    let ws = wsRef.current;
+
+    // If disabled close any existing websockets...
+    if (!enabled) {
+      if (ws) {
+        ws.close();
+      }
+      wsRef.current = null;
+      return () => {};
+    }
+
+    // If there is no websocket or the tokens have changed or the current websocket is closed,
+    // create a new one
+    if (
+      !ws ||
+      (tokenRef.current && token !== tokenRef.current) ||
+      ghToken !== ghTokenRef.current ||
+      ws.readyState === WebSocket.CLOSED ||
+      ws.readyState === WebSocket.CLOSING
+    ) {
+      ws?.close();
+      const baseUrl =
+        import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host;
+      const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
+      ws = new WebSocket(`${protocol}//${baseUrl}/ws`, [
+        "openhands",
+        token || "NO_JWT",
+        ghToken || "NO_GITHUB",
+      ]);
+    }
+    ws.addEventListener("open", handleOpen);
+    ws.addEventListener("message", handleMessage);
+    ws.addEventListener("error", handleError);
+    ws.addEventListener("close", handleClose);
+    wsRef.current = ws;
+    tokenRef.current = token;
+    ghTokenRef.current = ghToken;
+
+    return () => {
+      ws.removeEventListener("open", handleOpen);
+      ws.removeEventListener("message", handleMessage);
+      ws.removeEventListener("error", handleError);
+      ws.removeEventListener("close", handleClose);
+    };
+  }, [enabled, token, ghToken]);
+
+  // Strict mode mounts and unmounts each component twice, so we have to wait in the destructor
+  // before actually closing the socket and cancel the operation if the component gets remounted.
+  React.useEffect(() => {
+    const timeout = closeRef.current;
+    if (timeout != null) {
+      clearTimeout(timeout);
+    }
+
+    return () => {
+      closeRef.current = setTimeout(() => {
+        wsRef.current?.close();
+      }, 100);
+    };
+  }, []);
+
+  const value = React.useMemo<UseWsClient>(
+    () => ({
+      status,
+      events,
+      send,
+    }),
+    [status, events],
+  );
+
+  return (
+    <WsClientContext.Provider value={value}>
+      {children}
+    </WsClientContext.Provider>
+  );
+}
+
+export function useWsClient() {
+  const context = React.useContext(WsClientContext);
+  return context;
+}

+ 4 - 7
frontend/src/entry.client.tsx

@@ -10,7 +10,6 @@ import React, { startTransition, StrictMode } from "react";
 import { hydrateRoot } from "react-dom/client";
 import { Provider } from "react-redux";
 import posthog from "posthog-js";
-import { SocketProvider } from "./context/socket";
 import "./i18n";
 import store from "./store";
 
@@ -43,12 +42,10 @@ prepareApp().then(() =>
     hydrateRoot(
       document,
       <StrictMode>
-        <SocketProvider>
-          <Provider store={store}>
-            <RemixBrowser />
-            <PosthogInit />
-          </Provider>
-        </SocketProvider>
+        <Provider store={store}>
+          <RemixBrowser />
+          <PosthogInit />
+        </Provider>
       </StrictMode>,
     );
   }),

+ 2 - 2
frontend/src/hooks/useTerminal.ts

@@ -4,7 +4,7 @@ import React from "react";
 import { Command } from "#/state/commandSlice";
 import { getTerminalCommand } from "#/services/terminalService";
 import { parseTerminalOutput } from "#/utils/parseTerminalOutput";
-import { useSocket } from "#/context/socket";
+import { useWsClient } from "#/context/ws-client-provider";
 
 /*
   NOTE: Tests for this hook are indirectly covered by the tests for the XTermTerminal component.
@@ -15,7 +15,7 @@ export const useTerminal = (
   commands: Command[] = [],
   secrets: string[] = [],
 ) => {
-  const { send } = useSocket();
+  const { send } = useWsClient();
   const terminal = React.useRef<Terminal | null>(null);
   const fitAddon = React.useRef<FitAddon | null>(null);
   const ref = React.useRef<HTMLDivElement>(null);

+ 62 - 249
frontend/src/routes/_oh.app.tsx

@@ -2,72 +2,29 @@ import { useDisclosure } from "@nextui-org/react";
 import React from "react";
 import {
   Outlet,
-  useFetcher,
   useLoaderData,
   json,
   ClientActionFunctionArgs,
-  useRouteLoaderData,
 } from "@remix-run/react";
-import { useDispatch, useSelector } from "react-redux";
-import WebSocket from "ws";
-import toast from "react-hot-toast";
-import posthog from "posthog-js";
+import { useDispatch } from "react-redux";
 import { getSettings } from "#/services/settings";
 import Security from "../components/modals/security/Security";
 import { Controls } from "#/components/controls";
-import store, { RootState } from "#/store";
+import store from "#/store";
 import { Container } from "#/components/container";
-import ActionType from "#/types/ActionType";
-import { handleAssistantMessage } from "#/services/actions";
-import {
-  addErrorMessage,
-  addUserMessage,
-  clearMessages,
-} from "#/state/chatSlice";
-import { useSocket } from "#/context/socket";
-import {
-  getGitHubTokenCommand,
-  getCloneRepoCommand,
-} from "#/services/terminalService";
+import { clearMessages } from "#/state/chatSlice";
 import { clearTerminal } from "#/state/commandSlice";
 import { useEffectOnce } from "#/utils/use-effect-once";
 import CodeIcon from "#/icons/code.svg?react";
 import GlobeIcon from "#/icons/globe.svg?react";
 import ListIcon from "#/icons/list-type-number.svg?react";
-import { createChatMessage } from "#/services/chatService";
-import {
-  clearFiles,
-  clearInitialQuery,
-  clearSelectedRepository,
-  setImportedProjectZip,
-} from "#/state/initial-query-slice";
+import { clearInitialQuery } from "#/state/initial-query-slice";
 import { isGitHubErrorReponse, retrieveLatestGitHubCommit } from "#/api/github";
-import OpenHands from "#/api/open-hands";
-import AgentState from "#/types/AgentState";
-import { base64ToBlob } from "#/utils/base64-to-blob";
-import { clientLoader as rootClientLoader } from "#/routes/_oh";
 import { clearJupyter } from "#/state/jupyterSlice";
 import { FilesProvider } from "#/context/files";
-import { ErrorObservation } from "#/types/core/observations";
 import { ChatInterface } from "#/components/chat-interface";
-
-interface ServerError {
-  error: boolean | string;
-  message: string;
-  [key: string]: unknown;
-}
-
-const isServerError = (data: object): data is ServerError => "error" in data;
-
-const isErrorObservation = (data: object): data is ErrorObservation =>
-  "observation" in data && data.observation === "error";
-
-const isAgentStateChange = (
-  data: object,
-): data is { extras: { agent_state: AgentState } } =>
-  "extras" in data &&
-  data.extras instanceof Object &&
-  "agent_state" in data.extras;
+import { WsClientProvider } from "#/context/ws-client-provider";
+import { EventHandler } from "#/components/event-handler";
 
 export const clientLoader = async () => {
   const ghToken = localStorage.getItem("ghToken");
@@ -117,179 +74,26 @@ export const clientAction = async ({ request }: ClientActionFunctionArgs) => {
 
 function App() {
   const dispatch = useDispatch();
-  const { files, importedProjectZip } = useSelector(
-    (state: RootState) => state.initalQuery,
-  );
-  const { start, send, setRuntimeIsInitialized, runtimeActive } = useSocket();
-  const { settings, token, ghToken, repo, q, lastCommit } =
+  const { settings, token, ghToken, lastCommit } =
     useLoaderData<typeof clientLoader>();
-  const fetcher = useFetcher();
-  const data = useRouteLoaderData<typeof rootClientLoader>("routes/_oh");
 
   const secrets = React.useMemo(
     () => [ghToken, token].filter((secret) => secret !== null),
     [ghToken, token],
   );
 
-  // To avoid re-rendering the component when the user object changes, we memoize the user ID.
-  // We use this to ensure the github token is valid before exporting it to the terminal.
-  const userId = React.useMemo(() => {
-    if (data?.user && !isGitHubErrorReponse(data.user)) return data.user.id;
-    return null;
-  }, [data?.user]);
-
   const Terminal = React.useMemo(
     () => React.lazy(() => import("../components/terminal/Terminal")),
     [],
   );
 
-  const addIntialQueryToChat = (
-    query: string,
-    base64Files: string[],
-    timestamp = new Date().toISOString(),
-  ) => {
-    dispatch(
-      addUserMessage({
-        content: query,
-        imageUrls: base64Files,
-        timestamp,
-      }),
-    );
-  };
-
-  const sendInitialQuery = (query: string, base64Files: string[]) => {
-    const timestamp = new Date().toISOString();
-    send(createChatMessage(query, base64Files, timestamp));
-
-    const userSettings = getSettings();
-    if (userSettings.LLM_API_KEY) {
-      posthog.capture("user_activated");
-    }
-  };
-
-  const handleOpen = React.useCallback(() => {
-    const initEvent = {
-      action: ActionType.INIT,
-      args: settings,
-    };
-    send(JSON.stringify(initEvent));
-
-    // display query in UI, but don't send it to the server
-    if (q) addIntialQueryToChat(q, files);
-  }, [settings]);
-
-  const handleMessage = React.useCallback(
-    (message: MessageEvent<WebSocket.Data>) => {
-      // set token received from the server
-      const parsed = JSON.parse(message.data.toString());
-      if ("token" in parsed) {
-        fetcher.submit({ token: parsed.token }, { method: "post" });
-        return;
-      }
-
-      if (isServerError(parsed)) {
-        if (parsed.error_code === 401) {
-          toast.error("Session expired.");
-          fetcher.submit({}, { method: "POST", action: "/end-session" });
-          return;
-        }
-
-        if (typeof parsed.error === "string") {
-          toast.error(parsed.error);
-        } else {
-          toast.error(parsed.message);
-        }
-
-        return;
-      }
-      if (isErrorObservation(parsed)) {
-        dispatch(
-          addErrorMessage({
-            id: parsed.extras?.error_id,
-            message: parsed.message,
-          }),
-        );
-        return;
-      }
-
-      handleAssistantMessage(message.data.toString());
-
-      // handle first time connection
-      if (
-        isAgentStateChange(parsed) &&
-        parsed.extras.agent_state === AgentState.INIT
-      ) {
-        setRuntimeIsInitialized();
-
-        // handle new session
-        if (!token) {
-          let additionalInfo = "";
-          if (ghToken && repo) {
-            send(getCloneRepoCommand(ghToken, repo));
-            additionalInfo = `Repository ${repo} has been cloned to /workspace. Please check the /workspace for files.`;
-            dispatch(clearSelectedRepository()); // reset selected repository; maybe better to move this to '/'?
-          }
-          // if there's an uploaded project zip, add it to the chat
-          else if (importedProjectZip) {
-            additionalInfo = `Files have been uploaded. Please check the /workspace for files.`;
-          }
-
-          if (q) {
-            if (additionalInfo) {
-              sendInitialQuery(`${q}\n\n[${additionalInfo}]`, files);
-            } else {
-              sendInitialQuery(q, files);
-            }
-            dispatch(clearFiles()); // reset selected files
-          }
-        }
-      }
-    },
-    [token, ghToken, repo, q, files],
-  );
-
-  const startSocketConnection = React.useCallback(() => {
-    start({
-      token,
-      onOpen: handleOpen,
-      onMessage: handleMessage,
-    });
-  }, [token, handleOpen, handleMessage]);
-
   useEffectOnce(() => {
-    // clear and restart the socket connection
     dispatch(clearMessages());
     dispatch(clearTerminal());
     dispatch(clearJupyter());
     dispatch(clearInitialQuery()); // Clear initial query when navigating to /app
-    startSocketConnection();
   });
 
-  React.useEffect(() => {
-    if (runtimeActive && userId && ghToken) {
-      // Export if the user valid, this could happen mid-session so it is handled here
-      send(getGitHubTokenCommand(ghToken));
-    }
-  }, [userId, ghToken, runtimeActive]);
-
-  React.useEffect(() => {
-    (async () => {
-      if (runtimeActive && importedProjectZip) {
-        // upload files action
-        try {
-          const blob = base64ToBlob(importedProjectZip);
-          const file = new File([blob], "imported-project.zip", {
-            type: blob.type,
-          });
-          await OpenHands.uploadFiles([file]);
-          dispatch(setImportedProjectZip(null));
-        } catch (error) {
-          toast.error("Failed to upload project files.");
-        }
-      }
-    })();
-  }, [runtimeActive, importedProjectZip]);
-
   const {
     isOpen: securityModalIsOpen,
     onOpen: onSecurityModalOpen,
@@ -297,53 +101,62 @@ function App() {
   } = useDisclosure();
 
   return (
-    <div className="flex flex-col h-full gap-3">
-      <div className="flex h-full overflow-auto gap-3">
-        <Container className="w-[390px] max-h-full relative">
-          <ChatInterface />
-        </Container>
-
-        <div className="flex flex-col grow gap-3">
-          <Container
-            className="h-2/3"
-            labels={[
-              { label: "Workspace", to: "", icon: <CodeIcon /> },
-              { label: "Jupyter", to: "jupyter", icon: <ListIcon /> },
-              {
-                label: "Browser",
-                to: "browser",
-                icon: <GlobeIcon />,
-                isBeta: true,
-              },
-            ]}
-          >
-            <FilesProvider>
-              <Outlet />
-            </FilesProvider>
-          </Container>
-          {/* Terminal uses some API that is not compatible in a server-environment. For this reason, we lazy load it to ensure
-           * that it loads only in the client-side. */}
-          <Container className="h-1/3 overflow-scroll" label="Terminal">
-            <React.Suspense fallback={<div className="h-full" />}>
-              <Terminal secrets={secrets} />
-            </React.Suspense>
-          </Container>
+    <WsClientProvider
+      enabled
+      token={token}
+      ghToken={ghToken}
+      settings={settings}
+    >
+      <EventHandler>
+        <div className="flex flex-col h-full gap-3">
+          <div className="flex h-full overflow-auto gap-3">
+            <Container className="w-[390px] max-h-full relative">
+              <ChatInterface />
+            </Container>
+
+            <div className="flex flex-col grow gap-3">
+              <Container
+                className="h-2/3"
+                labels={[
+                  { label: "Workspace", to: "", icon: <CodeIcon /> },
+                  { label: "Jupyter", to: "jupyter", icon: <ListIcon /> },
+                  {
+                    label: "Browser",
+                    to: "browser",
+                    icon: <GlobeIcon />,
+                    isBeta: true,
+                  },
+                ]}
+              >
+                <FilesProvider>
+                  <Outlet />
+                </FilesProvider>
+              </Container>
+              {/* Terminal uses some API that is not compatible in a server-environment. For this reason, we lazy load it to ensure
+               * that it loads only in the client-side. */}
+              <Container className="h-1/3 overflow-scroll" label="Terminal">
+                <React.Suspense fallback={<div className="h-full" />}>
+                  <Terminal secrets={secrets} />
+                </React.Suspense>
+              </Container>
+            </div>
+          </div>
+
+          <div className="h-[60px]">
+            <Controls
+              setSecurityOpen={onSecurityModalOpen}
+              showSecurityLock={!!settings.SECURITY_ANALYZER}
+              lastCommitData={lastCommit}
+            />
+          </div>
+          <Security
+            isOpen={securityModalIsOpen}
+            onOpenChange={onSecurityModalOpenChange}
+            securityAnalyzer={settings.SECURITY_ANALYZER}
+          />
         </div>
-      </div>
-
-      <div className="h-[60px]">
-        <Controls
-          setSecurityOpen={onSecurityModalOpen}
-          showSecurityLock={!!settings.SECURITY_ANALYZER}
-          lastCommitData={lastCommit}
-        />
-      </div>
-      <Security
-        isOpen={securityModalIsOpen}
-        onOpenChange={onSecurityModalOpenChange}
-        securityAnalyzer={settings.SECURITY_ANALYZER}
-      />
-    </div>
+      </EventHandler>
+    </WsClientProvider>
   );
 }
 

+ 3 - 15
frontend/src/routes/_oh.tsx

@@ -21,7 +21,6 @@ import { DangerModal } from "#/components/modals/confirmation-modals/danger-moda
 import { LoadingSpinner } from "#/components/modals/LoadingProject";
 import { ModalBackdrop } from "#/components/modals/modal-backdrop";
 import { UserActions } from "#/components/user-actions";
-import { useSocket } from "#/context/socket";
 import i18n from "#/i18n";
 import { getSettings, settingsAreUpToDate } from "#/services/settings";
 import AllHandsLogo from "#/assets/branding/all-hands-logo.svg?react";
@@ -135,7 +134,6 @@ type SettingsFormData = {
 };
 
 export default function MainApp() {
-  const { stop, isConnected } = useSocket();
   const navigation = useNavigation();
   const location = useLocation();
   const {
@@ -202,14 +200,6 @@ export default function MainApp() {
     }
   }, [user]);
 
-  React.useEffect(() => {
-    if (location.pathname === "/") {
-      // If the user is on the home page, we should stop the socket connection.
-      // This is relevant when the user redirects here for whatever reason.
-      if (isConnected) stop();
-    }
-  }, [location.pathname]);
-
   const handleUserLogout = () => {
     logoutFetcher.submit(
       {},
@@ -313,11 +303,9 @@ export default function MainApp() {
             <p className="text-xs text-[#A3A3A3]">
               To continue, connect an OpenAI, Anthropic, or other LLM account
             </p>
-            {isConnected && (
-              <p className="text-xs text-danger">
-                Changing settings during an active session will end the session
-              </p>
-            )}
+            <p className="text-xs text-danger">
+              Changing settings during an active session will end the session
+            </p>
             <SettingsForm
               settings={settings}
               models={settingsFormData.models}

+ 13 - 18
frontend/src/services/actions.ts

@@ -12,8 +12,11 @@ import {
 import { setCurStatusMessage } from "#/state/statusSlice";
 import store from "#/store";
 import ActionType from "#/types/ActionType";
-import { ActionMessage, StatusMessage } from "#/types/Message";
-import { SocketMessage } from "#/types/ResponseType";
+import {
+  ActionMessage,
+  ObservationMessage,
+  StatusMessage,
+} from "#/types/Message";
 import { handleObservationMessage } from "./observations";
 
 const messageActions = {
@@ -138,22 +141,14 @@ export function handleStatusMessage(message: StatusMessage) {
   }
 }
 
-export function handleAssistantMessage(data: string | SocketMessage) {
-  let socketMessage: SocketMessage;
-
-  if (typeof data === "string") {
-    socketMessage = JSON.parse(data) as SocketMessage;
-  } else {
-    socketMessage = data;
-  }
-
-  if ("action" in socketMessage) {
-    handleActionMessage(socketMessage);
-  } else if ("observation" in socketMessage) {
-    handleObservationMessage(socketMessage);
-  } else if ("status_update" in socketMessage) {
-    handleStatusMessage(socketMessage);
+export function handleAssistantMessage(message: Record<string, unknown>) {
+  if (message.action) {
+    handleActionMessage(message as unknown as ActionMessage);
+  } else if (message.observation) {
+    handleObservationMessage(message as unknown as ObservationMessage);
+  } else if (message.status_update) {
+    handleStatusMessage(message as unknown as StatusMessage);
   } else {
-    console.error("Unknown message type", socketMessage);
+    console.error("Unknown message type", message);
   }
 }

+ 4 - 5
frontend/src/services/agentStateService.ts

@@ -1,8 +1,7 @@
 import ActionType from "#/types/ActionType";
 import AgentState from "#/types/AgentState";
 
-export const generateAgentStateChangeEvent = (state: AgentState) =>
-  JSON.stringify({
-    action: ActionType.CHANGE_AGENT_STATE,
-    args: { agent_state: state },
-  });
+export const generateAgentStateChangeEvent = (state: AgentState) => ({
+  action: ActionType.CHANGE_AGENT_STATE,
+  args: { agent_state: state },
+});

+ 1 - 1
frontend/src/services/chatService.ts

@@ -9,5 +9,5 @@ export function createChatMessage(
     action: ActionType.MESSAGE,
     args: { content: message, images_urls, timestamp },
   };
-  return JSON.stringify(event);
+  return event;
 }

+ 1 - 1
frontend/src/services/terminalService.ts

@@ -2,7 +2,7 @@ import ActionType from "#/types/ActionType";
 
 export function getTerminalCommand(command: string, hidden: boolean = false) {
   const event = { action: ActionType.RUN, args: { command, hidden } };
-  return JSON.stringify(event);
+  return event;
 }
 
 export function getGitHubTokenCommand(gitHubToken: string) {

+ 1 - 0
frontend/src/utils/verified-models.ts

@@ -20,6 +20,7 @@ export const VERIFIED_ANTHROPIC_MODELS = [
   "claude-2",
   "claude-2.1",
   "claude-3-5-sonnet-20240620",
+  "claude-3-5-sonnet-20241022",
   "claude-3-haiku-20240307",
   "claude-3-opus-20240229",
   "claude-3-sonnet-20240229",

+ 2 - 2
frontend/test-utils.tsx

@@ -6,7 +6,7 @@ import { configureStore } from "@reduxjs/toolkit";
 // eslint-disable-next-line import/no-extraneous-dependencies
 import { RenderOptions, render } from "@testing-library/react";
 import { AppStore, RootState, rootReducer } from "./src/store";
-import { SocketProvider } from "#/context/socket";
+import { WsClientProvider } from "#/context/ws-client-provider";
 
 const setupStore = (preloadedState?: Partial<RootState>): AppStore =>
   configureStore({
@@ -35,7 +35,7 @@ export function renderWithProviders(
   function Wrapper({ children }: PropsWithChildren<object>): JSX.Element {
     return (
       <Provider store={store}>
-        <SocketProvider>{children}</SocketProvider>
+        <WsClientProvider enabled={true} token={null} ghToken={null} settings={null}>{children}</WsClientProvider>
       </Provider>
     );
   }