From 3089dd4cb6b9ba1bc8132ef50967bd70b6a3f33b Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:22:57 +0530 Subject: [PATCH] refactor(model-connections): simplify connection settings UI --- .../model-connections-mutation.atoms.ts | 12 +- .../components/new-chat/model-selector.tsx | 4 +- .../settings/model-connections-settings.tsx | 234 +++++++----------- .../types/model-connections.types.ts | 8 +- .../hooks/use-automation-eligible-models.ts | 2 +- .../lib/apis/model-connections-api.service.ts | 31 ++- 6 files changed, 126 insertions(+), 165 deletions(-) diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index 76289e60d..101bad1b5 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -4,8 +4,10 @@ import type { ConnectionCreateRequest, ConnectionUpdateRequest, ModelCreateRequest, + ModelRead, ModelRoles, ModelUpdateRequest, + VerifyConnectionResponse, } from "@/contracts/types/model-connections.types"; import { modelConnectionsApiService } from "@/lib/apis/model-connections-api.service"; import { cacheKeys } from "@/lib/query-client/cache-keys"; @@ -67,7 +69,7 @@ export const verifyModelConnectionMutationAtom = atomWithMutation((get) => { return { mutationKey: ["model-connections", "verify"], mutationFn: (id: number) => modelConnectionsApiService.verifyConnection(id), - onSuccess: (result) => { + onSuccess: (result: VerifyConnectionResponse) => { if (result.ok) { toast.success("Connection verified"); } else { @@ -90,11 +92,9 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { return { mutationKey: ["model-connections", "discover"], mutationFn: (id: number) => modelConnectionsApiService.discoverModels(id), - onSuccess: (models) => { + onSuccess: (models: ModelRead[]) => { toast.success( - models.length - ? `${models.length} models discovered` - : "No models found for this connection" + models.length ? `${models.length} models discovered` : "No models found for this connection" ); invalidateModelConnections(searchSpaceId); }, @@ -132,7 +132,7 @@ export const testModelMutationAtom = atomWithMutation((get) => { return { mutationKey: ["models", "test"], mutationFn: (id: number) => modelConnectionsApiService.testModel(id), - onSuccess: (result) => { + onSuccess: (result: VerifyConnectionResponse) => { if (result.ok) toast.success("Model test succeeded"); else toast.error(result.message || "Model test failed"); invalidateModelConnections(searchSpaceId); diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 7c912afbb..4744da617 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -56,7 +56,7 @@ function modelName(model: ModelRead) { function connectionLabel(connection: ConnectionRead) { if (connection.scope === "GLOBAL") return "Hosted"; - return connection.native_provider || connection.protocol; + return connection.litellm_provider || connection.protocol; } function flattenChatModels(connections: ConnectionRead[]) { @@ -67,7 +67,7 @@ function flattenChatModels(connections: ConnectionRead[]) { ...model, connectionId: connection.id, connectionLabel: connectionLabel(connection), - provider: connection.native_provider || connection.protocol, + provider: connection.litellm_provider || connection.protocol, })) ); } diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 59873408f..29501abda 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -2,7 +2,7 @@ import { useAtom, useAtomValue } from "jotai"; import { CheckCircle2, PlugZap, Plus, RefreshCcw, XCircle } from "lucide-react"; -import { useMemo, useState } from "react"; +import { useState } from "react"; import { addManualModelMutationAtom, createModelConnectionMutationAtom, @@ -35,60 +35,32 @@ import type { ConnectionRead, ModelRead, } from "@/contracts/types/model-connections.types"; -import { isCloud } from "@/lib/env-config"; import { getProviderIcon } from "@/lib/provider-icons"; -type Preset = { - id: string; - label: string; - protocol: ConnectionProtocol; - nativeProvider?: string; - baseUrl?: string; - local?: boolean; -}; - -const PRESETS: Preset[] = [ - { id: "custom", label: "OpenAI-compatible (any URL)", protocol: "OPENAI_COMPATIBLE" }, - { id: "openai", label: "OpenAI", protocol: "NATIVE", nativeProvider: "OPENAI" }, - { id: "anthropic", label: "Anthropic", protocol: "NATIVE", nativeProvider: "ANTHROPIC" }, - { id: "openrouter", label: "OpenRouter", protocol: "NATIVE", nativeProvider: "OPENROUTER" }, +const PROTOCOL_OPTIONS: { value: ConnectionProtocol; label: string; description: string }[] = [ { - id: "ollama", + value: "OPENAI_COMPATIBLE", + label: "OpenAI-compatible", + description: "Use for OpenAI, OpenRouter, Groq, vLLM, LM Studio, and compatible APIs.", + }, + { + value: "ANTHROPIC", + label: "Anthropic", + description: "Use for Claude endpoints that require Anthropic headers.", + }, + { + value: "OLLAMA", label: "Ollama", - protocol: "OLLAMA", - baseUrl: "http://host.docker.internal:11434", - local: true, - }, - { - id: "lmstudio", - label: "LM Studio", - protocol: "OPENAI_COMPATIBLE", - baseUrl: "http://host.docker.internal:1234/v1", - local: true, - }, - { - id: "llamacpp", - label: "llama.cpp", - protocol: "OPENAI_COMPATIBLE", - baseUrl: "http://host.docker.internal:8080/v1", - local: true, - }, - { - id: "localai", - label: "LocalAI", - protocol: "OPENAI_COMPATIBLE", - baseUrl: "http://host.docker.internal:8080/v1", - local: true, - }, - { - id: "vllm", - label: "vLLM", - protocol: "OPENAI_COMPATIBLE", - baseUrl: "http://host.docker.internal:8000/v1", - local: true, + description: "Use for Ollama's native API.", }, ]; +function defaultLitellmProvider(protocol: ConnectionProtocol) { + if (protocol === "OLLAMA") return "ollama_chat"; + if (protocol === "ANTHROPIC") return "anthropic"; + return "openai"; +} + // Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict // what the user can type — any OpenAI-compatible endpoint works. const URL_SUGGESTIONS = [ @@ -135,9 +107,9 @@ function flattenModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models.map((model) => ({ ...model, - connectionName: connection.native_provider || connection.protocol, + connectionName: connection.litellm_provider || connection.protocol, connectionId: connection.id, - provider: connection.native_provider || connection.protocol, + provider: connection.litellm_provider || connection.protocol, })) ); } @@ -156,7 +128,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); const [manualModelId, setManualModelId] = useState(""); - const providerLabel = connection.native_provider || connection.protocol; + const providerLabel = connection.litellm_provider || connection.protocol; const isLocal = connection.protocol === "OLLAMA" || !connection.base_url?.startsWith("https"); function saveAllowlist() { @@ -200,11 +172,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { > Test - @@ -212,8 +180,8 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { {connection.last_status && connection.last_status !== "OK" ? (

- {connection.last_error || "Could not list models."} Chat may still work — add model - IDs manually below. + {connection.last_error || "Could not list models."} Chat may still work — add model IDs + manually below.

) : null} @@ -236,8 +204,8 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {

- Leave empty to discover all models. Recommended for providers with large catalogs - (e.g. OpenRouter). + Leave empty to discover all models. Recommended for providers with large catalogs (e.g. + OpenRouter).

) : null} @@ -314,20 +282,14 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const createConnection = useAtomValue(createModelConnectionMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); - const visiblePresets = useMemo( - () => PRESETS.filter((preset) => !(isCloud() && preset.local)), - [] - ); - const [presetId, setPresetId] = useState(visiblePresets[0]?.id ?? "custom"); - const preset = visiblePresets.find((item) => item.id === presetId) ?? visiblePresets[0]; - const [baseUrl, setBaseUrl] = useState(preset?.baseUrl ?? ""); + const [protocol, setProtocol] = useState("OPENAI_COMPATIBLE"); + const [baseUrl, setBaseUrl] = useState(""); const [apiKey, setApiKey] = useState(""); - // Native providers carry their endpoint inside LiteLLM, so Base URL is hidden - // by default and only revealed for power users who want to override it. - const [showCustomEndpoint, setShowCustomEndpoint] = useState(false); - - const isNative = preset?.protocol === "NATIVE"; - const requiresUrl = !isNative; + const [litellmProvider, setLitellmProvider] = useState(""); + const [showAdvancedProvider, setShowAdvancedProvider] = useState(false); + const selectedProtocol = PROTOCOL_OPTIONS.find((item) => item.value === protocol); + const protocolDefaultProvider = defaultLitellmProvider(protocol); + const isOllama = protocol === "OLLAMA"; const allConnections = [...globalConnections, ...connections]; const enabledModels = flattenModels(allConnections).filter((model) => model.enabled); @@ -335,21 +297,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const visionModels = enabledModels.filter((model) => capability(model, "vision")); const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); - function onPresetChange(value: string) { - setPresetId(value); - const next = visiblePresets.find((item) => item.id === value); - // Native providers use LiteLLM's built-in endpoint; everything else needs - // (and may prefill) a Base URL. - setBaseUrl(next?.protocol === "NATIVE" ? "" : (next?.baseUrl ?? "")); - setShowCustomEndpoint(false); - } - function handleCreate() { - if (!preset) return; + const explicitProvider = litellmProvider.trim(); createConnection.mutate( { - protocol: preset.protocol, - native_provider: preset.nativeProvider, + protocol, + litellm_provider: explicitProvider ? explicitProvider : null, base_url: baseUrl || null, api_key: apiKey || null, scope: "SEARCH_SPACE", @@ -384,90 +337,89 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
- - setProtocol(value as ConnectionProtocol)} + > - {visiblePresets.map((item) => ( - - - {getProviderIcon(item.nativeProvider || item.protocol, { - className: "size-4", - })} - {item.label} - + {PROTOCOL_OPTIONS.map((item) => ( + + {item.label} ))}
- - {isNative && !showCustomEndpoint ? ( -
-
- Uses provider default -
- -
- ) : ( - <> - setBaseUrl(event.target.value)} - placeholder="https://api.example.com/v1" - list="model-conn-url-suggestions" - /> - - {URL_SUGGESTIONS.map((url) => ( - - - )} + + setBaseUrl(event.target.value)} + placeholder={ + isOllama ? "http://host.docker.internal:11434" : "https://api.example.com/v1" + } + list="model-conn-url-suggestions" + /> + + {URL_SUGGESTIONS.map((url) => ( +
- + setApiKey(event.target.value)} - placeholder={preset?.local ? "Optional for local models" : "API key"} + placeholder={isOllama ? "Optional for Ollama" : "API key"} type="password" />
- {preset?.local ? ( +

- Local URLs are tested from the backend container. Use host.docker.internal instead of - localhost. + {selectedProtocol?.description} Base URL is explicit and editable; no provider presets + are required. Local URLs are tested from the backend container, so use + host.docker.internal instead of localhost.

- ) : isNative ? ( -

- Just paste an API key — {preset?.label} routes through its native endpoint - automatically. After adding, hit Discover (or add model IDs manually). -

- ) : preset?.protocol === "OPENAI_COMPATIBLE" ? ( -

- Enter any OpenAI-compatible endpoint (OpenRouter, Together, Groq, vLLM, LM Studio…). - After adding, hit Discover to list models. -

- ) : null} +
+ + {showAdvancedProvider ? ( +
+ + setLitellmProvider(event.target.value)} + placeholder={protocolDefaultProvider} + /> +

+ Leave empty to use the protocol default. Set this for more accurate LiteLLM + capabilities/costs, for example openrouter, groq, gemini, or azure. +

+
+ ) : null} +
+
{connections.map((connection) => ( diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index dcc875251..7a37799c4 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -1,6 +1,6 @@ import { z } from "zod"; -export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "NATIVE"]); +export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "ANTHROPIC"]); export const connectionScopeEnum = z.enum(["GLOBAL", "SEARCH_SPACE", "USER"]); export const modelSourceEnum = z.enum(["DISCOVERED", "MANUAL"]); @@ -32,7 +32,7 @@ export const modelRead = z.object({ export const connectionRead = z.object({ id: z.number(), protocol: z.union([connectionProtocolEnum, z.string()]), - native_provider: z.string().nullable().optional(), + litellm_provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), scope: z.union([connectionScopeEnum, z.string()]), @@ -49,7 +49,7 @@ export const connectionRead = z.object({ export const connectionCreateRequest = z.object({ protocol: connectionProtocolEnum, - native_provider: z.string().nullable().optional(), + litellm_provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), @@ -59,7 +59,7 @@ export const connectionCreateRequest = z.object({ }); export const connectionUpdateRequest = z.object({ - native_provider: z.string().nullable().optional(), + litellm_provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).optional(), diff --git a/surfsense_web/hooks/use-automation-eligible-models.ts b/surfsense_web/hooks/use-automation-eligible-models.ts index e75235c56..f8b264162 100644 --- a/surfsense_web/hooks/use-automation-eligible-models.ts +++ b/surfsense_web/hooks/use-automation-eligible-models.ts @@ -51,7 +51,7 @@ function buildKind( id: model.id, name: model.display_name || model.model_id, modelName: model.model_id, - provider: connection.native_provider || connection.protocol, + provider: connection.litellm_provider || connection.protocol, isBYOK, }); diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index 7d0f0f59c..92eec8e61 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -1,11 +1,13 @@ import { type ConnectionCreateRequest, + type ConnectionRead, type ConnectionUpdateRequest, connectionCreateRequest, connectionListResponse, connectionRead, connectionUpdateRequest, type ModelCreateRequest, + type ModelRead, type ModelRoles, type ModelUpdateRequest, modelCreateRequest, @@ -13,24 +15,25 @@ import { modelRead, modelRoles, modelUpdateRequest, + type VerifyConnectionResponse, verifyConnectionResponse, } from "@/contracts/types/model-connections.types"; import { ValidationError } from "../error"; import { baseApiService } from "./base-api.service"; class ModelConnectionsApiService { - getGlobalConnections = async () => { + getGlobalConnections = async (): Promise => { return baseApiService.get(`/api/v1/global-model-connections`, connectionListResponse); }; - getConnections = async (searchSpaceId: number) => { + getConnections = async (searchSpaceId: number): Promise => { return baseApiService.get( `/api/v1/model-connections?search_space_id=${searchSpaceId}`, connectionListResponse ); }; - createConnection = async (request: ConnectionCreateRequest) => { + createConnection = async (request: ConnectionCreateRequest): Promise => { const parsed = connectionCreateRequest.safeParse(request); if (!parsed.success) { throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); @@ -40,7 +43,10 @@ class ModelConnectionsApiService { }); }; - updateConnection = async (id: number, request: ConnectionUpdateRequest) => { + updateConnection = async ( + id: number, + request: ConnectionUpdateRequest + ): Promise => { const parsed = connectionUpdateRequest.safeParse(request); if (!parsed.success) { throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); @@ -54,15 +60,18 @@ class ModelConnectionsApiService { return baseApiService.delete(`/api/v1/model-connections/${id}`, undefined); }; - verifyConnection = async (id: number) => { + verifyConnection = async (id: number): Promise => { return baseApiService.post(`/api/v1/model-connections/${id}/verify`, verifyConnectionResponse); }; - discoverModels = async (id: number) => { + discoverModels = async (id: number): Promise => { return baseApiService.post(`/api/v1/model-connections/${id}/discover`, modelListResponse); }; - addManualModel = async (connectionId: number, request: ModelCreateRequest) => { + addManualModel = async ( + connectionId: number, + request: ModelCreateRequest + ): Promise => { const parsed = modelCreateRequest.safeParse(request); if (!parsed.success) { throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); @@ -72,7 +81,7 @@ class ModelConnectionsApiService { }); }; - updateModel = async (id: number, request: ModelUpdateRequest) => { + updateModel = async (id: number, request: ModelUpdateRequest): Promise => { const parsed = modelUpdateRequest.safeParse(request); if (!parsed.success) { throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); @@ -82,15 +91,15 @@ class ModelConnectionsApiService { }); }; - testModel = async (id: number) => { + testModel = async (id: number): Promise => { return baseApiService.post(`/api/v1/models/${id}/test`, verifyConnectionResponse); }; - getModelRoles = async (searchSpaceId: number) => { + getModelRoles = async (searchSpaceId: number): Promise => { return baseApiService.get(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles); }; - updateModelRoles = async (searchSpaceId: number, roles: ModelRoles) => { + updateModelRoles = async (searchSpaceId: number, roles: ModelRoles): Promise => { return baseApiService.put(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles, { body: roles, });