ModelSelector.test.tsx 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import { describe, it, expect } from "vitest";
  2. import { render, screen } from "@testing-library/react";
  3. import userEvent from "@testing-library/user-event";
  4. import { ModelSelector } from "#/components/modals/settings/ModelSelector";
  5. describe("ModelSelector", () => {
  6. const models = {
  7. openai: {
  8. separator: "/",
  9. models: ["gpt-4o", "gpt-4o-mini"],
  10. },
  11. azure: {
  12. separator: "/",
  13. models: ["ada", "gpt-35-turbo"],
  14. },
  15. vertex_ai: {
  16. separator: "/",
  17. models: ["chat-bison", "chat-bison-32k"],
  18. },
  19. cohere: {
  20. separator: ".",
  21. models: ["command-r-v1:0"],
  22. },
  23. };
  24. it("should display the provider selector", async () => {
  25. const user = userEvent.setup();
  26. render(<ModelSelector models={models} />);
  27. const selector = screen.getByLabelText("LLM Provider");
  28. expect(selector).toBeInTheDocument();
  29. await user.click(selector);
  30. expect(screen.getByText("OpenAI")).toBeInTheDocument();
  31. expect(screen.getByText("Azure")).toBeInTheDocument();
  32. expect(screen.getByText("VertexAI")).toBeInTheDocument();
  33. expect(screen.getByText("cohere")).toBeInTheDocument();
  34. });
  35. it("should disable the model selector if the provider is not selected", async () => {
  36. const user = userEvent.setup();
  37. render(<ModelSelector models={models} />);
  38. const modelSelector = screen.getByLabelText("LLM Model");
  39. expect(modelSelector).toBeDisabled();
  40. const providerSelector = screen.getByLabelText("LLM Provider");
  41. await user.click(providerSelector);
  42. const vertexAI = screen.getByText("VertexAI");
  43. await user.click(vertexAI);
  44. expect(modelSelector).not.toBeDisabled();
  45. });
  46. it("should display the model selector", async () => {
  47. const user = userEvent.setup();
  48. render(<ModelSelector models={models} />);
  49. const providerSelector = screen.getByLabelText("LLM Provider");
  50. await user.click(providerSelector);
  51. const azureProvider = screen.getByText("Azure");
  52. await user.click(azureProvider);
  53. const modelSelector = screen.getByLabelText("LLM Model");
  54. await user.click(modelSelector);
  55. expect(screen.getByText("ada")).toBeInTheDocument();
  56. expect(screen.getByText("gpt-35-turbo")).toBeInTheDocument();
  57. await user.click(providerSelector);
  58. const vertexProvider = screen.getByText("VertexAI");
  59. await user.click(vertexProvider);
  60. await user.click(modelSelector);
  61. expect(screen.getByText("chat-bison")).toBeInTheDocument();
  62. expect(screen.getByText("chat-bison-32k")).toBeInTheDocument();
  63. });
  64. it("should call onModelChange when the model is changed", async () => {
  65. const user = userEvent.setup();
  66. render(<ModelSelector models={models} />);
  67. const providerSelector = screen.getByLabelText("LLM Provider");
  68. const modelSelector = screen.getByLabelText("LLM Model");
  69. await user.click(providerSelector);
  70. await user.click(screen.getByText("Azure"));
  71. await user.click(modelSelector);
  72. await user.click(screen.getByText("ada"));
  73. await user.click(modelSelector);
  74. await user.click(screen.getByText("gpt-35-turbo"));
  75. await user.click(providerSelector);
  76. await user.click(screen.getByText("cohere"));
  77. await user.click(modelSelector);
  78. await user.click(screen.getByText("command-r-v1:0"));
  79. });
  80. it("should have a default value if passed", async () => {
  81. render(<ModelSelector models={models} currentModel="azure/ada" />);
  82. expect(screen.getByLabelText("LLM Provider")).toHaveValue("Azure");
  83. expect(screen.getByLabelText("LLM Model")).toHaveValue("ada");
  84. });
  85. it.todo("should disable provider if isDisabled is true");
  86. it.todo(
  87. "should display the verified models in the correct order",
  88. async () => {},
  89. );
  90. });