From 407f2a9612acbab423a5b490e1caa079158a3ecb Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 22:41:21 +0530 Subject: [PATCH] feat(model-connections): enhance model connection functionality with preview and selection features --- .../app/routes/model_connections_routes.py | 84 +++++++- surfsense_backend/app/schemas/__init__.py | 2 + .../app/schemas/model_connections.py | 27 +++ .../app/services/model_connection_service.py | 45 ++++ .../app/services/provider_registry.py | 3 +- .../model-connections-mutation.atoms.ts | 11 + .../agent-action-log/action-log-dialog.tsx | 4 +- .../settings/model-connections-settings.tsx | 192 ++++++++++-------- .../model-connections/azure-connect-form.tsx | 65 +++--- .../bedrock-connect-form.tsx | 144 ++++++------- .../model-connections/connect-fields.tsx | 32 ++- .../connection-settings-dialog.tsx | 20 +- .../default-connect-form.tsx | 48 ++--- .../settings/model-connections/model-utils.ts | 13 +- .../models-selection-panel.tsx | 10 +- .../provider-connect-dialog.tsx | 171 ++++++++-------- .../model-connections/provider-metadata.tsx | 4 +- .../model-connections/vertex-connect-form.tsx | 149 +++++++------- .../types/model-connections.types.ts | 19 ++ .../lib/apis/model-connections-api.service.ts | 16 ++ 20 files changed, 630 insertions(+), 429 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 730c68565..2405843a7 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -21,10 +21,12 @@ from app.schemas import ( ConnectionRead, ConnectionUpdate, ModelCreate, + ModelPreviewRead, ModelProviderRead, ModelRead, ModelRolesRead, ModelRolesUpdate, + ModelSelection, ModelsBulkUpdate, ModelUpdate, VerifyConnectionResponse, @@ -48,6 +50,21 @@ def _model_read(model: Model | dict) -> ModelRead: return ModelRead.model_validate(model) +def _preview_model_read(item: dict) -> ModelPreviewRead: + return ModelPreviewRead( + model_id=item["model_id"], + display_name=item.get("display_name"), + source=item.get("source", ModelSource.DISCOVERED), + supports_chat=item.get("supports_chat"), + max_input_tokens=item.get("max_input_tokens"), + supports_image_input=item.get("supports_image_input"), + supports_tools=item.get("supports_tools"), + supports_image_generation=item.get("supports_image_generation"), + enabled=item.get("enabled", False), + metadata=item.get("metadata") or item.get("catalog") or {}, + ) + + def _connection_read(conn: Connection | dict, models: list[Model | dict] | None = None) -> ConnectionRead: if isinstance(conn, dict): payload = { @@ -86,6 +103,25 @@ def _apply_model_facts(model: Model, facts: dict) -> None: model.supports_image_generation = facts.get("supports_image_generation") +def _selection_to_model(conn: Connection, selection: ModelSelection) -> Model: + source = ( + selection.source + if isinstance(selection.source, ModelSource) + else ModelSource(selection.source) + ) + model = Model( + connection_id=conn.id, + model_id=selection.model_id.strip(), + display_name=selection.display_name, + source=source, + capabilities_override={}, + enabled=selection.enabled, + catalog=selection.metadata, + ) + _apply_model_facts(model, selection.model_dump()) + return model + + def _default_model_for(models: list[Model], capability: str) -> int | None: for model in models: if model.enabled and has_capability(model, capability): @@ -226,7 +262,7 @@ async def create_connection( Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to create model connections in this search space", ) - payload = data.model_dump(exclude={"search_space_id"}) + payload = data.model_dump(exclude={"search_space_id", "models"}) conn = Connection( **payload, @@ -234,9 +270,51 @@ async def create_connection( user_id=user.id, ) session.add(conn) + await session.flush() + + seen_model_ids: set[str] = set() + for selection in data.models: + model_id = selection.model_id.strip() + if not model_id or model_id in seen_model_ids: + continue + seen_model_ids.add(model_id) + session.add(_selection_to_model(conn, selection)) + await session.commit() - await session.refresh(conn) - return _connection_read(conn, []) + conn = await _load_connection(session, conn.id) + await _default_unset_roles(session, conn, list(conn.models)) + await session.commit() + conn = await _load_connection(session, conn.id) + return _connection_read(conn, list(conn.models)) + + +@router.post("/model-connections/discover-preview", response_model=list[ModelPreviewRead]) +async def preview_connection_models( + data: ConnectionCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None: + await check_permission( + session, + user, + data.search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to create model connections in this search space", + ) + + draft = Connection( + provider=data.provider, + base_url=data.base_url, + api_key=data.api_key, + extra=data.extra or {}, + scope=data.scope, + enabled=data.enabled, + search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, + user_id=user.id, + ) + discovered = await discover_models(draft) + return [_preview_model_read(item) for item in discovered] @router.put("/model-connections/{connection_id}", response_model=ConnectionRead) diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 55e712f12..efa448dcd 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -49,10 +49,12 @@ from .model_connections import ( ConnectionRead, ConnectionUpdate, ModelCreate, + ModelPreviewRead, ModelProviderRead, ModelRead, ModelRolesRead, ModelRolesUpdate, + ModelSelection, ModelsBulkUpdate, ModelUpdate, VerifyConnectionResponse, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index 0b03c7fab..896532d6f 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -48,6 +48,32 @@ class ConnectionRead(BaseModel): model_config = ConfigDict(from_attributes=True) +class ModelSelection(BaseModel): + model_id: str = Field(..., max_length=255) + display_name: str | None = Field(None, max_length=255) + source: ModelSource | str = ModelSource.DISCOVERED + supports_chat: bool | None = None + max_input_tokens: int | None = None + supports_image_input: bool | None = None + supports_tools: bool | None = None + supports_image_generation: bool | None = None + enabled: bool = False + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ModelPreviewRead(BaseModel): + model_id: str + display_name: str | None = None + source: ModelSource | str = ModelSource.DISCOVERED + supports_chat: bool | None = None + max_input_tokens: int | None = None + supports_image_input: bool | None = None + supports_tools: bool | None = None + supports_image_generation: bool | None = None + enabled: bool = False + metadata: dict[str, Any] = Field(default_factory=dict) + + class ConnectionCreate(BaseModel): provider: str = Field(..., max_length=100) base_url: str | None = Field(None, max_length=500) @@ -56,6 +82,7 @@ class ConnectionCreate(BaseModel): scope: ConnectionScope = ConnectionScope.SEARCH_SPACE search_space_id: int | None = None enabled: bool = True + models: list[ModelSelection] = Field(default_factory=list) class ConnectionUpdate(BaseModel): diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 428af736e..7742e837e 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from datetime import UTC, datetime from typing import Any +import anyio import httpx import litellm @@ -292,6 +293,48 @@ def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]: return results +async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]: + params = (conn.extra or {}).get("litellm_params", {}) + region_name = params.get("aws_region_name") + if not region_name: + return [] + + def list_models() -> list[dict[str, Any]]: + import boto3 + + client_kwargs: dict[str, str] = {"region_name": region_name} + if params.get("aws_access_key_id"): + client_kwargs["aws_access_key_id"] = params["aws_access_key_id"] + if params.get("aws_secret_access_key"): + client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"] + + client = boto3.client("bedrock", **client_kwargs) + response = client.list_foundation_models() + results: list[dict[str, Any]] = [] + for item in response.get("modelSummaries", []): + model_id = item.get("modelId") + if not model_id: + continue + input_modalities = set(item.get("inputModalities") or []) + output_modalities = set(item.get("outputModalities") or []) + results.append( + { + "model_id": model_id, + "display_name": item.get("modelName") or model_id, + "source": ModelSource.DISCOVERED, + "supports_chat": "TEXT" in input_modalities and "TEXT" in output_modalities, + "supports_image_input": "IMAGE" in input_modalities, + "supports_tools": None, + "supports_image_generation": "IMAGE" in output_modalities, + "max_input_tokens": None, + "metadata": item, + } + ) + return results + + return await anyio.to_thread.run_sync(list_models) + + async def discover_models(conn: Connection) -> list[dict[str, Any]]: allowlist = _allowlist(conn) spec = spec_for(conn.provider) @@ -304,6 +347,8 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: results = await _discover_anthropic_models(conn) elif spec.discovery == "openai_models": results = await _discover_openai_shaped_models(conn, conn.base_url) + elif spec.discovery == "bedrock_models": + results = await _discover_bedrock_models(conn) elif spec.discovery == "static": results = _litellm_static_models(conn) else: diff --git a/surfsense_backend/app/services/provider_registry.py b/surfsense_backend/app/services/provider_registry.py index 871769f11..98bfb63c1 100644 --- a/surfsense_backend/app/services/provider_registry.py +++ b/surfsense_backend/app/services/provider_registry.py @@ -21,6 +21,7 @@ DiscoveryKind = Literal[ "ollama", "openai_models", "anthropic_models", + "bedrock_models", "openrouter", "static", "none", @@ -51,7 +52,7 @@ REGISTRY: dict[str, ProviderSpec] = { Transport.NATIVE, "vertex_ai", "static", None, False, "native" ), "bedrock": ProviderSpec( - Transport.NATIVE, "bedrock", "static", None, False, "native" + Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native" ), "openrouter": ProviderSpec( Transport.OPENAI_COMPATIBLE, diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index fee3b95ba..ea91c6483 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -4,6 +4,7 @@ import type { ConnectionCreateRequest, ConnectionUpdateRequest, ModelCreateRequest, + ModelPreviewRead, ModelRead, ModelRoles, ModelsBulkUpdateRequest, @@ -103,6 +104,16 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { }; }); +export const previewConnectionModelsMutationAtom = atomWithMutation(() => { + return { + mutationKey: ["model-connections", "discover-preview"], + mutationFn: (request: ConnectionCreateRequest) => + modelConnectionsApiService.previewModels(request), + onSuccess: (_models: ModelPreviewRead[]) => {}, + onError: (error: Error) => toast.error(error.message || "Failed to discover models"), + }; +}); + export const addManualModelMutationAtom = atomWithMutation((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/agent-action-log/action-log-dialog.tsx b/surfsense_web/components/agent-action-log/action-log-dialog.tsx index 1d0eefc17..5f3b83db1 100644 --- a/surfsense_web/components/agent-action-log/action-log-dialog.tsx +++ b/surfsense_web/components/agent-action-log/action-log-dialog.tsx @@ -2,7 +2,7 @@ import { useQueryClient } from "@tanstack/react-query"; import { useAtom, useAtomValue } from "jotai"; -import { RefreshCcw, Workflow } from "lucide-react"; +import { RefreshCw, Workflow } from "lucide-react"; import { useCallback } from "react"; import { actionLogDialogAtom } from "@/atoms/agent/action-log-dialog.atom"; import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; @@ -112,7 +112,7 @@ export function ActionLogDialog() { className="absolute right-14 top-4 size-8 rounded-full p-0 text-muted-foreground hover:bg-accent hover:text-accent-foreground" aria-label="Refresh action log" > - +
diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 2e15ce2e9..6c3d1a411 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -4,12 +4,9 @@ import { useAtom, useAtomValue } from "jotai"; import { CheckCircle2, Trash2, XCircle } from "lucide-react"; import { useState } from "react"; import { - addManualModelMutationAtom, - bulkUpdateModelsMutationAtom, createModelConnectionMutationAtom, deleteModelConnectionMutationAtom, - discoverConnectionModelsMutationAtom, - updateModelMutationAtom, + previewConnectionModelsMutationAtom, updateModelRolesMutationAtom, } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { @@ -41,9 +38,13 @@ import { SelectValue, } from "@/components/ui/select"; import { Separator } from "@/components/ui/separator"; -import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; +import type { + ConnectionRead, + ModelRead, + ModelSelection, +} from "@/contracts/types/model-connections.types"; import { ConnectionSettingsDialog } from "./model-connections/connection-settings-dialog"; -import { capability, modelLabel } from "./model-connections/model-utils"; +import { capability, modelLabel, type SelectableModel } from "./model-connections/model-utils"; import { ProviderConnectDialog } from "./model-connections/provider-connect-dialog"; import { type ConnectionDraft, @@ -154,16 +155,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const [{ data: providers = [] }] = useAtom(modelProvidersAtom); const [{ data: roles }] = useAtom(modelRolesAtom); const createConnection = useAtomValue(createModelConnectionMutationAtom); - const addManualModel = useAtomValue(addManualModelMutationAtom); - const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); - const updateModel = useAtomValue(updateModelMutationAtom); - const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom); + const previewModels = useAtomValue(previewConnectionModelsMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); const [isAddProviderOpen, setIsAddProviderOpen] = useState(false); const [provider, setProvider] = useState("openai_compatible"); - const [connectedConnection, setConnectedConnection] = useState(null); - const [connectModels, setConnectModels] = useState([]); + const [connectModels, setConnectModels] = useState([]); const selectedProvider = providers.find((item) => item.provider === provider); const sortedProviders = [...providers].sort((left, right) => { @@ -185,7 +182,6 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); function resetConnectState() { - setConnectedConnection(null); setConnectModels([]); } @@ -196,15 +192,48 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num } } - function replaceConnectModels(updatedModels: ModelRead[]) { - setConnectModels((current) => - current.map((model) => updatedModels.find((updated) => updated.id === model.id) ?? model) - ); + function toModelSelection(model: SelectableModel): ModelSelection { + return { + model_id: model.model_id, + display_name: model.display_name, + source: model.source || "DISCOVERED", + supports_chat: model.supports_chat, + max_input_tokens: model.max_input_tokens, + supports_image_input: model.supports_image_input, + supports_tools: model.supports_tools, + supports_image_generation: model.supports_image_generation, + enabled: model.enabled, + metadata: "metadata" in model ? (model.metadata ?? {}) : (model.catalog ?? {}), + }; + } + + function mergePreviewModels(fetchedModels: SelectableModel[]) { + setConnectModels((current) => { + const currentById = new Map(current.map((model) => [model.model_id, model])); + return fetchedModels.map((model) => { + const prior = currentById.get(model.model_id); + return { + ...toModelSelection(model), + enabled: prior ? prior.enabled : model.enabled, + }; + }); + }); } // Each provider connect form builds its own credential payload; the backend // resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM. function handleCreate(draft: ConnectionDraft) { + const models = [...connectModels]; + if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) { + models.push({ + model_id: draft.seedModelId, + display_name: draft.seedModelId, + source: "MANUAL", + enabled: true, + metadata: {}, + }); + } + createConnection.mutate( { provider, @@ -214,26 +243,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num search_space_id: searchSpaceId, extra: draft.extra, enabled: true, + models, }, { - onSuccess: (created) => { - setConnectedConnection(created); - setConnectModels([]); - if (draft.seedModelId) { - addManualModel.mutate( - { - connectionId: created.id, - data: { model_id: draft.seedModelId }, - }, - { - onSuccess: (model) => setConnectModels([model]), - } - ); - } else { - discoverModels.mutate(created.id, { - onSuccess: (models) => setConnectModels(models), - }); - } + onSuccess: () => { + setIsAddProviderOpen(false); + resetConnectState(); }, } ); @@ -243,52 +258,72 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num resetConnectState(); setProvider(providerId); setIsAddProviderOpen(true); + if (providerId === "vertex_ai") { + previewModels.mutate( + { + provider: providerId, + base_url: null, + api_key: null, + scope: "SEARCH_SPACE", + search_space_id: searchSpaceId, + extra: {}, + enabled: true, + models: [], + }, + { + onSuccess: mergePreviewModels, + } + ); + } } - function refreshConnectModels() { - if (!connectedConnection) return; - discoverModels.mutate(connectedConnection.id, { - onSuccess: (models) => setConnectModels(models), - }); + function refreshConnectModels(draft: ConnectionDraft) { + previewModels.mutate( + { + provider, + base_url: draft.base_url, + api_key: draft.api_key, + scope: "SEARCH_SPACE", + search_space_id: searchSpaceId, + extra: draft.extra, + enabled: true, + models: [], + }, + { + onSuccess: mergePreviewModels, + } + ); } function addConnectModel(modelId: string) { - if (!connectedConnection) return; - addManualModel.mutate( - { connectionId: connectedConnection.id, data: { model_id: modelId } }, - { - onSuccess: (model) => setConnectModels((current) => [...current, model]), - } + setConnectModels((current) => { + if (current.some((model) => model.model_id === modelId)) return current; + return [ + ...current, + { + model_id: modelId, + display_name: modelId, + source: "MANUAL", + enabled: true, + metadata: {}, + }, + ]; + }); + } + + function toggleConnectModel(model: SelectableModel, enabled: boolean) { + setConnectModels((current) => + current.map((item) => (item.model_id === model.model_id ? { ...item, enabled } : item)) ); } - function toggleConnectModel(model: ModelRead, enabled: boolean) { - updateModel.mutate( - { id: model.id, data: { enabled } }, - { - onSuccess: (updated) => replaceConnectModels([updated]), - } + function bulkToggleConnectModels(models: SelectableModel[], enabled: boolean) { + const modelIds = new Set(models.map((model) => model.model_id)); + setConnectModels((current) => + current.map((item) => (modelIds.has(item.model_id) ? { ...item, enabled } : item)) ); } - function bulkToggleConnectModels(models: ModelRead[], enabled: boolean) { - if (!connectedConnection) return; - bulkUpdateModels.mutate( - { - connectionId: connectedConnection.id, - data: { model_ids: models.map((model) => model.id), enabled }, - }, - { - onSuccess: replaceConnectModels, - } - ); - } - - function finishConnectFlow() { - setIsAddProviderOpen(false); - resetConnectState(); - } - function renderModelOption(model: ModelRead & { connectionName: string; provider: string }) { return ( @@ -347,17 +382,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num selectedProvider={selectedProvider} isPending={createConnection.isPending} onSubmit={handleCreate} - connectedConnection={connectedConnection} - connectModels={connectModels} - isDiscoveringModels={discoverModels.isPending} - isAddingManualModel={addManualModel.isPending} - isUpdatingModel={updateModel.isPending} - isBulkUpdatingModels={bulkUpdateModels.isPending} - onRefreshModels={refreshConnectModels} - onAddManualModel={addConnectModel} - onToggleModel={toggleConnectModel} - onBulkToggleModels={bulkToggleConnectModels} - onDone={finishConnectFlow} + previewModels={connectModels} + isPreviewingModels={previewModels.isPending} + onPreviewModels={refreshConnectModels} + onAddPreviewModel={addConnectModel} + onTogglePreviewModel={toggleConnectModel} + onBulkTogglePreviewModels={bulkToggleConnectModels} /> {connections.length > 0 ? ( diff --git a/surfsense_web/components/settings/model-connections/azure-connect-form.tsx b/surfsense_web/components/settings/model-connections/azure-connect-form.tsx index 11a2e25d3..451f053db 100644 --- a/surfsense_web/components/settings/model-connections/azure-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/azure-connect-form.tsx @@ -1,7 +1,7 @@ -import { useState } from "react"; +import { useEffect, useState } from "react"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; -import { ApiKeyField, ConnectFormFooter } from "./connect-fields"; +import { ApiKeyField } from "./connect-fields"; import { isValidAzureTargetUri, type ProviderConnectFormProps, @@ -12,48 +12,43 @@ import { * Azure OpenAI connect form. The user pastes a single Target URI, which we parse * into api base, api version, and the deployment name (seeded as the model). */ -export function AzureConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { +export function AzureConnectForm({ onDraftChange }: ProviderConnectFormProps) { const [targetUri, setTargetUri] = useState(""); const [apiKey, setApiKey] = useState(""); const canSubmit = isValidAzureTargetUri(targetUri) && Boolean(apiKey.trim()); - function handleSubmit() { + useEffect(() => { const parsed = parseAzureTargetUri(targetUri); - onSubmit({ - base_url: parsed?.origin ?? null, - api_key: apiKey || null, - extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {}, - seedModelId: parsed?.deploymentName || undefined, - }); - } + onDraftChange( + { + base_url: parsed?.origin ?? null, + api_key: apiKey || null, + extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {}, + seedModelId: parsed?.deploymentName || undefined, + }, + canSubmit + ); + }, [apiKey, canSubmit, onDraftChange, targetUri]); return ( - <> -
-
- - setTargetUri(event.target.value)} - placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview" - /> -

- Paste your endpoint target URI from Azure OpenAI (including API base, deployment name, - and API version). -

-
- +
+ + setTargetUri(event.target.value)} + placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview" /> +

+ Paste your endpoint target URI from Azure OpenAI (including API base, deployment name, and + API version). +

- - +
); } diff --git a/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx b/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx index 9466c7cd1..3115ac223 100644 --- a/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { useEffect, useState } from "react"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { @@ -8,7 +8,7 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import { ConnectFormFooter } from "./connect-fields"; +import { ApiKeyField } from "./connect-fields"; import { AWS_REGION_OPTIONS, BEDROCK_AUTH_ACCESS_KEY, @@ -21,7 +21,7 @@ import { * Amazon Bedrock connect form. Region + auth method drive which AWS credentials * are collected; everything rides along in `extra.litellm_params`. */ -export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { +export function BedrockConnectForm({ onDraftChange }: ProviderConnectFormProps) { const [region, setRegion] = useState(""); const [authMethod, setAuthMethod] = useState(BEDROCK_AUTH_ACCESS_KEY); const [accessKeyId, setAccessKeyId] = useState(""); @@ -39,7 +39,7 @@ export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderCo return true; })(); - function handleSubmit() { + useEffect(() => { const params: Record = { aws_region_name: region }; if (authMethod === BEDROCK_AUTH_ACCESS_KEY) { params.aws_access_key_id = accessKeyId; @@ -47,88 +47,74 @@ export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderCo } else if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) { params.aws_bearer_token_bedrock = bearerToken; } - onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } }); - } + onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit); + }, [accessKeyId, authMethod, bearerToken, canSubmit, onDraftChange, region, secretAccessKey]); return ( - <> -
-
- - -
-
- - -
- {authMethod === BEDROCK_AUTH_ACCESS_KEY ? ( - <> -
- - setAccessKeyId(event.target.value)} - placeholder="AKIAIOSFODNN7EXAMPLE" - /> -
-
- - setSecretAccessKey(event.target.value)} - placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" - type="password" - /> -
- - ) : null} - {authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? ( +
+
+ + +
+
+ + +
+ {authMethod === BEDROCK_AUTH_ACCESS_KEY ? ( + <>
- + setBearerToken(event.target.value)} - placeholder="Your long-term API key" - type="password" + value={accessKeyId} + onChange={(event) => setAccessKeyId(event.target.value)} + placeholder="AKIAIOSFODNN7EXAMPLE" />
- ) : null} - {authMethod === BEDROCK_AUTH_IAM ? ( -

- SurfSense will use the IAM role attached to the environment it's running in to - authenticate. -

- ) : null} + + + ) : null} + {authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? ( + + ) : null} + {authMethod === BEDROCK_AUTH_IAM ? (

- Add Bedrock model IDs from the provider's settings after connecting. + SurfSense will use the IAM role attached to the environment it's running in to + authenticate.

-
- - + ) : null} +

