From 610ff063d6afab5b573cf5aab8657771f3a7bca6 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 02:17:51 +0530 Subject: [PATCH] refactor(model-connections): update frontend for provider-based models --- .../[search_space_id]/client-layout.tsx | 2 +- .../[search_space_id]/onboard/page.tsx | 2 +- .../model-connections-query.atoms.ts | 7 + .../components/assistant-ui/thread.tsx | 33 ++-- .../components/new-chat/model-selector.tsx | 13 +- .../settings/model-connections-settings.tsx | 182 ++++++++++-------- .../types/model-connections.types.ts | 44 +++-- .../hooks/use-automation-eligible-models.ts | 11 +- .../lib/apis/model-connections-api.service.ts | 6 + surfsense_web/lib/query-client/cache-keys.ts | 1 + 10 files changed, 177 insertions(+), 124 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx index c7e05fe99..2b16a038a 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx @@ -49,7 +49,7 @@ export function DashboardClientLayout({ const firstGlobalChatModel = useMemo(() => { for (const connection of globalConnections) { - const model = connection.models.find((item) => item.enabled && item.capabilities?.chat); + const model = connection.models.find((item) => item.enabled && item.supports_chat); if (model) return model; } return null; diff --git a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx index 9cf429a3a..c6fc1c7a2 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx @@ -33,7 +33,7 @@ export default function OnboardPage() { const firstGlobalChatModel = useMemo(() => { for (const connection of globalConnections) { - const model = connection.models.find((item) => item.enabled && item.capabilities?.chat); + const model = connection.models.find((item) => item.enabled && item.supports_chat); if (model) return model; } return null; diff --git a/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts index 617ffe124..87f31ce9b 100644 --- a/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts @@ -11,6 +11,13 @@ export const globalModelConnectionsAtom = atomWithQuery(() => ({ queryFn: () => modelConnectionsApiService.getGlobalConnections(), })); +export const modelProvidersAtom = atomWithQuery(() => ({ + queryKey: cacheKeys.modelConnections.providers(), + enabled: !!getBearerToken(), + staleTime: 60 * 60 * 1000, + queryFn: () => modelConnectionsApiService.getModelProviders(), +})); + export const modelConnectionsAtom = atomWithQuery((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 5796109f0..722ebb476 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -48,10 +48,10 @@ import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dial import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; import { membersAtom } from "@/atoms/members/members-query.atoms"; import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; + globalModelConnectionsAtom, + modelConnectionsAtom, + modelRolesAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { AssistantMessage } from "@/components/assistant-ui/assistant-message"; import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status"; @@ -976,9 +976,9 @@ const ComposerAction: FC = ({ isBlockedByOtherUser = false if (url) setPendingScreenImages((prev) => [...prev, url]); }, [electronAPI, setPendingScreenImages]); - const { data: userConfigs } = useAtomValue(newLLMConfigsAtom); - const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom); - const { data: preferences } = useAtomValue(llmPreferencesAtom); + const { data: globalModelConnections } = useAtomValue(globalModelConnectionsAtom); + const { data: modelConnections } = useAtomValue(modelConnectionsAtom); + const { data: modelRoles } = useAtomValue(modelRolesAtom); const { data: agentTools } = useAtomValue(agentToolsAtom); const disabledTools = useAtomValue(disabledToolsAtom); @@ -1065,15 +1065,18 @@ const ComposerAction: FC = ({ isBlockedByOtherUser = false }, [hydrateDisabled]); const hasModelConfigured = useMemo(() => { - if (!preferences) return false; - const agentLlmId = preferences.agent_llm_id; - if (agentLlmId === null || agentLlmId === undefined) return false; - - if (agentLlmId <= 0) { - return globalConfigs?.some((c) => c.id === agentLlmId) ?? false; + const chatModelId = modelRoles?.chat_model_id ?? 0; + if (chatModelId === 0) { + return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) => + connection.models.some((model) => model.enabled && Boolean(model.supports_chat)) + ); } - return userConfigs?.some((c) => c.id === agentLlmId) ?? false; - }, [preferences, globalConfigs, userConfigs]); + return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) => + connection.models.some( + (model) => model.id === chatModelId && model.enabled && Boolean(model.supports_chat) + ) + ); + }, [modelRoles?.chat_model_id, globalModelConnections, modelConnections]); const isSendDisabled = isComposerEmpty || !hasModelConfigured || isBlockedByOtherUser; diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 4744da617..6850096d6 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -56,18 +56,18 @@ function modelName(model: ModelRead) { function connectionLabel(connection: ConnectionRead) { if (connection.scope === "GLOBAL") return "Hosted"; - return connection.litellm_provider || connection.protocol; + return connection.provider; } function flattenChatModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models - .filter((model) => model.enabled && Boolean(model.capabilities?.chat)) + .filter((model) => model.enabled && Boolean(model.supports_chat)) .map((model) => ({ ...model, connectionId: connection.id, connectionLabel: connectionLabel(connection), - provider: connection.litellm_provider || connection.protocol, + provider: connection.provider, })) ); } @@ -184,9 +184,14 @@ export function ModelSelector({ {modelName(model)}
{model.model_id}
+ {model.max_input_tokens ? ( +
+ {model.max_input_tokens.toLocaleString()} context +
+ ) : null}
- {!model.capabilities?.vision ? ( + {!model.supports_image_input ? ( No image diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 29501abda..0e541548b 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,11 +1,12 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { CheckCircle2, PlugZap, Plus, RefreshCcw, XCircle } from "lucide-react"; +import { CheckCircle2, PlugZap, Plus, RefreshCcw, Trash2, XCircle } from "lucide-react"; import { useState } from "react"; import { addManualModelMutationAtom, createModelConnectionMutationAtom, + deleteModelConnectionMutationAtom, discoverConnectionModelsMutationAtom, testModelMutationAtom, updateModelConnectionMutationAtom, @@ -16,6 +17,7 @@ import { import { globalModelConnectionsAtom, modelConnectionsAtom, + modelProvidersAtom, modelRolesAtom, } from "@/atoms/model-connections/model-connections-query.atoms"; import { Badge } from "@/components/ui/badge"; @@ -30,37 +32,9 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import type { - ConnectionProtocol, - ConnectionRead, - ModelRead, -} from "@/contracts/types/model-connections.types"; +import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; import { getProviderIcon } from "@/lib/provider-icons"; -const PROTOCOL_OPTIONS: { value: ConnectionProtocol; label: string; description: string }[] = [ - { - value: "OPENAI_COMPATIBLE", - label: "OpenAI-compatible", - description: "Use for OpenAI, OpenRouter, Groq, vLLM, LM Studio, and compatible APIs.", - }, - { - value: "ANTHROPIC", - label: "Anthropic", - description: "Use for Claude endpoints that require Anthropic headers.", - }, - { - value: "OLLAMA", - label: "Ollama", - description: "Use for Ollama's native API.", - }, -]; - -function defaultLitellmProvider(protocol: ConnectionProtocol) { - if (protocol === "OLLAMA") return "ollama_chat"; - if (protocol === "ANTHROPIC") return "anthropic"; - return "openai"; -} - // Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict // what the user can type — any OpenAI-compatible endpoint works. const URL_SUGGESTIONS = [ @@ -82,9 +56,19 @@ function modelLabel(model: ModelRead) { } function capability(model: ModelRead, key: "chat" | "vision" | "image_gen") { - return Boolean(model.capabilities?.[key]); + if (key === "chat") return Boolean(model.supports_chat); + if (key === "vision") return Boolean(model.supports_image_input); + return Boolean(model.supports_image_generation); } +type ModelCapabilityFilter = "chat" | "vision" | "image_gen"; + +const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] = [ + { key: "chat", label: "Chat" }, + { key: "vision", label: "Vision" }, + { key: "image_gen", label: "Image" }, +]; + function StatusBadge({ connection }: { connection: ConnectionRead }) { if (connection.last_status === "OK") { return ( @@ -107,9 +91,9 @@ function flattenModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models.map((model) => ({ ...model, - connectionName: connection.litellm_provider || connection.protocol, + connectionName: connection.provider, connectionId: connection.id, - provider: connection.litellm_provider || connection.protocol, + provider: connection.provider, })) ); } @@ -118,6 +102,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); const updateConnection = useAtomValue(updateModelConnectionMutationAtom); + const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom); const addManualModel = useAtomValue(addManualModelMutationAtom); const updateModel = useAtomValue(updateModelMutationAtom); const testModel = useAtomValue(testModelMutationAtom); @@ -127,9 +112,16 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { : []; const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); const [manualModelId, setManualModelId] = useState(""); + const [modelFilter, setModelFilter] = useState(null); - const providerLabel = connection.litellm_provider || connection.protocol; - const isLocal = connection.protocol === "OLLAMA" || !connection.base_url?.startsWith("https"); + const providerLabel = connection.provider; + const isLocal = + connection.provider === "ollama_chat" || + connection.provider === "lm_studio" || + !connection.base_url?.startsWith("https"); + const filteredModels = modelFilter + ? connection.models.filter((model) => capability(model, modelFilter)) + : connection.models; function saveAllowlist() { const ids = allowlistText @@ -151,6 +143,14 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { ); } + function deleteCurrentConnection() { + const confirmed = window.confirm( + `Delete the ${providerLabel} connection and all of its models? This cannot be undone.` + ); + if (!confirmed) return; + deleteConnection.mutate(connection.id); + } + return (
@@ -175,6 +175,14 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { +
@@ -232,8 +240,38 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
+ {connection.models.length > 0 ? ( +
+ Filter models + {MODEL_CAPABILITY_FILTERS.map((filter) => { + const count = connection.models.filter((model) => capability(model, filter.key)).length; + const isActive = modelFilter === filter.key; + + return ( + + ); + })} +
+ ) : null} +
- {connection.models.map((model) => ( + {filteredModels.length === 0 && modelFilter ? ( +
+ No {MODEL_CAPABILITY_FILTERS.find((filter) => filter.key === modelFilter)?.label.toLowerCase()}{" "} + models found on this connection. +
+ ) : null} + {filteredModels.map((model) => (
{["chat", "vision", "image_gen"] - .filter((key) => Boolean(model.capabilities?.[key])) - .join(", ") || "No verified capabilities"} + .filter((key) => capability(model, key as "chat" | "vision" | "image_gen")) + .join(", ") || "No discovered capabilities"}
@@ -278,18 +316,16 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: number }) { const [{ data: globalConnections = [] }] = useAtom(globalModelConnectionsAtom); const [{ data: connections = [] }] = useAtom(modelConnectionsAtom); + const [{ data: providers = [] }] = useAtom(modelProvidersAtom); const [{ data: roles }] = useAtom(modelRolesAtom); const createConnection = useAtomValue(createModelConnectionMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); - const [protocol, setProtocol] = useState("OPENAI_COMPATIBLE"); + const [provider, setProvider] = useState("openai_compatible"); const [baseUrl, setBaseUrl] = useState(""); const [apiKey, setApiKey] = useState(""); - const [litellmProvider, setLitellmProvider] = useState(""); - const [showAdvancedProvider, setShowAdvancedProvider] = useState(false); - const selectedProtocol = PROTOCOL_OPTIONS.find((item) => item.value === protocol); - const protocolDefaultProvider = defaultLitellmProvider(protocol); - const isOllama = protocol === "OLLAMA"; + const selectedProvider = providers.find((item) => item.provider === provider); + const isOllama = provider === "ollama_chat"; const allConnections = [...globalConnections, ...connections]; const enabledModels = flattenModels(allConnections).filter((model) => model.enabled); @@ -298,11 +334,9 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); function handleCreate() { - const explicitProvider = litellmProvider.trim(); createConnection.mutate( { - protocol, - litellm_provider: explicitProvider ? explicitProvider : null, + provider, base_url: baseUrl || null, api_key: apiKey || null, scope: "SEARCH_SPACE", @@ -337,18 +371,22 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
- + setLitellmProvider(event.target.value)} - placeholder={protocolDefaultProvider} - /> -

- Leave empty to use the protocol default. Set this for more accurate LiteLLM - capabilities/costs, for example openrouter, groq, gemini, or azure. -

-
- ) : null} -
diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index 7a37799c4..a34687d74 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -1,26 +1,19 @@ import { z } from "zod"; -export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "ANTHROPIC"]); export const connectionScopeEnum = z.enum(["GLOBAL", "SEARCH_SPACE", "USER"]); export const modelSourceEnum = z.enum(["DISCOVERED", "MANUAL"]); -export const modelCapabilities = z.object({ - chat: z.boolean().optional(), - vision: z.boolean().optional(), - image_gen: z.boolean().optional(), - embedding: z.boolean().optional(), - tools: z.boolean().optional(), -}); - export const modelRead = z.object({ id: z.number(), connection_id: z.number(), model_id: z.string(), display_name: z.string().nullable().optional(), source: z.union([modelSourceEnum, z.string()]), - capabilities: z.record(z.string(), z.any()).default({}), - capabilities_declared: z.record(z.string(), z.any()).default({}), - capabilities_verified: z.record(z.string(), z.any()).default({}), + supports_chat: z.boolean().nullable().optional(), + max_input_tokens: z.number().nullable().optional(), + supports_image_input: z.boolean().nullable().optional(), + supports_tools: z.boolean().nullable().optional(), + supports_image_generation: z.boolean().nullable().optional(), capabilities_override: z.record(z.string(), z.any()).default({}), embedding_dimension: z.number().nullable().optional(), enabled: z.boolean(), @@ -31,8 +24,7 @@ export const modelRead = z.object({ export const connectionRead = z.object({ id: z.number(), - protocol: z.union([connectionProtocolEnum, z.string()]), - litellm_provider: z.string().nullable().optional(), + provider: z.string(), base_url: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), scope: z.union([connectionScopeEnum, z.string()]), @@ -48,8 +40,7 @@ export const connectionRead = z.object({ }); export const connectionCreateRequest = z.object({ - protocol: connectionProtocolEnum, - litellm_provider: z.string().nullable().optional(), + provider: z.string().min(1), base_url: z.string().nullable().optional(), api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), @@ -59,7 +50,7 @@ export const connectionCreateRequest = z.object({ }); export const connectionUpdateRequest = z.object({ - litellm_provider: z.string().nullable().optional(), + provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).optional(), @@ -74,6 +65,11 @@ export const modelCreateRequest = z.object({ export const modelUpdateRequest = z.object({ display_name: z.string().nullable().optional(), enabled: z.boolean().optional(), + supports_chat: z.boolean().nullable().optional(), + max_input_tokens: z.number().nullable().optional(), + supports_image_input: z.boolean().nullable().optional(), + supports_tools: z.boolean().nullable().optional(), + supports_image_generation: z.boolean().nullable().optional(), capabilities_override: z.record(z.string(), z.any()).optional(), }); @@ -89,10 +85,21 @@ export const modelRoles = z.object({ image_gen_model_id: z.number().nullable().optional(), }); +export const modelProviderRead = z.object({ + provider: z.string(), + transport: z.string(), + discovery: z.string(), + default_base_url: z.string().nullable().optional(), + base_url_required: z.boolean(), + auth_style: z.string(), + local_only: z.boolean().default(false), +}); + +export const modelProviderListResponse = z.array(modelProviderRead); + export const connectionListResponse = z.array(connectionRead); export const modelListResponse = z.array(modelRead); -export type ConnectionProtocol = z.infer; export type ConnectionScope = z.infer; export type ModelRead = z.infer; export type ConnectionRead = z.infer; @@ -102,3 +109,4 @@ export type ModelCreateRequest = z.infer; export type ModelUpdateRequest = z.infer; export type ModelRoles = z.infer; export type VerifyConnectionResponse = z.infer; +export type ModelProviderRead = z.infer; diff --git a/surfsense_web/hooks/use-automation-eligible-models.ts b/surfsense_web/hooks/use-automation-eligible-models.ts index f8b264162..fd3ad3a6a 100644 --- a/surfsense_web/hooks/use-automation-eligible-models.ts +++ b/surfsense_web/hooks/use-automation-eligible-models.ts @@ -47,11 +47,16 @@ function buildKind( capability: "chat" | "image_gen" | "vision", prefId: number | null | undefined ): EligibleModelKind { + const supportsCapability = (model: ModelRead) => { + if (capability === "chat") return Boolean(model.supports_chat); + if (capability === "vision") return Boolean(model.supports_image_input); + return Boolean(model.supports_image_generation); + }; const toOption = (connection: ConnectionRead, model: ModelRead, isBYOK: boolean) => ({ id: model.id, name: model.display_name || model.model_id, modelName: model.model_id, - provider: connection.litellm_provider || connection.protocol, + provider: connection.provider, isBYOK, }); @@ -60,7 +65,7 @@ function buildKind( .filter( (model) => model.enabled && - Boolean(model.capabilities?.[capability]) && + supportsCapability(model) && String(model.billing_tier ?? "").toLowerCase() === "premium" ) .map((model) => toOption(connection, model, false)) @@ -68,7 +73,7 @@ function buildKind( const byokOptions: EligibleModelOption[] = (byok ?? []).flatMap((connection) => connection.models - .filter((model) => model.enabled && Boolean(model.capabilities?.[capability])) + .filter((model) => model.enabled && supportsCapability(model)) .map((model) => toOption(connection, model, true)) ); diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index 92eec8e61..12ad8e0d2 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -7,10 +7,12 @@ import { connectionRead, connectionUpdateRequest, type ModelCreateRequest, + type ModelProviderRead, type ModelRead, type ModelRoles, type ModelUpdateRequest, modelCreateRequest, + modelProviderListResponse, modelListResponse, modelRead, modelRoles, @@ -26,6 +28,10 @@ class ModelConnectionsApiService { return baseApiService.get(`/api/v1/global-model-connections`, connectionListResponse); }; + getModelProviders = async (): Promise => { + return baseApiService.get(`/api/v1/model-providers`, modelProviderListResponse); + }; + getConnections = async (searchSpaceId: number): Promise => { return baseApiService.get( `/api/v1/model-connections?search_space_id=${searchSpaceId}`, diff --git a/surfsense_web/lib/query-client/cache-keys.ts b/surfsense_web/lib/query-client/cache-keys.ts index 558a73f95..5a3f0fb84 100644 --- a/surfsense_web/lib/query-client/cache-keys.ts +++ b/surfsense_web/lib/query-client/cache-keys.ts @@ -47,6 +47,7 @@ export const cacheKeys = { modelConnections: { all: (searchSpaceId: number) => ["model-connections", searchSpaceId] as const, global: () => ["model-connections", "global"] as const, + providers: () => ["model-connections", "providers"] as const, roles: (searchSpaceId: number) => ["model-roles", searchSpaceId] as const, }, imageGenConfigs: {