From ced1bb85edc778e710ff133fe5265df089a40a11 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 09:43:56 +0530 Subject: [PATCH] feat(model-connections): implement bulk model update endpoint and related schema changes --- .../app/routes/model_connections_routes.py | 31 +- surfsense_backend/app/schemas/__init__.py | 1 + .../app/schemas/model_connections.py | 6 + .../model-connections-mutation.atoms.ts | 12 + .../settings/model-connections-settings.tsx | 626 +++++++++++++----- .../types/model-connections.types.ts | 7 + .../lib/apis/model-connections-api.service.ts | 23 +- 7 files changed, 538 insertions(+), 168 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index ecb86711e..730c68565 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -1,7 +1,7 @@ import logging from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import select +from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -25,6 +25,7 @@ from app.schemas import ( ModelRead, ModelRolesRead, ModelRolesUpdate, + ModelsBulkUpdate, ModelUpdate, VerifyConnectionResponse, ) @@ -62,6 +63,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None id=conn.id, provider=conn.provider, base_url=conn.base_url, + api_key=conn.api_key, extra=conn.extra or {}, scope=conn.scope, search_space_id=conn.search_space_id, @@ -351,6 +353,33 @@ async def add_manual_model( return _model_read(model) +@router.patch("/model-connections/{connection_id}/models", response_model=list[ModelRead]) +async def bulk_update_models( + connection_id: int, + data: ModelsBulkUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + + model_ids = set(data.model_ids) + await session.execute( + update(Model) + .where(Model.connection_id == connection_id, Model.id.in_(model_ids)) + .values(enabled=data.enabled) + ) + await session.commit() + session.expire_all() + + result = await session.execute( + select(Model) + .where(Model.connection_id == connection_id, Model.id.in_(model_ids)) + .order_by(Model.id) + ) + return [_model_read(model) for model in result.scalars().all()] + + @router.put("/models/{model_id}", response_model=ModelRead) async def update_model( model_id: int, diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 8ac7c5bbb..55e712f12 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -53,6 +53,7 @@ from .model_connections import ( ModelRead, ModelRolesRead, ModelRolesUpdate, + ModelsBulkUpdate, ModelUpdate, VerifyConnectionResponse, ) diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index c081a193d..0b03c7fab 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -32,6 +32,7 @@ class ConnectionRead(BaseModel): id: int provider: str base_url: str | None = None + api_key: str | None = None extra: dict[str, Any] = Field(default_factory=dict) scope: ConnectionScope | str search_space_id: int | None = None @@ -87,6 +88,11 @@ class ModelUpdate(BaseModel): capabilities_override: dict[str, Any] | None = None +class ModelsBulkUpdate(BaseModel): + model_ids: list[int] = Field(..., min_length=1, max_length=1000) + enabled: bool + + class ModelProviderRead(BaseModel): provider: str transport: str diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index 101bad1b5..fee3b95ba 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -6,6 +6,7 @@ import type { ModelCreateRequest, ModelRead, ModelRoles, + ModelsBulkUpdateRequest, ModelUpdateRequest, VerifyConnectionResponse, } from "@/contracts/types/model-connections.types"; @@ -127,6 +128,17 @@ export const updateModelMutationAtom = atomWithMutation((get) => { }; }); +export const bulkUpdateModelsMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["models", "bulk-update"], + mutationFn: ({ connectionId, data }: { connectionId: number; data: ModelsBulkUpdateRequest }) => + modelConnectionsApiService.bulkUpdateModels(connectionId, data), + onSuccess: () => invalidateModelConnections(searchSpaceId), + onError: (error: Error) => toast.error(error.message || "Failed to update models"), + }; +}); + export const testModelMutationAtom = atomWithMutation((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 0e541548b..9112cfe64 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,14 +1,24 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { CheckCircle2, PlugZap, Plus, RefreshCcw, Trash2, XCircle } from "lucide-react"; +import { + Check, + CheckCircle2, + ChevronsUpDown, + Eye, + EyeOff, + RefreshCcw, + Settings, + Trash2, + XCircle, +} from "lucide-react"; import { useState } from "react"; import { addManualModelMutationAtom, + bulkUpdateModelsMutationAtom, createModelConnectionMutationAtom, deleteModelConnectionMutationAtom, discoverConnectionModelsMutationAtom, - testModelMutationAtom, updateModelConnectionMutationAtom, updateModelMutationAtom, updateModelRolesMutationAtom, @@ -20,11 +30,41 @@ import { modelProvidersAtom, modelRolesAtom, } from "@/atoms/model-connections/model-connections-query.atoms"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { Checkbox } from "@/components/ui/checkbox"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "@/components/ui/dialog"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Select, SelectContent, @@ -32,8 +72,14 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; +import { Separator } from "@/components/ui/separator"; +import type { + ConnectionRead, + ConnectionUpdateRequest, + ModelRead, +} from "@/contracts/types/model-connections.types"; import { getProviderIcon } from "@/lib/provider-icons"; +import { cn } from "@/lib/utils"; // Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict // what the user can type — any OpenAI-compatible endpoint works. @@ -69,6 +115,67 @@ const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] { key: "image_gen", label: "Image" }, ]; +function UrlSuggestionCombobox({ + value, + onChange, + placeholder, +}: { + value: string; + onChange: (value: string) => void; + placeholder: string; +}) { + const [open, setOpen] = useState(false); + + return ( + + + + + + + + + + Use the custom URL you typed + + + {URL_SUGGESTIONS.map((url) => ( + { + onChange(url); + setOpen(false); + }} + > + + {url} + + ))} + + + + + + ); +} + function StatusBadge({ connection }: { connection: ConnectionRead }) { if (connection.last_status === "OK") { return ( @@ -105,11 +212,15 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom); const addManualModel = useAtomValue(addManualModelMutationAtom); const updateModel = useAtomValue(updateModelMutationAtom); - const testModel = useAtomValue(testModelMutationAtom); + const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom); const allowlist = Array.isArray(connection.extra?.model_ids) ? (connection.extra.model_ids as string[]) : []; + const [isSettingsOpen, setIsSettingsOpen] = useState(false); + const [baseUrlDraft, setBaseUrlDraft] = useState(connection.base_url ?? ""); + const [apiKeyDraft, setApiKeyDraft] = useState(""); + const [showApiKey, setShowApiKey] = useState(false); const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); const [manualModelId, setManualModelId] = useState(""); const [modelFilter, setModelFilter] = useState(null); @@ -122,6 +233,38 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { const filteredModels = modelFilter ? connection.models.filter((model) => capability(model, modelFilter)) : connection.models; + const allFilteredModelsEnabled = + filteredModels.length > 0 && filteredModels.every((model) => model.enabled); + const hasConnectionChanges = + baseUrlDraft.trim() !== (connection.base_url ?? "") || + apiKeyDraft.trim() !== (connection.api_key ?? ""); + + function handleSettingsOpenChange(open: boolean) { + setIsSettingsOpen(open); + if (open) { + setBaseUrlDraft(connection.base_url ?? ""); + setApiKeyDraft(connection.api_key ?? ""); + setShowApiKey(false); + setAllowlistText(allowlist.join(", ")); + } + } + + function saveConnectionSettings() { + const data: ConnectionUpdateRequest = { + base_url: baseUrlDraft.trim() || null, + }; + + if (apiKeyDraft.trim() !== (connection.api_key ?? "")) { + data.api_key = apiKeyDraft.trim() || null; + } + + updateConnection.mutate( + { id: connection.id, data }, + { + onSuccess: () => setApiKeyDraft(""), + } + ); + } function saveAllowlist() { const ids = allowlistText @@ -144,170 +287,321 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { } function deleteCurrentConnection() { - const confirmed = window.confirm( - `Delete the ${providerLabel} connection and all of its models? This cannot be undone.` - ); - if (!confirmed) return; deleteConnection.mutate(connection.id); } + function toggleFilteredModels() { + const nextEnabled = !allFilteredModelsEnabled; + const modelIds = filteredModels + .filter((model) => model.enabled !== nextEnabled) + .map((model) => model.id); + + if (modelIds.length === 0) return; + + bulkUpdateModels.mutate({ + connectionId: connection.id, + data: { model_ids: modelIds, enabled: nextEnabled }, + }); + } + return ( -
-
-
-
+
+
+
+
{getProviderIcon(providerLabel, { className: "size-4" })} - {providerLabel} + {providerLabel} + {connection.scope === "GLOBAL" ? ( + + Default + + ) : null}
-
+
{connection.base_url || "Provider default endpoint"}
-
+
- - - -
-
- - {connection.last_status && connection.last_status !== "OK" ? ( -

- {connection.last_error || "Could not list models."} Chat may still work — add model IDs - manually below. -

- ) : null} - - {!isLocal ? ( -
- -
- setAllowlistText(event.target.value)} - placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" - /> - -
-

- Leave empty to discover all models. Recommended for providers with large catalogs (e.g. - OpenRouter). -

-
- ) : null} - -
- setManualModelId(event.target.value)} - onKeyDown={(event) => { - if (event.key === "Enter") { - event.preventDefault(); - addModel(); - } - }} - placeholder="Add a model ID manually (for providers without /models)" - /> - -
- - {connection.models.length > 0 ? ( -
- Filter models - {MODEL_CAPABILITY_FILTERS.map((filter) => { - const count = connection.models.filter((model) => capability(model, filter.key)).length; - const isActive = modelFilter === filter.key; - - return ( - - ); - })} -
- ) : null} + + + +
+ {getProviderIcon(providerLabel, { className: "size-5" })} +
+ + Configure {providerLabel} + + + Manage credentials and choose which models are available from this provider. + +
+
+
-
- {filteredModels.length === 0 && modelFilter ? ( -
- No {MODEL_CAPABILITY_FILTERS.find((filter) => filter.key === modelFilter)?.label.toLowerCase()}{" "} - models found on this connection. -
- ) : null} - {filteredModels.map((model) => ( -
-
-
- {getProviderIcon(providerLabel, { className: "size-4" })} - {modelLabel(model)} - {model.source === "MANUAL" ? ( - - manual - - ) : null} +
+
+
+ + +

+ Leave empty to use the provider default endpoint. +

+
+ +
+ +
+ setApiKeyDraft(event.target.value)} + placeholder={connection.has_api_key ? "Saved API key" : "Paste an API key"} + type={showApiKey ? "text" : "password"} + className="pr-11" + /> + +
+
+ + {!isLocal ? ( +
+ +
+ setAllowlistText(event.target.value)} + placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" + /> + +
+

+ Leave empty to discover all models. Recommended for providers with large + catalogs. +

+
+ ) : null} + + + +
+
+
+
Models
+

+ Select models to make available for this provider. +

+
+
+ + +
+
+ +
+ setManualModelId(event.target.value)} + onKeyDown={(event) => { + if (event.key === "Enter") { + event.preventDefault(); + addModel(); + } + }} + placeholder="Add a model ID manually" + /> + +
+ + {connection.models.length > 0 ? ( +
+ + Filter models + + {MODEL_CAPABILITY_FILTERS.map((filter) => { + const count = connection.models.filter((model) => + capability(model, filter.key) + ).length; + const isActive = modelFilter === filter.key; + + return ( + + ); + })} +
+ ) : null} + +
+ {connection.models.length === 0 ? ( +
+ No models yet. Use the refresh button to discover models or add one + manually. +
+ ) : null} + {filteredModels.length === 0 && modelFilter ? ( +
+ No{" "} + {MODEL_CAPABILITY_FILTERS.find( + (filter) => filter.key === modelFilter + )?.label.toLowerCase()}{" "} + models found on this connection. +
+ ) : null} +
+ {filteredModels.map((model) => ( +
+ + updateModel.mutate({ + id: model.id, + data: { enabled: checked === true }, + }) + } + disabled={updateModel.isPending} + /> +
+
+ {modelLabel(model)} + {model.source === "MANUAL" ? ( + + manual + + ) : null} +
+
+ {["chat", "vision", "image_gen"] + .filter((key) => + capability(model, key as "chat" | "vision" | "image_gen") + ) + .join(", ") || "No discovered capabilities"} +
+
+
+ ))} +
+
+
+ + {connection.last_status && connection.last_status !== "OK" ? ( +

+ {connection.last_error || "Could not list models."} Chat may still work; add + model IDs manually if discovery is unavailable. +

+ ) : null} +
-
- {["chat", "vision", "image_gen"] - .filter((key) => capability(model, key as "chat" | "vision" | "image_gen")) - .join(", ") || "No discovered capabilities"} -
-
-
- + + + + + + + + + -
-
- ))} + + + + Delete this provider? + + {providerLabel} and all of + its models will be removed from this search space. This cannot be undone. + + + + Cancel + + Delete + + + + +
); @@ -394,19 +688,13 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
- setBaseUrl(event.target.value)} + onChange={setBaseUrl} placeholder={ isOllama ? "http://host.docker.internal:11434" : "https://api.example.com/v1" } - list="model-conn-url-suggestions" /> - - {URL_SUGGESTIONS.map((url) => ( -
@@ -425,7 +713,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num Boolean(selectedProvider?.base_url_required && !baseUrl.trim()) } > - Add + Add
@@ -439,11 +727,17 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num

-
- {connections.map((connection) => ( - - ))} -
+ {connections.length > 0 ? ( +
+ +

Available Providers

+
+ {connections.map((connection) => ( + + ))} +
+
+ ) : null} diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index a34687d74..c75f4c90a 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -26,6 +26,7 @@ export const connectionRead = z.object({ id: z.number(), provider: z.string(), base_url: z.string().nullable().optional(), + api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), scope: z.union([connectionScopeEnum, z.string()]), search_space_id: z.number().nullable().optional(), @@ -73,6 +74,11 @@ export const modelUpdateRequest = z.object({ capabilities_override: z.record(z.string(), z.any()).optional(), }); +export const modelsBulkUpdateRequest = z.object({ + model_ids: z.array(z.number()).min(1).max(1000), + enabled: z.boolean(), +}); + export const verifyConnectionResponse = z.object({ status: z.string(), ok: z.boolean(), @@ -107,6 +113,7 @@ export type ConnectionCreateRequest = z.infer; export type ConnectionUpdateRequest = z.infer; export type ModelCreateRequest = z.infer; export type ModelUpdateRequest = z.infer; +export type ModelsBulkUpdateRequest = z.infer; export type ModelRoles = z.infer; export type VerifyConnectionResponse = z.infer; export type ModelProviderRead = z.infer; diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index 12ad8e0d2..bd5aa1309 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -10,12 +10,14 @@ import { type ModelProviderRead, type ModelRead, type ModelRoles, + type ModelsBulkUpdateRequest, type ModelUpdateRequest, modelCreateRequest, - modelProviderListResponse, modelListResponse, + modelProviderListResponse, modelRead, modelRoles, + modelsBulkUpdateRequest, modelUpdateRequest, type VerifyConnectionResponse, verifyConnectionResponse, @@ -97,6 +99,25 @@ class ModelConnectionsApiService { }); }; + bulkUpdateModels = async ( + connectionId: number, + request: ModelsBulkUpdateRequest + ): Promise => { + const parsed = modelsBulkUpdateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.request( + `/api/v1/model-connections/${connectionId}/models`, + modelListResponse, + { + method: "PATCH", + headers: { "Content-Type": "application/json" }, + body: parsed.data, + } + ); + }; + testModel = async (id: number): Promise => { return baseApiService.post(`/api/v1/models/${id}/test`, verifyConnectionResponse); };