+ Add Bedrock model IDs from the provider's settings after connecting. +

+
); } diff --git a/surfsense_web/components/settings/model-connections/connect-fields.tsx b/surfsense_web/components/settings/model-connections/connect-fields.tsx index af8db7f12..44b2d434f 100644 --- a/surfsense_web/components/settings/model-connections/connect-fields.tsx +++ b/surfsense_web/components/settings/model-connections/connect-fields.tsx @@ -1,3 +1,5 @@ +import { Eye, EyeOff } from "lucide-react"; +import { useState } from "react"; import { Button } from "@/components/ui/button"; import { DialogFooter } from "@/components/ui/dialog"; import { Input } from "@/components/ui/input"; @@ -43,15 +45,31 @@ export function ApiKeyField({ label = "API Key", placeholder = "API key", }: ApiKeyFieldProps) { + const [showApiKey, setShowApiKey] = useState(false); + return (
- onChange(event.target.value)} - placeholder={placeholder} - type="password" - /> +
+ onChange(event.target.value)} + placeholder={placeholder} + type={showApiKey ? "text" : "password"} + className="pr-11" + /> + +
); } @@ -71,7 +89,7 @@ export function ConnectFormFooter({ isPending, }: ConnectFormFooterProps) { return ( - + diff --git a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx index f3821af46..d0f8e6c16 100644 --- a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx +++ b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx @@ -25,8 +25,8 @@ import { Separator } from "@/components/ui/separator"; import type { ConnectionRead, ConnectionUpdateRequest, - ModelRead, } from "@/contracts/types/model-connections.types"; +import type { SelectableModel } from "./model-utils"; import { ModelsSelectionPanel } from "./models-selection-panel"; import { providerIcon } from "./provider-metadata"; @@ -101,17 +101,22 @@ export function ConnectionSettingsDialog({ }); } - function handleToggleModel(model: ModelRead, enabled: boolean) { + function handleToggleModel(model: SelectableModel, enabled: boolean) { + if (typeof model.id !== "number") return; updateModel.mutate({ id: model.id, data: { enabled }, }); } - function handleBulkToggle(models: ModelRead[], enabled: boolean) { + function handleBulkToggle(models: SelectableModel[], enabled: boolean) { + const modelIds = models + .map((model) => model.id) + .filter((id): id is number => typeof id === "number"); + if (modelIds.length === 0) return; bulkUpdateModels.mutate({ connectionId: connection.id, - data: { model_ids: models.map((model) => model.id), enabled }, + data: { model_ids: modelIds, enabled }, }); } @@ -184,12 +189,7 @@ export function ConnectionSettingsDialog({ onChange={(event) => setAllowlistText(event.target.value)} placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" /> -
diff --git a/surfsense_web/components/settings/model-connections/default-connect-form.tsx b/surfsense_web/components/settings/model-connections/default-connect-form.tsx index 3f261c6b2..768c0b5da 100644 --- a/surfsense_web/components/settings/model-connections/default-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/default-connect-form.tsx @@ -1,5 +1,5 @@ -import { useState } from "react"; -import { ApiBaseUrlField, ApiKeyField, ConnectFormFooter } from "./connect-fields"; +import { useEffect, useState } from "react"; +import { ApiBaseUrlField, ApiKeyField } from "./connect-fields"; import type { ProviderConnectFormProps } from "./provider-metadata"; /** @@ -11,41 +11,31 @@ export function DefaultConnectForm({ provider, defaultBaseUrl, baseUrlRequired, - isPending, - onCancel, - onSubmit, + onDraftChange, }: ProviderConnectFormProps) { const [baseUrl, setBaseUrl] = useState(defaultBaseUrl); const [apiKey, setApiKey] = useState(""); const isOllama = provider === "ollama_chat"; const canSubmit = !(baseUrlRequired && !baseUrl.trim()); - function handleSubmit() { - onSubmit({ base_url: baseUrl || null, api_key: apiKey || null, extra: {} }); - } + useEffect(() => { + onDraftChange({ base_url: baseUrl || null, api_key: apiKey || null, extra: {} }, canSubmit); + }, [apiKey, baseUrl, canSubmit, onDraftChange]); return ( - <> -
- - -
- + - + + ); } diff --git a/surfsense_web/components/settings/model-connections/model-utils.ts b/surfsense_web/components/settings/model-connections/model-utils.ts index 1db14b3eb..2887f2179 100644 --- a/surfsense_web/components/settings/model-connections/model-utils.ts +++ b/surfsense_web/components/settings/model-connections/model-utils.ts @@ -1,4 +1,4 @@ -import type { ModelRead } from "@/contracts/types/model-connections.types"; +import type { ModelPreviewRead, ModelRead } from "@/contracts/types/model-connections.types"; export type ModelCapabilityFilter = "chat" | "vision" | "image_gen"; @@ -8,17 +8,22 @@ export const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: stri { key: "image_gen", label: "Image" }, ]; -export function modelLabel(model: ModelRead) { +export type SelectableModel = (ModelRead | ModelPreviewRead) & { + id?: number | string; + connection_id?: number; +}; + +export function modelLabel(model: SelectableModel) { return model.display_name || model.model_id; } -export function capability(model: ModelRead, key: ModelCapabilityFilter) { +export function capability(model: SelectableModel, key: ModelCapabilityFilter) { if (key === "chat") return Boolean(model.supports_chat); if (key === "vision") return Boolean(model.supports_image_input); return Boolean(model.supports_image_generation); } -export function capabilityLabels(model: ModelRead) { +export function capabilityLabels(model: SelectableModel) { return MODEL_CAPABILITY_FILTERS.filter((filter) => capability(model, filter.key)) .map((filter) => filter.label.toLowerCase()) .join(", "); diff --git a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx index 20fbe862f..01ff0d1e7 100644 --- a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx +++ b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx @@ -4,17 +4,17 @@ import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Checkbox } from "@/components/ui/checkbox"; import { Input } from "@/components/ui/input"; -import type { ModelRead } from "@/contracts/types/model-connections.types"; import { capability, capabilityLabels, MODEL_CAPABILITY_FILTERS, type ModelCapabilityFilter, modelLabel, + type SelectableModel, } from "./model-utils"; interface ModelsSelectionPanelProps { - models: ModelRead[]; + models: SelectableModel[]; description?: string; emptyMessage?: string; manualInputPlaceholder?: string; @@ -25,8 +25,8 @@ interface ModelsSelectionPanelProps { isBulkUpdating?: boolean; onRefresh?: () => void; onAddManual?: (modelId: string) => void; - onToggleModel?: (model: ModelRead, enabled: boolean) => void; - onBulkToggle?: (models: ModelRead[], enabled: boolean) => void; + onToggleModel?: (model: SelectableModel, enabled: boolean) => void; + onBulkToggle?: (models: SelectableModel[], enabled: boolean) => void; } export function ModelsSelectionPanel({ @@ -166,7 +166,7 @@ export function ModelsSelectionPanel({
{filteredModels.map((model) => (
void; - connectedConnection?: ConnectionRead | null; - connectModels?: ModelRead[]; - isDiscoveringModels?: boolean; - isAddingManualModel?: boolean; - isUpdatingModel?: boolean; - isBulkUpdatingModels?: boolean; - onRefreshModels?: () => void; - onAddManualModel?: (modelId: string) => void; - onToggleModel?: (model: ModelRead, enabled: boolean) => void; - onBulkToggleModels?: (models: ModelRead[], enabled: boolean) => void; - onDone?: () => void; + previewModels?: SelectableModel[]; + isPreviewingModels?: boolean; + onPreviewModels?: (draft: ConnectionDraft) => void; + onAddPreviewModel?: (modelId: string) => void; + onTogglePreviewModel?: (model: SelectableModel, enabled: boolean) => void; + onBulkTogglePreviewModels?: (models: SelectableModel[], enabled: boolean) => void; } /** @@ -57,97 +50,93 @@ export function ProviderConnectDialog({ selectedProvider, isPending, onSubmit, - connectedConnection, - connectModels = [], - isDiscoveringModels = false, - isAddingManualModel = false, - isUpdatingModel = false, - isBulkUpdatingModels = false, - onRefreshModels, - onAddManualModel, - onToggleModel, - onBulkToggleModels, - onDone, + previewModels = [], + isPreviewingModels = false, + onPreviewModels, + onAddPreviewModel, + onTogglePreviewModel, + onBulkTogglePreviewModels, }: ProviderConnectDialogProps) { const meta = providerDisplay(provider); - const isModelSelectionStep = Boolean(connectedConnection); + const isAzure = provider === "azure"; + const isBedrock = provider === "bedrock"; + const isVertex = provider === "vertex_ai"; + const [currentDraft, setCurrentDraft] = useState({ + base_url: null, + api_key: null, + extra: {}, + }); + const [canSubmit, setCanSubmit] = useState(false); + + const handleDraftChange = useCallback((draft: ConnectionDraft, nextCanSubmit: boolean) => { + setCurrentDraft(draft); + setCanSubmit(nextCanSubmit); + }, []); const formProps: ProviderConnectFormProps = { provider, defaultBaseUrl: providerDefaultBaseUrl(provider, selectedProvider?.default_base_url), baseUrlRequired: Boolean(selectedProvider?.base_url_required), - isPending, - onCancel: () => onOpenChange(false), - onSubmit, + onDraftChange: handleDraftChange, }; + const modelDescription = (() => { + if (isAzure) { + return "Select the models to enable for Azure OpenAI"; + } + if (isBedrock) { + return "Select the models to enable for Amazon Bedrock"; + } + if (isVertex) { + return "Select the models to enable for Gemini"; + } + return "Select the models to enable for this provider"; + })(); + + const canRefreshModels = !isAzure && !isVertex && (!isBedrock || canSubmit); + return ( - +
{providerIcon(provider, "size-5")}
- - {isModelSelectionStep ? `Select ${meta.name} models` : `Connect ${meta.name}`} - - - {isModelSelectionStep - ? selectedProvider?.discovery === "static" - ? "Choose from known model IDs or add one manually." - : "Choose which discovered models should be available in this search space." - : meta.subtitle} - + Connect {meta.name} + {meta.subtitle}
- {isModelSelectionStep ? ( - <> -
- -
- - - - - ) : ( -
- {provider === "azure" ? ( - - ) : provider === "bedrock" ? ( - - ) : provider === "vertex_ai" ? ( - - ) : ( - - )} -
- )} +
+ {provider === "azure" ? ( + + ) : provider === "bedrock" ? ( + + ) : provider === "vertex_ai" ? ( + + ) : ( + + )} + + + + onPreviewModels?.(currentDraft) : undefined} + onAddManual={onAddPreviewModel} + onToggleModel={onTogglePreviewModel} + onBulkToggle={onBulkTogglePreviewModels} + /> +
+ onOpenChange(false)} + onSubmit={() => onSubmit(currentDraft)} + canSubmit={canSubmit} + isPending={isPending} + />
); diff --git a/surfsense_web/components/settings/model-connections/provider-metadata.tsx b/surfsense_web/components/settings/model-connections/provider-metadata.tsx index 0ca8ae419..73e873393 100644 --- a/surfsense_web/components/settings/model-connections/provider-metadata.tsx +++ b/surfsense_web/components/settings/model-connections/provider-metadata.tsx @@ -133,7 +133,5 @@ export interface ProviderConnectFormProps { provider: string; defaultBaseUrl: string; baseUrlRequired: boolean; - isPending: boolean; - onCancel: () => void; - onSubmit: (draft: ConnectionDraft) => void; + onDraftChange: (draft: ConnectionDraft, canSubmit: boolean) => void; } diff --git a/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx b/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx index 096d3df2e..1027742bc 100644 --- a/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { useEffect, useState } from "react"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { @@ -8,7 +8,6 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import { ConnectFormFooter } from "./connect-fields"; import { type ProviderConnectFormProps, VERTEX_AUTH_SERVICE_ACCOUNT, @@ -21,7 +20,7 @@ import { * credentials JSON file (read into a string); workload identity collects a * project id. Credentials ride along in `extra.litellm_params`. */ -export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { +export function VertexConnectForm({ onDraftChange }: ProviderConnectFormProps) { const [authMethod, setAuthMethod] = useState(VERTEX_AUTH_SERVICE_ACCOUNT); const [location, setLocation] = useState(VERTEX_DEFAULT_LOCATION); const [credentials, setCredentials] = useState(""); @@ -35,7 +34,7 @@ export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderCon setCredentials(await file.text()); } - function handleSubmit() { + useEffect(() => { const params: Record = {}; if (location) params.vertex_location = location; if (authMethod === VERTEX_AUTH_SERVICE_ACCOUNT) { @@ -43,85 +42,77 @@ export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderCon } else if (project) { params.vertex_project = project; } - onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } }); - } + onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit); + }, [authMethod, canSubmit, credentials, location, onDraftChange, project]); return ( - <> -
-
- - -
-
- - setLocation(event.target.value)} - placeholder={VERTEX_DEFAULT_LOCATION} - /> -

- Region where your Google Vertex AI models are hosted. -

-
- {authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? ( -
- - handleCredentialsFile(event.target.files?.[0])} - /> - -

- {credentials - ? "Credentials file loaded." - : "Attach your service account key JSON from Google Cloud."} -

-
- ) : ( -
- - setProject(event.target.value)} - placeholder="my-vertex-project" - /> -

- The GCP project where Vertex AI is enabled. -

-
- )} +
+
+ + +
+
+ + setLocation(event.target.value)} + placeholder={VERTEX_DEFAULT_LOCATION} + />

