From 212c8af6825e2a190084789c1ff1d00f637e8a3d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 18:18:03 +0530 Subject: [PATCH] refactor(model-selector, connection-settings): streamline model name handling and enhance connection settings dialog with improved state management for enabled models --- .../components/new-chat/model-selector.tsx | 12 +- .../connection-settings-dialog.tsx | 152 ++++++++++++------ 2 files changed, 107 insertions(+), 57 deletions(-) diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 59e396717..545301493 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -42,10 +42,6 @@ type ChatModel = ModelRead & { provider: string; }; -function modelName(model: ModelRead) { - return (model.display_name || model.model_id).replace(/\s+\(free\)$/i, ""); -} - function connectionLabel(connection: ConnectionRead) { if (connection.scope === "GLOBAL") return "Global"; return providerDisplay(connection.provider).name; @@ -69,6 +65,14 @@ function isFreeGlobalModel(model: ChatModel) { return model.connectionScope === "GLOBAL" && model.billing_tier?.toLowerCase() === "free"; } +function modelName(model: ChatModel) { + const name = model.display_name || model.model_id; + if (model.connectionScope === "GLOBAL") { + return name.replace(/\s+\(free\)$/i, ""); + } + return name; +} + function groupedModels(models: ChatModel[]) { return models.reduce>((groups, model) => { const key = model.connectionLabel; diff --git a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx index a5b8e7403..1f16c3bd0 100644 --- a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx +++ b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx @@ -1,13 +1,12 @@ import { useAtomValue } from "jotai"; import { Eye, EyeOff, Settings } from "lucide-react"; -import { useState } from "react"; +import { useMemo, useState } from "react"; import { addManualModelMutationAtom, bulkUpdateModelsMutationAtom, discoverConnectionModelsMutationAtom, testPreviewModelMutationAtom, updateModelConnectionMutationAtom, - updateModelMutationAtom, } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { Button } from "@/components/ui/button"; import { @@ -36,6 +35,14 @@ interface ConnectionSettingsDialogProps { providerLabel: string; } +function enabledModelIds(models: SelectableModel[]) { + return new Set( + models + .filter((model) => typeof model.id === "number" && model.enabled) + .map((model) => Number(model.id)) + ); +} + export function ConnectionSettingsDialog({ connection, providerLabel, @@ -44,7 +51,6 @@ export function ConnectionSettingsDialog({ const testPreviewModel = useAtomValue(testPreviewModelMutationAtom); const updateConnection = useAtomValue(updateModelConnectionMutationAtom); const addManualModel = useAtomValue(addManualModelMutationAtom); - const updateModel = useAtomValue(updateModelMutationAtom); const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom); const allowlist = Array.isArray(connection.extra?.model_ids) @@ -56,6 +62,9 @@ export function ConnectionSettingsDialog({ const [showApiKey, setShowApiKey] = useState(false); const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); const [isSavingConnectionSettings, setIsSavingConnectionSettings] = useState(false); + const [draftEnabledModelIds, setDraftEnabledModelIds] = useState(() => + enabledModelIds(connection.models) + ); const isLocal = connection.provider === "ollama_chat" || @@ -64,6 +73,19 @@ export function ConnectionSettingsDialog({ const hasConnectionChanges = baseUrlDraft.trim() !== (connection.base_url ?? "") || apiKeyDraft.trim() !== (connection.api_key ?? ""); + const draftModels = useMemo( + () => + connection.models.map((model) => + typeof model.id === "number" + ? { ...model, enabled: draftEnabledModelIds.has(model.id) } + : model + ), + [connection.models, draftEnabledModelIds] + ); + const hasModelChanges = connection.models.some( + (model) => typeof model.id === "number" && draftEnabledModelIds.has(model.id) !== model.enabled + ); + const canUpdate = hasConnectionChanges || hasModelChanges; function handleOpenChange(open: boolean) { setIsOpen(open); @@ -73,10 +95,35 @@ export function ConnectionSettingsDialog({ setShowApiKey(false); setAllowlistText(allowlist.join(", ")); setIsSavingConnectionSettings(false); + setDraftEnabledModelIds(enabledModelIds(connection.models)); } } - function saveConnectionSettings() { + async function saveModelChanges() { + const toEnable = connection.models + .filter((model) => typeof model.id === "number" && draftEnabledModelIds.has(model.id)) + .filter((model) => !model.enabled) + .map((model) => Number(model.id)); + const toDisable = connection.models + .filter((model) => typeof model.id === "number" && !draftEnabledModelIds.has(model.id)) + .filter((model) => model.enabled) + .map((model) => Number(model.id)); + + if (toEnable.length > 0) { + await bulkUpdateModels.mutateAsync({ + connectionId: connection.id, + data: { model_ids: toEnable, enabled: true }, + }); + } + if (toDisable.length > 0) { + await bulkUpdateModels.mutateAsync({ + connectionId: connection.id, + data: { model_ids: toDisable, enabled: false }, + }); + } + } + + async function saveConnectionSettings() { if (isSavingConnectionSettings) return; const data: ConnectionUpdateRequest = { @@ -90,49 +137,35 @@ export function ConnectionSettingsDialog({ ? (data.api_key ?? null) : (connection.api_key ?? null); - const enabledModels = connection.models.filter((model) => model.enabled); + const enabledModels = draftModels.filter((model) => model.enabled); const testModel = enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0]; setIsSavingConnectionSettings(true); - if (!testModel) { - updateConnection.mutate( - { id: connection.id, data }, - { - onSuccess: () => setApiKeyDraft(""), - onSettled: () => setIsSavingConnectionSettings(false), + try { + if (hasConnectionChanges) { + if (testModel) { + const result = await testPreviewModel.mutateAsync({ + provider: connection.provider, + base_url: data.base_url, + api_key: apiKeyForTest, + scope: "SEARCH_SPACE", + search_space_id: connection.search_space_id, + extra: connection.extra ?? {}, + enabled: connection.enabled, + models: [], + model_id: testModel.model_id, + }); + if (!result.ok) return; } - ); - return; - } - - testPreviewModel.mutate( - { - provider: connection.provider, - base_url: data.base_url, - api_key: apiKeyForTest, - scope: "SEARCH_SPACE", - search_space_id: connection.search_space_id, - extra: connection.extra ?? {}, - enabled: connection.enabled, - models: [], - model_id: testModel.model_id, - }, - { - onSuccess: (result) => { - if (!result.ok) { - setIsSavingConnectionSettings(false); - return; - } - updateConnection.mutate( - { id: connection.id, data }, - { - onSuccess: () => setApiKeyDraft(""), - onSettled: () => setIsSavingConnectionSettings(false), - } - ); - }, - onError: () => setIsSavingConnectionSettings(false), + await updateConnection.mutateAsync({ id: connection.id, data }); + setApiKeyDraft(""); } - ); + + if (hasModelChanges) { + await saveModelChanges(); + } + } finally { + setIsSavingConnectionSettings(false); + } } function saveAllowlist() { @@ -148,9 +181,15 @@ export function ConnectionSettingsDialog({ function handleToggleModel(model: SelectableModel, enabled: boolean) { if (typeof model.id !== "number") return; - updateModel.mutate({ - id: model.id, - data: { enabled }, + const modelId = model.id; + setDraftEnabledModelIds((current) => { + const next = new Set(current); + if (enabled) { + next.add(modelId); + } else { + next.delete(modelId); + } + return next; }); } @@ -159,9 +198,16 @@ export function ConnectionSettingsDialog({ .map((model) => model.id) .filter((id): id is number => typeof id === "number"); if (modelIds.length === 0) return; - bulkUpdateModels.mutate({ - connectionId: connection.id, - data: { model_ids: modelIds, enabled }, + setDraftEnabledModelIds((current) => { + const next = new Set(current); + for (const id of modelIds) { + if (enabled) { + next.add(id); + } else { + next.delete(id); + } + } + return next; }); } @@ -252,11 +298,11 @@ export function ConnectionSettingsDialog({ discoverModels.mutate(connection.id)} onAddManual={(modelId) => @@ -274,7 +320,7 @@ export function ConnectionSettingsDialog({