Преглед изворни кода

feat(frontend): Wait for events before rendering messages (#4994)

Co-authored-by: mamoodi <mamoodiha@gmail.com>
sp.wack пре 1 година
родитељ
комит
01cacf7c33

+ 93 - 0
frontend/__tests__/hooks/use-rate.test.ts

@@ -0,0 +1,93 @@
+import { act, renderHook } from "@testing-library/react";
+import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
+import { useRate } from "#/utils/use-rate";
+
+describe("useRate", () => {
+  beforeEach(() => {
+    vi.useFakeTimers();
+  });
+
+  afterEach(() => {
+    vi.useRealTimers();
+  });
+
+  it("should initialize", () => {
+    const { result } = renderHook(() => useRate());
+
+    expect(result.current.items).toHaveLength(0);
+    expect(result.current.rate).toBeNull();
+    expect(result.current.lastUpdated).toBeNull();
+    expect(result.current.isUnderThreshold).toBe(true);
+  });
+
+  it("should handle the case of a single element", () => {
+    const { result } = renderHook(() => useRate());
+
+    act(() => {
+      result.current.record(123);
+    });
+
+    expect(result.current.items).toHaveLength(1);
+    expect(result.current.lastUpdated).not.toBeNull();
+  });
+
+  it("should return the difference between the last two elements", () => {
+    const { result } = renderHook(() => useRate());
+
+    vi.setSystemTime(500);
+    act(() => {
+      result.current.record(4);
+    });
+
+    vi.advanceTimersByTime(500);
+    act(() => {
+      result.current.record(9);
+    });
+
+    expect(result.current.items).toHaveLength(2);
+    expect(result.current.rate).toBe(5);
+    expect(result.current.lastUpdated).toBe(1000);
+  });
+
+  it("should update isUnderThreshold after [threshold]ms of no activity", () => {
+    const { result } = renderHook(() => useRate({ threshold: 500 }));
+
+    expect(result.current.isUnderThreshold).toBe(true);
+
+    act(() => {
+       // not sure if fake timers is buggy with intervals,
+       // but I need to call it twice to register
+      vi.advanceTimersToNextTimer();
+      vi.advanceTimersToNextTimer();
+    });
+
+    expect(result.current.isUnderThreshold).toBe(false);
+  });
+
+  it("should return an isUnderThreshold boolean", () => {
+    const { result } = renderHook(() => useRate({ threshold: 500 }));
+
+    vi.setSystemTime(500);
+    act(() => {
+      result.current.record(400);
+    });
+    act(() => {
+      result.current.record(1000);
+    });
+
+    expect(result.current.isUnderThreshold).toBe(false);
+
+    act(() => {
+      result.current.record(1500);
+    });
+
+    expect(result.current.isUnderThreshold).toBe(true);
+
+    act(() => {
+      vi.advanceTimersToNextTimer();
+      vi.advanceTimersToNextTimer();
+    });
+
+    expect(result.current.isUnderThreshold).toBe(false);
+  });
+});

+ 31 - 24
frontend/src/components/chat-interface.tsx

@@ -28,7 +28,8 @@ const isErrorMessage = (
 ): message is ErrorMessage => "error" in message;
 
 export function ChatInterface() {
-  const { send } = useWsClient();
+  const { send, isLoadingMessages } = useWsClient();
+
   const dispatch = useDispatch();
   const scrollRef = React.useRef<HTMLDivElement>(null);
   const { scrollDomToBottom, onChatBodyScroll, hitBottom } =
@@ -101,30 +102,36 @@ export function ChatInterface() {
         onScroll={(e) => onChatBodyScroll(e.currentTarget)}
         className="flex flex-col grow overflow-y-auto overflow-x-hidden px-4 pt-4 gap-2"
       >
-        {messages.map((message, index) =>
-          isErrorMessage(message) ? (
-            <ErrorMessage
-              key={index}
-              id={message.id}
-              message={message.message}
-            />
-          ) : (
-            <ChatMessage
-              key={index}
-              type={message.sender}
-              message={message.content}
-            >
-              {message.imageUrls.length > 0 && (
-                <ImageCarousel size="small" images={message.imageUrls} />
-              )}
-              {messages.length - 1 === index &&
-                message.sender === "assistant" &&
-                curAgentState === AgentState.AWAITING_USER_CONFIRMATION && (
-                  <ConfirmationButtons />
-                )}
-            </ChatMessage>
-          ),
+        {isLoadingMessages && (
+          <div className="flex justify-center">
+            <div className="w-6 h-6 border-2 border-t-[4px] border-primary-500 rounded-full animate-spin" />
+          </div>
         )}
+        {!isLoadingMessages &&
+          messages.map((message, index) =>
+            isErrorMessage(message) ? (
+              <ErrorMessage
+                key={index}
+                id={message.id}
+                message={message.message}
+              />
+            ) : (
+              <ChatMessage
+                key={index}
+                type={message.sender}
+                message={message.content}
+              >
+                {message.imageUrls.length > 0 && (
+                  <ImageCarousel size="small" images={message.imageUrls} />
+                )}
+                {messages.length - 1 === index &&
+                  message.sender === "assistant" &&
+                  curAgentState === AgentState.AWAITING_USER_CONFIRMATION && (
+                    <ConfirmationButtons />
+                  )}
+              </ChatMessage>
+            ),
+          )}
       </div>
 
       <div className="flex flex-col gap-[6px] px-4 pb-4">

+ 13 - 1
frontend/src/context/ws-client-provider.tsx

@@ -5,6 +5,10 @@ import ActionType from "#/types/ActionType";
 import EventLogger from "#/utils/event-logger";
 import AgentState from "#/types/AgentState";
 import { handleAssistantMessage } from "#/services/actions";
+import { useRate } from "#/utils/use-rate";
+
+const isOpenHandsMessage = (event: Record<string, unknown>) =>
+  event.action === "message";
 
 const RECONNECT_RETRIES = 5;
 
@@ -17,12 +21,14 @@ export enum WsClientProviderStatus {
 
 interface UseWsClient {
   status: WsClientProviderStatus;
+  isLoadingMessages: boolean;
   events: Record<string, unknown>[];
   send: (event: Record<string, unknown>) => void;
 }
 
 const WsClientContext = React.createContext<UseWsClient>({
   status: WsClientProviderStatus.STOPPED,
+  isLoadingMessages: true,
   events: [],
   send: () => {
     throw new Error("not connected");
@@ -51,6 +57,8 @@ export function WsClientProvider({
   const [events, setEvents] = React.useState<Record<string, unknown>[]>([]);
   const [retryCount, setRetryCount] = React.useState(RECONNECT_RETRIES);
 
+  const messageRateHandler = useRate({ threshold: 500 });
+
   function send(event: Record<string, unknown>) {
     if (!wsRef.current) {
       EventLogger.error("WebSocket is not connected.");
@@ -71,6 +79,9 @@ export function WsClientProvider({
 
   function handleMessage(messageEvent: MessageEvent) {
     const event = JSON.parse(messageEvent.data);
+    if (isOpenHandsMessage(event)) {
+      messageRateHandler.record(new Date().getTime());
+    }
     setEvents((prevEvents) => [...prevEvents, event]);
     if (event.extras?.agent_state === AgentState.INIT) {
       setStatus(WsClientProviderStatus.ACTIVE);
@@ -177,10 +188,11 @@ export function WsClientProvider({
   const value = React.useMemo<UseWsClient>(
     () => ({
       status,
+      isLoadingMessages: messageRateHandler.isUnderThreshold,
       events,
       send,
     }),
-    [status, events],
+    [status, messageRateHandler.isUnderThreshold, events],
   );
 
   return (

+ 67 - 0
frontend/src/utils/use-rate.ts

@@ -0,0 +1,67 @@
+import React from "react";
+
+interface UseRateProps {
+  threshold: number;
+}
+
+const DEFAULT_CONFIG: UseRateProps = { threshold: 1000 };
+
+export const useRate = (config = DEFAULT_CONFIG) => {
+  const [items, setItems] = React.useState<number[]>([]);
+  const [rate, setRate] = React.useState<number | null>(null);
+  const [lastUpdated, setLastUpdated] = React.useState<number | null>(null);
+  const [isUnderThreshold, setIsUnderThreshold] = React.useState(true);
+
+  /**
+   * Record an entry in order to calculate the rate
+   * @param entry Entry to record
+   *
+   * @example
+   * record(new Date().getTime());
+   */
+  const record = (entry: number) => {
+    setItems((prev) => [...prev, entry]);
+    setLastUpdated(new Date().getTime());
+  };
+
+  /**
+   * Update the rate based on the last two entries (if available)
+   */
+  const updateRate = () => {
+    if (items.length > 1) {
+      const newRate = items[items.length - 1] - items[items.length - 2];
+      setRate(newRate);
+
+      if (newRate <= config.threshold) setIsUnderThreshold(true);
+      else setIsUnderThreshold(false);
+    }
+  };
+
+  React.useEffect(() => {
+    updateRate();
+  }, [items]);
+
+  React.useEffect(() => {
+    // Set up an interval to check if the time since the last update exceeds the threshold
+    // If it does, set isUnderThreshold to false, otherwise set it to true
+    // This ensures that the component can react to periods of inactivity
+    const intervalId = setInterval(() => {
+      if (lastUpdated !== null) {
+        const timeSinceLastUpdate = new Date().getTime() - lastUpdated;
+        setIsUnderThreshold(timeSinceLastUpdate <= config.threshold);
+      } else {
+        setIsUnderThreshold(false);
+      }
+    }, config.threshold);
+
+    return () => clearInterval(intervalId);
+  }, [lastUpdated, config.threshold]);
+
+  return {
+    items,
+    rate,
+    lastUpdated,
+    isUnderThreshold,
+    record,
+  };
+};