- Add Vertex AI model IDs from the provider's settings after connecting. + Region where your Google Vertex AI models are hosted.

- - + {authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? ( +
+ + handleCredentialsFile(event.target.files?.[0])} + /> + +

+ {credentials + ? "Credentials file loaded." + : "Attach your service account key JSON from Google Cloud."} +

+
+ ) : ( +
+ + setProject(event.target.value)} + placeholder="my-vertex-project" + /> +

+ The GCP project where Vertex AI is enabled. +

+
+ )} +

+ Add Vertex AI model IDs from the provider's settings after connecting. +

+
); } diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index c75f4c90a..134c740b2 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -40,6 +40,21 @@ export const connectionRead = z.object({ created_at: z.string().nullable().optional(), }); +export const modelSelection = z.object({ + model_id: z.string().min(1), + display_name: z.string().nullable().optional(), + source: z.union([modelSourceEnum, z.string()]).default("DISCOVERED"), + supports_chat: z.boolean().nullable().optional(), + max_input_tokens: z.number().nullable().optional(), + supports_image_input: z.boolean().nullable().optional(), + supports_tools: z.boolean().nullable().optional(), + supports_image_generation: z.boolean().nullable().optional(), + enabled: z.boolean().default(false), + metadata: z.record(z.string(), z.any()).default({}), +}); + +export const modelPreviewRead = modelSelection; + export const connectionCreateRequest = z.object({ provider: z.string().min(1), base_url: z.string().nullable().optional(), @@ -48,6 +63,7 @@ export const connectionCreateRequest = z.object({ scope: connectionScopeEnum.default("SEARCH_SPACE"), search_space_id: z.number().nullable().optional(), enabled: z.boolean().default(true), + models: z.array(modelSelection).default([]), }); export const connectionUpdateRequest = z.object({ @@ -105,9 +121,12 @@ export const modelProviderListResponse = z.array(modelProviderRead); export const connectionListResponse = z.array(connectionRead); export const modelListResponse = z.array(modelRead); +export const modelPreviewListResponse = z.array(modelPreviewRead); export type ConnectionScope = z.infer; export type ModelRead = z.infer; +export type ModelPreviewRead = z.infer; +export type ModelSelection = z.infer; export type ConnectionRead = z.infer; export type ConnectionCreateRequest = z.infer; export type ConnectionUpdateRequest = z.infer; diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index bd5aa1309..f463a27e7 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -7,6 +7,7 @@ import { connectionRead, connectionUpdateRequest, type ModelCreateRequest, + type ModelPreviewRead, type ModelProviderRead, type ModelRead, type ModelRoles, @@ -14,6 +15,7 @@ import { type ModelUpdateRequest, modelCreateRequest, modelListResponse, + modelPreviewListResponse, modelProviderListResponse, modelRead, modelRoles, @@ -76,6 +78,20 @@ class ModelConnectionsApiService { return baseApiService.post(`/api/v1/model-connections/${id}/discover`, modelListResponse); }; + previewModels = async (request: ConnectionCreateRequest): Promise => { + const parsed = connectionCreateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.post( + `/api/v1/model-connections/discover-preview`, + modelPreviewListResponse, + { + body: parsed.data, + } + ); + }; + addManualModel = async ( connectionId: number, request: ModelCreateRequest