Răsfoiți Sursa

feat(frontend): Improve models input UI/UX in settings (#3530)

* Create helper functions

* Add map according to litellm docs

* Create ModelSelector

* Extend model selector

* use autocomplete from nextui

* Improve keys without providers

* Handle models without a provider

* Add verified section and some empty handling

* Add support for default or previously set models

* Update tests

* Lint

* Remove modifier

* Fix typescript error

* Functionality for switching to custom model

* Add verified models

* Respond to resetting to default

* Comment
sp.wack 1 an în urmă
părinte
comite
07e750f038

+ 193 - 0
frontend/src/components/modals/settings/ModelSelector.test.tsx

@@ -0,0 +1,193 @@
+import React from "react";
+import { describe, it, expect, vi } from "vitest";
+import { render, screen } from "@testing-library/react";
+import userEvent from "@testing-library/user-event";
+import { ModelSelector } from "./ModelSelector";
+
+describe("ModelSelector", () => {
+  const models = {
+    openai: {
+      separator: "/",
+      models: ["gpt-4o", "gpt-3.5-turbo"],
+    },
+    azure: {
+      separator: "/",
+      models: ["ada", "gpt-35-turbo"],
+    },
+    vertex_ai: {
+      separator: "/",
+      models: ["chat-bison", "chat-bison-32k"],
+    },
+    cohere: {
+      separator: ".",
+      models: ["command-r-v1:0"],
+    },
+  };
+
+  it("should display the provider selector", async () => {
+    const user = userEvent.setup();
+    const onModelChange = vi.fn();
+    render(<ModelSelector models={models} onModelChange={onModelChange} />);
+
+    const selector = screen.getByLabelText("Provider");
+    expect(selector).toBeInTheDocument();
+
+    await user.click(selector);
+
+    expect(screen.getByText("OpenAI")).toBeInTheDocument();
+    expect(screen.getByText("Azure")).toBeInTheDocument();
+    expect(screen.getByText("VertexAI")).toBeInTheDocument();
+    expect(screen.getByText("cohere")).toBeInTheDocument();
+  });
+
+  it("should disable the model selector if the provider is not selected", async () => {
+    const user = userEvent.setup();
+    const onModelChange = vi.fn();
+    render(<ModelSelector models={models} onModelChange={onModelChange} />);
+
+    const modelSelector = screen.getByLabelText("Model");
+    expect(modelSelector).toBeDisabled();
+
+    const providerSelector = screen.getByLabelText("Provider");
+    await user.click(providerSelector);
+
+    const vertexAI = screen.getByText("VertexAI");
+    await user.click(vertexAI);
+
+    expect(modelSelector).not.toBeDisabled();
+  });
+
+  it("should display the model selector", async () => {
+    const user = userEvent.setup();
+    const onModelChange = vi.fn();
+    render(<ModelSelector models={models} onModelChange={onModelChange} />);
+
+    const providerSelector = screen.getByLabelText("Provider");
+    await user.click(providerSelector);
+
+    const azureProvider = screen.getByText("Azure");
+    await user.click(azureProvider);
+
+    const modelSelector = screen.getByLabelText("Model");
+    await user.click(modelSelector);
+
+    expect(screen.getByText("ada")).toBeInTheDocument();
+    expect(screen.getByText("gpt-35-turbo")).toBeInTheDocument();
+
+    await user.click(providerSelector);
+    const vertexProvider = screen.getByText("VertexAI");
+    await user.click(vertexProvider);
+
+    await user.click(modelSelector);
+
+    expect(screen.getByText("chat-bison")).toBeInTheDocument();
+    expect(screen.getByText("chat-bison-32k")).toBeInTheDocument();
+  });
+
+  it("should display the actual litellm model ID as the user is making the selections", async () => {
+    const user = userEvent.setup();
+    const onModelChange = vi.fn();
+    render(<ModelSelector models={models} onModelChange={onModelChange} />);
+
+    const id = screen.getByTestId("model-id");
+    const providerSelector = screen.getByLabelText("Provider");
+    const modelSelector = screen.getByLabelText("Model");
+
+    expect(id).toHaveTextContent("No model selected");
+
+    await user.click(providerSelector);
+    await user.click(screen.getByText("Azure"));
+
+    expect(id).toHaveTextContent("azure/");
+
+    await user.click(modelSelector);
+    await user.click(screen.getByText("ada"));
+    expect(id).toHaveTextContent("azure/ada");
+
+    await user.click(providerSelector);
+    await user.click(screen.getByText("cohere"));
+    expect(id).toHaveTextContent("cohere.");
+
+    await user.click(modelSelector);
+    await user.click(screen.getByText("command-r-v1:0"));
+    expect(id).toHaveTextContent("cohere.command-r-v1:0");
+  });
+
+  it("should call onModelChange when the model is changed", async () => {
+    const user = userEvent.setup();
+    const onModelChange = vi.fn();
+    render(<ModelSelector models={models} onModelChange={onModelChange} />);
+
+    const providerSelector = screen.getByLabelText("Provider");
+    const modelSelector = screen.getByLabelText("Model");
+
+    await user.click(providerSelector);
+    await user.click(screen.getByText("Azure"));
+
+    await user.click(modelSelector);
+    await user.click(screen.getByText("ada"));
+
+    expect(onModelChange).toHaveBeenCalledTimes(1);
+    expect(onModelChange).toHaveBeenCalledWith("azure/ada");
+
+    await user.click(modelSelector);
+    await user.click(screen.getByText("gpt-35-turbo"));
+
+    expect(onModelChange).toHaveBeenCalledTimes(2);
+    expect(onModelChange).toHaveBeenCalledWith("azure/gpt-35-turbo");
+
+    await user.click(providerSelector);
+    await user.click(screen.getByText("cohere"));
+
+    await user.click(modelSelector);
+    await user.click(screen.getByText("command-r-v1:0"));
+
+    expect(onModelChange).toHaveBeenCalledTimes(3);
+    expect(onModelChange).toHaveBeenCalledWith("cohere.command-r-v1:0");
+  });
+
+  it("should clear the model ID when the provider is cleared", async () => {
+    const user = userEvent.setup();
+    const onModelChange = vi.fn();
+    render(<ModelSelector models={models} onModelChange={onModelChange} />);
+
+    const providerSelector = screen.getByLabelText("Provider");
+    const modelSelector = screen.getByLabelText("Model");
+
+    await user.click(providerSelector);
+    await user.click(screen.getByText("Azure"));
+
+    await user.click(modelSelector);
+    await user.click(screen.getByText("ada"));
+
+    expect(screen.getByTestId("model-id")).toHaveTextContent("azure/ada");
+
+    await user.clear(providerSelector);
+
+    expect(screen.getByTestId("model-id")).toHaveTextContent(
+      "No model selected",
+    );
+  });
+
+  it("should have a default value if passed", async () => {
+    const onModelChange = vi.fn();
+    render(
+      <ModelSelector
+        models={models}
+        onModelChange={onModelChange}
+        defaultModel="azure/ada"
+      />,
+    );
+
+    expect(screen.getByTestId("model-id")).toHaveTextContent("azure/ada");
+    expect(screen.getByLabelText("Provider")).toHaveValue("Azure");
+    expect(screen.getByLabelText("Model")).toHaveValue("ada");
+  });
+
+  it.todo("should disable provider if isDisabled is true");
+
+  it.todo(
+    "should display the verified models in the correct order",
+    async () => {},
+  );
+});

+ 133 - 0
frontend/src/components/modals/settings/ModelSelector.tsx

@@ -0,0 +1,133 @@
+import {
+  Autocomplete,
+  AutocompleteItem,
+  AutocompleteSection,
+} from "@nextui-org/react";
+import React from "react";
+import { mapProvider } from "#/utils/mapProvider";
+import { VERIFIED_MODELS, VERIFIED_PROVIDERS } from "#/utils/verified-models";
+import { extractModelAndProvider } from "#/utils/extractModelAndProvider";
+
+interface ModelSelectorProps {
+  isDisabled?: boolean;
+  models: Record<string, { separator: string; models: string[] }>;
+  onModelChange: (model: string) => void;
+  defaultModel?: string;
+}
+
+export function ModelSelector({
+  isDisabled,
+  models,
+  onModelChange,
+  defaultModel,
+}: ModelSelectorProps) {
+  const [litellmId, setLitellmId] = React.useState<string | null>(null);
+  const [selectedProvider, setSelectedProvider] = React.useState<string | null>(
+    null,
+  );
+  const [selectedModel, setSelectedModel] = React.useState<string | null>(null);
+
+  React.useEffect(() => {
+    if (defaultModel) {
+      // runs when resetting to defaults
+      const { provider, model } = extractModelAndProvider(defaultModel);
+
+      setLitellmId(defaultModel);
+      setSelectedProvider(provider);
+      setSelectedModel(model);
+    }
+  }, [defaultModel]);
+
+  const handleChangeProvider = (provider: string) => {
+    setSelectedProvider(provider);
+    setSelectedModel(null);
+
+    const separator = models[provider]?.separator || "";
+    setLitellmId(provider + separator);
+  };
+
+  const handleChangeModel = (model: string) => {
+    const separator = models[selectedProvider || ""]?.separator || "";
+    const fullModel = selectedProvider + separator + model;
+    setLitellmId(fullModel);
+    onModelChange(fullModel);
+    setSelectedModel(model);
+  };
+
+  const clear = () => {
+    setSelectedProvider(null);
+    setLitellmId(null);
+  };
+
+  return (
+    <div data-testid="model-selector" className="flex flex-col gap-2">
+      <span className="text-center italic text-gray-500" data-testid="model-id">
+        {litellmId?.replace("other", "") || "No model selected"}
+      </span>
+
+      <div className="flex flex-col gap-3">
+        <Autocomplete
+          isDisabled={isDisabled}
+          label="Provider"
+          placeholder="Select a provider"
+          isClearable={false}
+          onSelectionChange={(e) => {
+            if (e?.toString()) handleChangeProvider(e.toString());
+          }}
+          onInputChange={(value) => !value && clear()}
+          defaultSelectedKey={selectedProvider ?? undefined}
+          selectedKey={selectedProvider}
+        >
+          <AutocompleteSection title="Verified">
+            {Object.keys(models)
+              .filter((provider) => VERIFIED_PROVIDERS.includes(provider))
+              .map((provider) => (
+                <AutocompleteItem key={provider} value={provider}>
+                  {mapProvider(provider)}
+                </AutocompleteItem>
+              ))}
+          </AutocompleteSection>
+          <AutocompleteSection title="Others">
+            {Object.keys(models)
+              .filter((provider) => !VERIFIED_PROVIDERS.includes(provider))
+              .map((provider) => (
+                <AutocompleteItem key={provider} value={provider}>
+                  {mapProvider(provider)}
+                </AutocompleteItem>
+              ))}
+          </AutocompleteSection>
+        </Autocomplete>
+
+        <Autocomplete
+          label="Model"
+          placeholder="Select a model"
+          onSelectionChange={(e) => {
+            if (e?.toString()) handleChangeModel(e.toString());
+          }}
+          isDisabled={isDisabled || !selectedProvider}
+          selectedKey={selectedModel}
+          defaultSelectedKey={selectedModel ?? undefined}
+        >
+          <AutocompleteSection title="Verified">
+            {models[selectedProvider || ""]?.models
+              .filter((model) => VERIFIED_MODELS.includes(model))
+              .map((model) => (
+                <AutocompleteItem key={model} value={model}>
+                  {model}
+                </AutocompleteItem>
+              ))}
+          </AutocompleteSection>
+          <AutocompleteSection title="Others">
+            {models[selectedProvider || ""]?.models
+              .filter((model) => !VERIFIED_MODELS.includes(model))
+              .map((model) => (
+                <AutocompleteItem key={model} value={model}>
+                  {model}
+                </AutocompleteItem>
+              ))}
+          </AutocompleteSection>
+        </Autocomplete>
+      </div>
+    </div>
+  );
+}

+ 101 - 26
frontend/src/components/modals/settings/SettingsForm.test.tsx

@@ -6,6 +6,8 @@ import { Settings } from "#/services/settings";
 import SettingsForm from "./SettingsForm";
 
 const onModelChangeMock = vi.fn();
+const onCustomModelChangeMock = vi.fn();
+const onModelTypeChangeMock = vi.fn();
 const onAgentChangeMock = vi.fn();
 const onLanguageChangeMock = vi.fn();
 const onAPIKeyChangeMock = vi.fn();
@@ -18,7 +20,9 @@ const renderSettingsForm = (settings?: Settings) => {
       disabled={false}
       settings={
         settings || {
-          LLM_MODEL: "model1",
+          LLM_MODEL: "gpt-4o",
+          CUSTOM_LLM_MODEL: "",
+          USING_CUSTOM_MODEL: false,
           AGENT: "agent1",
           LANGUAGE: "en",
           LLM_API_KEY: "sk-...",
@@ -26,10 +30,12 @@ const renderSettingsForm = (settings?: Settings) => {
           SECURITY_ANALYZER: "analyzer1",
         }
       }
-      models={["model1", "model2", "model3"]}
+      models={["gpt-4o", "gpt-3.5-turbo", "azure/ada"]}
       agents={["agent1", "agent2", "agent3"]}
       securityAnalyzers={["analyzer1", "analyzer2", "analyzer3"]}
       onModelChange={onModelChangeMock}
+      onCustomModelChange={onCustomModelChangeMock}
+      onModelTypeChange={onModelTypeChangeMock}
       onAgentChange={onAgentChangeMock}
       onLanguageChange={onLanguageChangeMock}
       onAPIKeyChange={onAPIKeyChangeMock}
@@ -43,7 +49,8 @@ describe("SettingsForm", () => {
   it("should display the first values in the array by default", () => {
     renderSettingsForm();
 
-    const modelInput = screen.getByRole("combobox", { name: "model" });
+    const providerInput = screen.getByRole("combobox", { name: "Provider" });
+    const modelInput = screen.getByRole("combobox", { name: "Model" });
     const agentInput = screen.getByRole("combobox", { name: "agent" });
     const languageInput = screen.getByRole("combobox", { name: "language" });
     const apiKeyInput = screen.getByTestId("apikey");
@@ -52,7 +59,8 @@ describe("SettingsForm", () => {
       name: "securityanalyzer",
     });
 
-    expect(modelInput).toHaveValue("model1");
+    expect(providerInput).toHaveValue("OpenAI");
+    expect(modelInput).toHaveValue("gpt-4o");
     expect(agentInput).toHaveValue("agent1");
     expect(languageInput).toHaveValue("English");
     expect(apiKeyInput).toHaveValue("sk-...");
@@ -62,7 +70,9 @@ describe("SettingsForm", () => {
 
   it("should display the existing values if they are present", () => {
     renderSettingsForm({
-      LLM_MODEL: "model2",
+      LLM_MODEL: "gpt-3.5-turbo",
+      CUSTOM_LLM_MODEL: "",
+      USING_CUSTOM_MODEL: false,
       AGENT: "agent2",
       LANGUAGE: "es",
       LLM_API_KEY: "sk-...",
@@ -70,14 +80,16 @@ describe("SettingsForm", () => {
       SECURITY_ANALYZER: "analyzer2",
     });
 
-    const modelInput = screen.getByRole("combobox", { name: "model" });
+    const providerInput = screen.getByRole("combobox", { name: "Provider" });
+    const modelInput = screen.getByRole("combobox", { name: "Model" });
     const agentInput = screen.getByRole("combobox", { name: "agent" });
     const languageInput = screen.getByRole("combobox", { name: "language" });
     const securityAnalyzerInput = screen.getByRole("combobox", {
       name: "securityanalyzer",
     });
 
-    expect(modelInput).toHaveValue("model2");
+    expect(providerInput).toHaveValue("OpenAI");
+    expect(modelInput).toHaveValue("gpt-3.5-turbo");
     expect(agentInput).toHaveValue("agent2");
     expect(languageInput).toHaveValue("Español");
     expect(securityAnalyzerInput).toHaveValue("analyzer2");
@@ -87,18 +99,22 @@ describe("SettingsForm", () => {
     renderWithProviders(
       <SettingsForm
         settings={{
-          LLM_MODEL: "model1",
+          LLM_MODEL: "gpt-4o",
+          CUSTOM_LLM_MODEL: "",
+          USING_CUSTOM_MODEL: false,
           AGENT: "agent1",
           LANGUAGE: "en",
           LLM_API_KEY: "sk-...",
           CONFIRMATION_MODE: true,
           SECURITY_ANALYZER: "analyzer1",
         }}
-        models={["model1", "model2", "model3"]}
+        models={["gpt-4o", "gpt-3.5-turbo", "azure/ada"]}
         agents={["agent1", "agent2", "agent3"]}
         securityAnalyzers={["analyzer1", "analyzer2", "analyzer3"]}
         disabled
         onModelChange={onModelChangeMock}
+        onCustomModelChange={onCustomModelChangeMock}
+        onModelTypeChange={onModelTypeChangeMock}
         onAgentChange={onAgentChangeMock}
         onLanguageChange={onLanguageChangeMock}
         onAPIKeyChange={onAPIKeyChangeMock}
@@ -106,7 +122,9 @@ describe("SettingsForm", () => {
         onSecurityAnalyzerChange={onSecurityAnalyzerChangeMock}
       />,
     );
-    const modelInput = screen.getByRole("combobox", { name: "model" });
+
+    const providerInput = screen.getByRole("combobox", { name: "Provider" });
+    const modelInput = screen.getByRole("combobox", { name: "Model" });
     const agentInput = screen.getByRole("combobox", { name: "agent" });
     const languageInput = screen.getByRole("combobox", { name: "language" });
     const confirmationModeInput = screen.getByTestId("confirmationmode");
@@ -114,6 +132,7 @@ describe("SettingsForm", () => {
       name: "securityanalyzer",
     });
 
+    expect(providerInput).toBeDisabled();
     expect(modelInput).toBeDisabled();
     expect(agentInput).toBeDisabled();
     expect(languageInput).toBeDisabled();
@@ -122,22 +141,6 @@ describe("SettingsForm", () => {
   });
 
   describe("onChange handlers", () => {
-    it("should call the onModelChange handler when the model changes", async () => {
-      renderSettingsForm();
-
-      const modelInput = screen.getByRole("combobox", { name: "model" });
-      await act(async () => {
-        await userEvent.click(modelInput);
-      });
-
-      const model3 = screen.getByText("model3");
-      await act(async () => {
-        await userEvent.click(model3);
-      });
-
-      expect(onModelChangeMock).toHaveBeenCalledWith("model3");
-    });
-
     it("should call the onAgentChange handler when the agent changes", async () => {
       const user = userEvent.setup();
       renderSettingsForm();
@@ -182,4 +185,76 @@ describe("SettingsForm", () => {
       expect(onAPIKeyChangeMock).toHaveBeenCalledWith("sk-...x");
     });
   });
+
+  describe("Setting a custom LLM model", () => {
+    it("should display the fetched models by default", () => {
+      renderSettingsForm();
+
+      const modelSelector = screen.getByTestId("model-selector");
+      expect(modelSelector).toBeInTheDocument();
+
+      const customModelInput = screen.queryByTestId("custom-model-input");
+      expect(customModelInput).not.toBeInTheDocument();
+    });
+
+    it("should switch to the custom model input when the custom model toggle is clicked", async () => {
+      const user = userEvent.setup();
+      renderSettingsForm();
+
+      const customModelToggle = screen.getByTestId("custom-model-toggle");
+      await user.click(customModelToggle);
+
+      const modelSelector = screen.queryByTestId("model-selector");
+      expect(modelSelector).not.toBeInTheDocument();
+
+      const customModelInput = screen.getByTestId("custom-model-input");
+      expect(customModelInput).toBeInTheDocument();
+    });
+
+    it("should call the onCustomModelChange handler when the custom model input changes", async () => {
+      const user = userEvent.setup();
+      renderSettingsForm();
+
+      const customModelToggle = screen.getByTestId("custom-model-toggle");
+      await user.click(customModelToggle);
+
+      const customModelInput = screen.getByTestId("custom-model-input");
+      await userEvent.type(customModelInput, "my/custom-model");
+
+      expect(onCustomModelChangeMock).toHaveBeenCalledWith("my/custom-model");
+      expect(onModelTypeChangeMock).toHaveBeenCalledWith("custom");
+    });
+
+    it("should have custom model switched if using custom model", () => {
+      renderWithProviders(
+        <SettingsForm
+          settings={{
+            LLM_MODEL: "gpt-4o",
+            CUSTOM_LLM_MODEL: "CUSTOM_MODEL",
+            USING_CUSTOM_MODEL: true,
+            AGENT: "agent1",
+            LANGUAGE: "en",
+            LLM_API_KEY: "sk-...",
+            CONFIRMATION_MODE: true,
+            SECURITY_ANALYZER: "analyzer1",
+          }}
+          models={["gpt-4o", "gpt-3.5-turbo", "azure/ada"]}
+          agents={["agent1", "agent2", "agent3"]}
+          securityAnalyzers={["analyzer1", "analyzer2", "analyzer3"]}
+          disabled
+          onModelChange={onModelChangeMock}
+          onCustomModelChange={onCustomModelChangeMock}
+          onModelTypeChange={onModelTypeChangeMock}
+          onAgentChange={onAgentChangeMock}
+          onLanguageChange={onLanguageChangeMock}
+          onAPIKeyChange={onAPIKeyChangeMock}
+          onConfirmationModeChange={onConfirmationModeChangeMock}
+          onSecurityAnalyzerChange={onSecurityAnalyzerChangeMock}
+        />,
+      );
+
+      const customModelToggle = screen.getByTestId("custom-model-toggle");
+      expect(customModelToggle).toHaveAttribute("aria-checked", "true");
+    });
+  });
 });

+ 43 - 11
frontend/src/components/modals/settings/SettingsForm.tsx

@@ -6,6 +6,8 @@ import { AvailableLanguages } from "../../../i18n";
 import { I18nKey } from "../../../i18n/declaration";
 import { AutocompleteCombobox } from "./AutocompleteCombobox";
 import { Settings } from "#/services/settings";
+import { organizeModelsAndProviders } from "#/utils/organizeModelsAndProviders";
+import { ModelSelector } from "./ModelSelector";
 
 interface SettingsFormProps {
   settings: Settings;
@@ -15,6 +17,8 @@ interface SettingsFormProps {
   disabled: boolean;
 
   onModelChange: (model: string) => void;
+  onCustomModelChange: (model: string) => void;
+  onModelTypeChange: (type: "custom" | "default") => void;
   onAPIKeyChange: (apiKey: string) => void;
   onAgentChange: (agent: string) => void;
   onLanguageChange: (language: string) => void;
@@ -29,6 +33,8 @@ function SettingsForm({
   securityAnalyzers,
   disabled,
   onModelChange,
+  onCustomModelChange,
+  onModelTypeChange,
   onAPIKeyChange,
   onAgentChange,
   onLanguageChange,
@@ -38,20 +44,46 @@ function SettingsForm({
   const { t } = useTranslation();
   const { isOpen: isVisible, onOpenChange: onVisibleChange } = useDisclosure();
   const [isAgentSelectEnabled, setIsAgentSelectEnabled] = React.useState(false);
+  const [usingCustomModel, setUsingCustomModel] = React.useState(
+    settings.USING_CUSTOM_MODEL,
+  );
+
+  const changeModelType = (type: "custom" | "default") => {
+    if (type === "custom") {
+      setUsingCustomModel(true);
+      onModelTypeChange("custom");
+    } else {
+      setUsingCustomModel(false);
+      onModelTypeChange("default");
+    }
+  };
 
   return (
     <>
-      <AutocompleteCombobox
-        ariaLabel="model"
-        items={models.map((model) => ({ value: model, label: model }))}
-        defaultKey={settings.LLM_MODEL}
-        onChange={(e) => {
-          onModelChange(e);
-        }}
-        tooltip={t(I18nKey.SETTINGS$MODEL_TOOLTIP)}
-        allowCustomValue // user can type in a custom LLM model that is not in the list
-        disabled={disabled}
-      />
+      <Switch
+        data-testid="custom-model-toggle"
+        aria-checked={usingCustomModel}
+        isSelected={usingCustomModel}
+        onValueChange={(value) => changeModelType(value ? "custom" : "default")}
+      >
+        Use custom model
+      </Switch>
+      {usingCustomModel && (
+        <Input
+          data-testid="custom-model-input"
+          label="Custom Model"
+          onValueChange={onCustomModelChange}
+          defaultValue={settings.CUSTOM_LLM_MODEL}
+        />
+      )}
+      {!usingCustomModel && (
+        <ModelSelector
+          isDisabled={disabled}
+          models={organizeModelsAndProviders(models)}
+          onModelChange={onModelChange}
+          defaultModel={settings.LLM_MODEL}
+        />
+      )}
       <Input
         label="API Key"
         isDisabled={disabled}

+ 47 - 14
frontend/src/components/modals/settings/SettingsModal.test.tsx

@@ -24,6 +24,8 @@ vi.mock("#/services/settings", async (importOriginal) => ({
   ...(await importOriginal<typeof import("#/services/settings")>()),
   getSettings: vi.fn().mockReturnValue({
     LLM_MODEL: "gpt-4o",
+    CUSTOM_LLM_MODEL: "",
+    USING_CUSTOM_MODEL: false,
     AGENT: "CodeActAgent",
     LANGUAGE: "en",
     LLM_API_KEY: "sk-...",
@@ -32,6 +34,8 @@ vi.mock("#/services/settings", async (importOriginal) => ({
   }),
   getDefaultSettings: vi.fn().mockReturnValue({
     LLM_MODEL: "gpt-4o",
+    CUSTOM_LLM_MODEL: "",
+    USING_CUSTOM_MODEL: false,
     AGENT: "CodeActAgent",
     LANGUAGE: "en",
     LLM_API_KEY: "",
@@ -46,7 +50,14 @@ vi.mock("#/services/options", async (importOriginal) => ({
   ...(await importOriginal<typeof import("#/services/options")>()),
   fetchModels: vi
     .fn()
-    .mockResolvedValue(Promise.resolve(["model1", "model2", "model3"])),
+    .mockResolvedValue(
+      Promise.resolve([
+        "gpt-4o",
+        "gpt-3.5-turbo",
+        "azure/ada",
+        "cohere.command-r-v1:0",
+      ]),
+    ),
   fetchAgents: vi
     .fn()
     .mockResolvedValue(Promise.resolve(["agent1", "agent2", "agent3"])),
@@ -104,6 +115,8 @@ describe("SettingsModal", () => {
   describe("onHandleSave", () => {
     const initialSettings: Settings = {
       LLM_MODEL: "gpt-4o",
+      CUSTOM_LLM_MODEL: "",
+      USING_CUSTOM_MODEL: false,
       AGENT: "CodeActAgent",
       LANGUAGE: "en",
       LLM_API_KEY: "sk-...",
@@ -122,17 +135,22 @@ describe("SettingsModal", () => {
       await assertModelsAndAgentsFetched();
 
       const saveButton = screen.getByRole("button", { name: /save/i });
-      const modelInput = screen.getByRole("combobox", { name: "model" });
+      const providerInput = screen.getByRole("combobox", { name: "Provider" });
+      const modelInput = screen.getByRole("combobox", { name: "Model" });
 
-      await user.click(modelInput);
-      const model3 = screen.getByText("model3");
+      await user.click(providerInput);
+      const azure = screen.getByText("Azure");
+      await user.click(azure);
 
+      await user.click(modelInput);
+      const model3 = screen.getByText("ada");
       await user.click(model3);
+
       await user.click(saveButton);
 
       expect(saveSettings).toHaveBeenCalledWith({
         ...initialSettings,
-        LLM_MODEL: "model3",
+        LLM_MODEL: "azure/ada",
       });
     });
 
@@ -146,12 +164,17 @@ describe("SettingsModal", () => {
       );
 
       const saveButton = screen.getByRole("button", { name: /save/i });
-      const modelInput = screen.getByRole("combobox", { name: "model" });
+      const providerInput = screen.getByRole("combobox", { name: "Provider" });
+      const modelInput = screen.getByRole("combobox", { name: "Model" });
 
-      await user.click(modelInput);
-      const model3 = screen.getByText("model3");
+      await user.click(providerInput);
+      const openai = screen.getByText("OpenAI");
+      await user.click(openai);
 
+      await user.click(modelInput);
+      const model3 = screen.getByText("gpt-3.5-turbo");
       await user.click(model3);
+
       await user.click(saveButton);
 
       expect(startNewSessionSpy).toHaveBeenCalled();
@@ -167,12 +190,17 @@ describe("SettingsModal", () => {
       );
 
       const saveButton = screen.getByRole("button", { name: /save/i });
-      const modelInput = screen.getByRole("combobox", { name: "model" });
+      const providerInput = screen.getByRole("combobox", { name: "Provider" });
+      const modelInput = screen.getByRole("combobox", { name: "Model" });
 
-      await user.click(modelInput);
-      const model3 = screen.getByText("model3");
+      await user.click(providerInput);
+      const cohere = screen.getByText("cohere");
+      await user.click(cohere);
 
+      await user.click(modelInput);
+      const model3 = screen.getByText("command-r-v1:0");
       await user.click(model3);
+
       await user.click(saveButton);
 
       expect(toastSpy).toHaveBeenCalledTimes(4);
@@ -213,12 +241,17 @@ describe("SettingsModal", () => {
       });
 
       const saveButton = screen.getByRole("button", { name: /save/i });
-      const modelInput = screen.getByRole("combobox", { name: "model" });
+      const providerInput = screen.getByRole("combobox", { name: "Provider" });
+      const modelInput = screen.getByRole("combobox", { name: "Model" });
 
-      await user.click(modelInput);
-      const model3 = screen.getByText("model3");
+      await user.click(providerInput);
+      const cohere = screen.getByText("cohere");
+      await user.click(cohere);
 
+      await user.click(modelInput);
+      const model3 = screen.getByText("command-r-v1:0");
       await user.click(model3);
+
       await user.click(saveButton);
 
       expect(onOpenChangeMock).toHaveBeenCalledWith(false);

+ 20 - 2
frontend/src/components/modals/settings/SettingsModal.tsx

@@ -63,8 +63,10 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
   React.useEffect(() => {
     (async () => {
       try {
-        setModels(await fetchModels());
-        setAgents(await fetchAgents());
+        const fetchedModels = await fetchModels();
+        const fetchedAgents = await fetchAgents();
+        setModels(fetchedModels);
+        setAgents(fetchedAgents);
         setSecurityAnalyzers(await fetchSecurityAnalyzers());
       } catch (error) {
         toast.error("settings", t(I18nKey.CONFIGURATION$ERROR_FETCH_MODELS));
@@ -81,6 +83,20 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
     }));
   };
 
+  const handleCustomModelChange = (model: string) => {
+    setSettings((prev) => ({
+      ...prev,
+      CUSTOM_LLM_MODEL: model,
+    }));
+  };
+
+  const handleModelTypeChange = (type: "custom" | "default") => {
+    setSettings((prev) => ({
+      ...prev,
+      USING_CUSTOM_MODEL: type === "custom",
+    }));
+  };
+
   const handleAgentChange = (agent: string) => {
     setSettings((prev) => ({ ...prev, AGENT: agent }));
   };
@@ -189,6 +205,8 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
           agents={agents}
           securityAnalyzers={securityAnalyzers}
           onModelChange={handleModelChange}
+          onCustomModelChange={handleCustomModelChange}
+          onModelTypeChange={handleModelTypeChange}
           onAgentChange={handleAgentChange}
           onLanguageChange={handleLanguageChange}
           onAPIKeyChange={handleAPIKeyChange}

+ 36 - 0
frontend/src/services/session.test.ts

@@ -11,9 +11,16 @@ const setupSpy = vi.spyOn(Session, "_setupSocket").mockImplementation(() => {
 });
 
 describe("startNewSession", () => {
+  afterEach(() => {
+    sendSpy.mockClear();
+    setupSpy.mockClear();
+  });
+
   it("Should start a new session with the current settings", () => {
     const settings: Settings = {
       LLM_MODEL: "llm_value",
+      CUSTOM_LLM_MODEL: "",
+      USING_CUSTOM_MODEL: false,
       AGENT: "agent_value",
       LANGUAGE: "language_value",
       LLM_API_KEY: "sk-...",
@@ -32,4 +39,33 @@ describe("startNewSession", () => {
     expect(setupSpy).toHaveBeenCalledTimes(1);
     expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event));
   });
+
+  it("should start with the custom llm if set", () => {
+    const settings: Settings = {
+      LLM_MODEL: "llm_value",
+      CUSTOM_LLM_MODEL: "custom_llm_value",
+      USING_CUSTOM_MODEL: true,
+      AGENT: "agent_value",
+      LANGUAGE: "language_value",
+      LLM_API_KEY: "sk-...",
+      CONFIRMATION_MODE: true,
+      SECURITY_ANALYZER: "analyzer",
+    };
+
+    const event = {
+      action: ActionType.INIT,
+      args: settings,
+    };
+
+    saveSettings(settings);
+    Session.startNewSession();
+
+    expect(setupSpy).toHaveBeenCalledTimes(1);
+    expect(sendSpy).toHaveBeenCalledWith(
+      JSON.stringify({
+        ...event,
+        args: { ...settings, LLM_MODEL: "custom_llm_value" },
+      }),
+    );
+  });
 });

+ 9 - 1
frontend/src/services/session.ts

@@ -46,7 +46,15 @@ class Session {
 
   private static _initializeAgent = () => {
     const settings = getSettings();
-    const event = { action: ActionType.INIT, args: settings };
+    const event = {
+      action: ActionType.INIT,
+      args: {
+        ...settings,
+        LLM_MODEL: settings.USING_CUSTOM_MODEL
+          ? settings.CUSTOM_LLM_MODEL
+          : settings.LLM_MODEL,
+      },
+    };
     const eventString = JSON.stringify(event);
     Session.send(eventString);
   };

+ 23 - 0
frontend/src/services/settings.test.ts

@@ -18,6 +18,8 @@ describe("getSettings", () => {
   it("should get the stored settings", () => {
     (localStorage.getItem as Mock)
       .mockReturnValueOnce("llm_value")
+      .mockReturnValueOnce("custom_llm_value")
+      .mockReturnValueOnce("true")
       .mockReturnValueOnce("agent_value")
       .mockReturnValueOnce("language_value")
       .mockReturnValueOnce("api_key")
@@ -28,6 +30,8 @@ describe("getSettings", () => {
 
     expect(settings).toEqual({
       LLM_MODEL: "llm_value",
+      CUSTOM_LLM_MODEL: "custom_llm_value",
+      USING_CUSTOM_MODEL: true,
       AGENT: "agent_value",
       LANGUAGE: "language_value",
       LLM_API_KEY: "api_key",
@@ -43,12 +47,16 @@ describe("getSettings", () => {
       .mockReturnValueOnce(null)
       .mockReturnValueOnce(null)
       .mockReturnValueOnce(null)
+      .mockReturnValueOnce(null)
+      .mockReturnValueOnce(null)
       .mockReturnValueOnce(null);
 
     const settings = getSettings();
 
     expect(settings).toEqual({
       LLM_MODEL: DEFAULT_SETTINGS.LLM_MODEL,
+      CUSTOM_LLM_MODEL: "",
+      USING_CUSTOM_MODEL: DEFAULT_SETTINGS.USING_CUSTOM_MODEL,
       AGENT: DEFAULT_SETTINGS.AGENT,
       LANGUAGE: DEFAULT_SETTINGS.LANGUAGE,
       LLM_API_KEY: "",
@@ -62,6 +70,8 @@ describe("saveSettings", () => {
   it("should save the settings", () => {
     const settings: Settings = {
       LLM_MODEL: "llm_value",
+      CUSTOM_LLM_MODEL: "custom_llm_value",
+      USING_CUSTOM_MODEL: true,
       AGENT: "agent_value",
       LANGUAGE: "language_value",
       LLM_API_KEY: "some_key",
@@ -72,6 +82,14 @@ describe("saveSettings", () => {
     saveSettings(settings);
 
     expect(localStorage.setItem).toHaveBeenCalledWith("LLM_MODEL", "llm_value");
+    expect(localStorage.setItem).toHaveBeenCalledWith(
+      "CUSTOM_LLM_MODEL",
+      "custom_llm_value",
+    );
+    expect(localStorage.setItem).toHaveBeenCalledWith(
+      "USING_CUSTOM_MODEL",
+      "true",
+    );
     expect(localStorage.setItem).toHaveBeenCalledWith("AGENT", "agent_value");
     expect(localStorage.setItem).toHaveBeenCalledWith(
       "LANGUAGE",
@@ -122,6 +140,8 @@ describe("getSettingsDifference", () => {
   beforeEach(() => {
     (localStorage.getItem as Mock)
       .mockReturnValueOnce("llm_value")
+      .mockReturnValueOnce("custom_llm_value")
+      .mockReturnValueOnce("false")
       .mockReturnValueOnce("agent_value")
       .mockReturnValueOnce("language_value");
   });
@@ -129,6 +149,8 @@ describe("getSettingsDifference", () => {
   it("should return updated settings", () => {
     const settings = {
       LLM_MODEL: "new_llm_value",
+      CUSTOM_LLM_MODEL: "custom_llm_value",
+      USING_CUSTOM_MODEL: true,
       AGENT: "new_agent_value",
       LANGUAGE: "language_value",
     };
@@ -136,6 +158,7 @@ describe("getSettingsDifference", () => {
     const updatedSettings = getSettingsDifference(settings);
 
     expect(updatedSettings).toEqual({
+      USING_CUSTOM_MODEL: true,
       LLM_MODEL: "new_llm_value",
       AGENT: "new_agent_value",
     });

+ 10 - 1
frontend/src/services/settings.ts

@@ -2,6 +2,8 @@ const LATEST_SETTINGS_VERSION = 1;
 
 export type Settings = {
   LLM_MODEL: string;
+  CUSTOM_LLM_MODEL: string;
+  USING_CUSTOM_MODEL: boolean;
   AGENT: string;
   LANGUAGE: string;
   LLM_API_KEY: string;
@@ -12,7 +14,9 @@ export type Settings = {
 type SettingsInput = Settings[keyof Settings];
 
 export const DEFAULT_SETTINGS: Settings = {
-  LLM_MODEL: "gpt-4o",
+  LLM_MODEL: "openai/gpt-4o",
+  CUSTOM_LLM_MODEL: "",
+  USING_CUSTOM_MODEL: false,
   AGENT: "CodeActAgent",
   LANGUAGE: "en",
   LLM_API_KEY: "",
@@ -54,6 +58,9 @@ export const getDefaultSettings = (): Settings => DEFAULT_SETTINGS;
  */
 export const getSettings = (): Settings => {
   const model = localStorage.getItem("LLM_MODEL");
+  const customModel = localStorage.getItem("CUSTOM_LLM_MODEL");
+  const usingCustomModel =
+    localStorage.getItem("USING_CUSTOM_MODEL") === "true";
   const agent = localStorage.getItem("AGENT");
   const language = localStorage.getItem("LANGUAGE");
   const apiKey = localStorage.getItem("LLM_API_KEY");
@@ -62,6 +69,8 @@ export const getSettings = (): Settings => {
 
   return {
     LLM_MODEL: model || DEFAULT_SETTINGS.LLM_MODEL,
+    CUSTOM_LLM_MODEL: customModel || DEFAULT_SETTINGS.CUSTOM_LLM_MODEL,
+    USING_CUSTOM_MODEL: usingCustomModel || DEFAULT_SETTINGS.USING_CUSTOM_MODEL,
     AGENT: agent || DEFAULT_SETTINGS.AGENT,
     LANGUAGE: language || DEFAULT_SETTINGS.LANGUAGE,
     LLM_API_KEY: apiKey || DEFAULT_SETTINGS.LLM_API_KEY,

+ 62 - 0
frontend/src/utils/extractModelAndProvider.test.ts

@@ -0,0 +1,62 @@
+import { describe, it, expect } from "vitest";
+import { extractModelAndProvider } from "./extractModelAndProvider";
+
+describe("extractModelAndProvider", () => {
+  it("should work", () => {
+    expect(extractModelAndProvider("azure/ada")).toEqual({
+      provider: "azure",
+      model: "ada",
+      separator: "/",
+    });
+
+    expect(
+      extractModelAndProvider("azure/standard/1024-x-1024/dall-e-2"),
+    ).toEqual({
+      provider: "azure",
+      model: "standard/1024-x-1024/dall-e-2",
+      separator: "/",
+    });
+
+    expect(extractModelAndProvider("vertex_ai_beta/chat-bison")).toEqual({
+      provider: "vertex_ai_beta",
+      model: "chat-bison",
+      separator: "/",
+    });
+
+    expect(extractModelAndProvider("cohere.command-r-v1:0")).toEqual({
+      provider: "cohere",
+      model: "command-r-v1:0",
+      separator: ".",
+    });
+
+    expect(
+      extractModelAndProvider(
+        "cloudflare/@cf/mistral/mistral-7b-instruct-v0.1",
+      ),
+    ).toEqual({
+      provider: "cloudflare",
+      model: "@cf/mistral/mistral-7b-instruct-v0.1",
+      separator: "/",
+    });
+
+    expect(extractModelAndProvider("together-ai-21.1b-41b")).toEqual({
+      provider: "",
+      model: "together-ai-21.1b-41b",
+      separator: "",
+    });
+  });
+
+  it("should add provider for popular models", () => {
+    expect(extractModelAndProvider("gpt-3.5-turbo")).toEqual({
+      provider: "openai",
+      model: "gpt-3.5-turbo",
+      separator: "/",
+    });
+
+    expect(extractModelAndProvider("gpt-4o")).toEqual({
+      provider: "openai",
+      model: "gpt-4o",
+      separator: "/",
+    });
+  });
+});

+ 49 - 0
frontend/src/utils/extractModelAndProvider.ts

@@ -0,0 +1,49 @@
+import { isNumber } from "./isNumber";
+import { VERIFIED_OPENAI_MODELS } from "./verified-models";
+
+/**
+ * Checks if the split array is actually a version number.
+ * @param split The split array of the model string
+ * @returns Boolean indicating if the split is actually a version number
+ *
+ * @example
+ * const split = ["gpt-3", "5-turbo"] // incorrectly split from "gpt-3.5-turbo"
+ * splitIsActuallyVersion(split) // returns true
+ */
+const splitIsActuallyVersion = (split: string[]) =>
+  split[1] && split[1][0] && isNumber(split[1][0]);
+
+/**
+ * Given a model string, extract the provider and model name. Currently the supported separators are "/" and "."
+ * @param model The model string
+ * @returns An object containing the provider, model name, and separator
+ *
+ * @example
+ * extractModelAndProvider("azure/ada")
+ * // returns { provider: "azure", model: "ada", separator: "/" }
+ *
+ * extractModelAndProvider("cohere.command-r-v1:0")
+ * // returns { provider: "cohere", model: "command-r-v1:0", separator: "." }
+ */
+export const extractModelAndProvider = (model: string) => {
+  let separator = "/";
+  let split = model.split(separator);
+  if (split.length === 1) {
+    // no "/" separator found, try with "."
+    separator = ".";
+    split = model.split(separator);
+    if (splitIsActuallyVersion(split)) {
+      split = [split.join(separator)]; // undo the split
+    }
+  }
+  if (split.length === 1) {
+    // no "/" or "." separator found
+    if (VERIFIED_OPENAI_MODELS.includes(split[0])) {
+      return { provider: "openai", model: split[0], separator: "/" };
+    }
+    // return as model only
+    return { provider: "", model, separator: "" };
+  }
+  const [provider, ...modelId] = split;
+  return { provider, model: modelId.join(separator), separator };
+};

+ 9 - 0
frontend/src/utils/isNumber.test.ts

@@ -0,0 +1,9 @@
+import { test, expect } from "vitest";
+import { isNumber } from "./isNumber";
+
+test("isNumber", () => {
+  expect(isNumber(1)).toBe(true);
+  expect(isNumber(0)).toBe(true);
+  expect(isNumber("3")).toBe(true);
+  expect(isNumber("0")).toBe(true);
+});

+ 2 - 0
frontend/src/utils/isNumber.ts

@@ -0,0 +1,2 @@
+export const isNumber = (value: string | number): boolean =>
+  !Number.isNaN(Number(value));

+ 27 - 0
frontend/src/utils/mapProvider.test.ts

@@ -0,0 +1,27 @@
+import { test, expect } from "vitest";
+import { mapProvider } from "./mapProvider";
+
+test("mapProvider", () => {
+  expect(mapProvider("azure")).toBe("Azure");
+  expect(mapProvider("azure_ai")).toBe("Azure AI Studio");
+  expect(mapProvider("vertex_ai")).toBe("VertexAI");
+  expect(mapProvider("palm")).toBe("PaLM");
+  expect(mapProvider("gemini")).toBe("Gemini");
+  expect(mapProvider("anthropic")).toBe("Anthropic");
+  expect(mapProvider("sagemaker")).toBe("AWS SageMaker");
+  expect(mapProvider("bedrock")).toBe("AWS Bedrock");
+  expect(mapProvider("mistral")).toBe("Mistral AI");
+  expect(mapProvider("anyscale")).toBe("Anyscale");
+  expect(mapProvider("databricks")).toBe("Databricks");
+  expect(mapProvider("ollama")).toBe("Ollama");
+  expect(mapProvider("perlexity")).toBe("Perplexity AI");
+  expect(mapProvider("friendliai")).toBe("FriendliAI");
+  expect(mapProvider("groq")).toBe("Groq");
+  expect(mapProvider("fireworks_ai")).toBe("Fireworks AI");
+  expect(mapProvider("cloudflare")).toBe("Cloudflare Workers AI");
+  expect(mapProvider("deepinfra")).toBe("DeepInfra");
+  expect(mapProvider("ai21")).toBe("AI21");
+  expect(mapProvider("replicate")).toBe("Replicate");
+  expect(mapProvider("voyage")).toBe("Voyage AI");
+  expect(mapProvider("openrouter")).toBe("OpenRouter");
+});

+ 30 - 0
frontend/src/utils/mapProvider.ts

@@ -0,0 +1,30 @@
+export const MAP_PROVIDER = {
+  openai: "OpenAI",
+  azure: "Azure",
+  azure_ai: "Azure AI Studio",
+  vertex_ai: "VertexAI",
+  palm: "PaLM",
+  gemini: "Gemini",
+  anthropic: "Anthropic",
+  sagemaker: "AWS SageMaker",
+  bedrock: "AWS Bedrock",
+  mistral: "Mistral AI",
+  anyscale: "Anyscale",
+  databricks: "Databricks",
+  ollama: "Ollama",
+  perlexity: "Perplexity AI",
+  friendliai: "FriendliAI",
+  groq: "Groq",
+  fireworks_ai: "Fireworks AI",
+  cloudflare: "Cloudflare Workers AI",
+  deepinfra: "DeepInfra",
+  ai21: "AI21",
+  replicate: "Replicate",
+  voyage: "Voyage AI",
+  openrouter: "OpenRouter",
+};
+
+export const mapProvider = (provider: string) =>
+  Object.keys(MAP_PROVIDER).includes(provider)
+    ? MAP_PROVIDER[provider as keyof typeof MAP_PROVIDER]
+    : provider;

+ 51 - 0
frontend/src/utils/organizeModelsAndProviders.test.ts

@@ -0,0 +1,51 @@
+import { test } from "vitest";
+import { organizeModelsAndProviders } from "./organizeModelsAndProviders";
+
+test("organizeModelsAndProviders", () => {
+  const models = [
+    "azure/ada",
+    "azure/gpt-35-turbo",
+    "azure/gpt-3-turbo",
+    "azure/standard/1024-x-1024/dall-e-2",
+    "vertex_ai_beta/chat-bison",
+    "vertex_ai_beta/chat-bison-32k",
+    "sagemaker/meta-textgeneration-llama-2-13b",
+    "cohere.command-r-v1:0",
+    "cloudflare/@cf/mistral/mistral-7b-instruct-v0.1",
+    "gpt-4o",
+    "together-ai-21.1b-41b",
+    "gpt-3.5-turbo",
+  ];
+
+  const object = organizeModelsAndProviders(models);
+
+  expect(object).toEqual({
+    azure: {
+      separator: "/",
+      models: [
+        "ada",
+        "gpt-35-turbo",
+        "gpt-3-turbo",
+        "standard/1024-x-1024/dall-e-2",
+      ],
+    },
+    vertex_ai_beta: {
+      separator: "/",
+      models: ["chat-bison", "chat-bison-32k"],
+    },
+    sagemaker: { separator: "/", models: ["meta-textgeneration-llama-2-13b"] },
+    cohere: { separator: ".", models: ["command-r-v1:0"] },
+    cloudflare: {
+      separator: "/",
+      models: ["@cf/mistral/mistral-7b-instruct-v0.1"],
+    },
+    openai: {
+      separator: "/",
+      models: ["gpt-4o", "gpt-3.5-turbo"],
+    },
+    other: {
+      separator: "",
+      models: ["together-ai-21.1b-41b"],
+    },
+  });
+});

+ 42 - 0
frontend/src/utils/organizeModelsAndProviders.ts

@@ -0,0 +1,42 @@
+import { extractModelAndProvider } from "./extractModelAndProvider";
+
+/**
+ * Given a list of models, organize them by provider
+ * @param models The list of models
+ * @returns An object containing the provider and models
+ *
+ * @example
+ * const models = [
+ *  "azure/ada",
+ *  "azure/gpt-35-turbo",
+ *  "cohere.command-r-v1:0",
+ * ];
+ *
+ * organizeModelsAndProviders(models);
+ * // returns {
+ * //   azure: {
+ * //     separator: "/",
+ * //     models: ["ada", "gpt-35-turbo"],
+ * //   },
+ * //   cohere: {
+ * //     separator: ".",
+ * //     models: ["command-r-v1:0"],
+ * //   },
+ * // }
+ */
+export const organizeModelsAndProviders = (models: string[]) => {
+  const object: Record<string, { separator: string; models: string[] }> = {};
+  models.forEach((model) => {
+    const {
+      separator,
+      provider,
+      model: modelId,
+    } = extractModelAndProvider(model);
+    const key = provider || "other";
+    if (!object[key]) {
+      object[key] = { separator, models: [] };
+    }
+    object[key].models.push(modelId);
+  });
+  return object;
+};

+ 14 - 0
frontend/src/utils/verified-models.ts

@@ -0,0 +1,14 @@
+// Here are the list of verified models and providers that we know work well with OpenHands.
+export const VERIFIED_PROVIDERS = ["openai", "azure", "anthropic"];
+export const VERIFIED_MODELS = ["gpt-4o", "claude-3-5-sonnet-20240620-v1:0"];
+
+// LiteLLM does not return OpenAI models with the provider, so we list them here to set them ourselves for consistency
+// (e.g., they return `gpt-4o` instead of `openai/gpt-4o`)
+export const VERIFIED_OPENAI_MODELS = [
+  "gpt-4o",
+  "gpt-4o-mini",
+  "gpt-4-turbo",
+  "gpt-4",
+  "gpt-4-32k",
+  "gpt-3.5-turbo",
+];