From 9f6210ad089788304c1a2bbb1068e505cefe18c3 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 00:12:04 +0530 Subject: [PATCH] feat(model-connections): add test preview functionality for model connections --- .../app/routes/model_connections_routes.py | 46 +++++++++- surfsense_backend/app/schemas/__init__.py | 20 ++++- .../app/schemas/model_connections.py | 4 + .../app/services/model_connection_service.py | 87 +++++++++++++++++-- .../app/services/model_resolver.py | 2 + .../app/services/provider_registry.py | 27 ++++-- .../model-connections-mutation.atoms.ts | 15 ++++ .../settings/model-connections-settings.tsx | 82 +++++++++-------- .../connection-settings-dialog.tsx | 63 +++++++++----- .../provider-connect-dialog.tsx | 4 +- .../types/model-connections.types.ts | 5 ++ .../lib/apis/model-connections-api.service.ts | 16 ++++ 12 files changed, 294 insertions(+), 77 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 474d376d3..76e4a3dfb 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -26,11 +26,13 @@ from app.schemas import ( ModelRead, ModelRolesRead, ModelRolesUpdate, - ModelSelection, ModelsBulkUpdate, + ModelSelection, + ModelTestPreview, ModelUpdate, VerifyConnectionResponse, ) +from app.services.model_capabilities import has_capability from app.services.model_connection_service import ( ModelDiscoveryError, derive_capabilities, @@ -38,7 +40,6 @@ from app.services.model_connection_service import ( persist_verification, test_model, ) -from app.services.model_capabilities import has_capability from app.services.provider_registry import REGISTRY from app.users import current_active_user from app.utils.rbac import check_permission @@ -321,6 +322,47 @@ async def preview_connection_models( return [_preview_model_read(item) for item in discovered] +@router.post("/model-connections/test-preview", response_model=VerifyConnectionResponse) +async def test_preview_connection_model( + data: ModelTestPreview, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None: + await check_permission( + session, + user, + data.search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to create model connections in this search space", + ) + + model_id = data.model_id.strip() + if not model_id: + raise HTTPException(status_code=400, detail="model_id is required") + + draft = Connection( + provider=data.provider, + base_url=data.base_url, + api_key=data.api_key, + extra=data.extra or {}, + scope=data.scope, + enabled=data.enabled, + search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, + user_id=user.id, + ) + model = Model( + connection_id=0, + model_id=model_id, + source=ModelSource.MANUAL, + enabled=True, + capabilities_override={}, + catalog={}, + ) + result = await test_model(draft, model) + return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message) + + @router.put("/model-connections/{connection_id}", response_model=ConnectionRead) async def update_connection( connection_id: int, diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index efa448dcd..3c4fdfa83 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -54,8 +54,9 @@ from .model_connections import ( ModelRead, ModelRolesRead, ModelRolesUpdate, - ModelSelection, ModelsBulkUpdate, + ModelSelection, + ModelTestPreview, ModelUpdate, VerifyConnectionResponse, ) @@ -149,7 +150,7 @@ from .vision_llm import ( VisionLLMConfigUpdate, ) -__all__ = [ +__all__ = [ # Folder schemas "BulkDocumentMove", # Chat schemas (assistant-ui integration) @@ -159,6 +160,10 @@ __all__ = [ "ChunkCreate", "ChunkRead", "ChunkUpdate", + # Model connection schemas + "ConnectionCreate", + "ConnectionRead", + "ConnectionUpdate", "CreateCreditCheckoutSessionRequest", "CreateCreditCheckoutSessionResponse", "CreditPurchaseHistoryResponse", @@ -232,6 +237,16 @@ __all__ = [ "MembershipRead", "MembershipReadWithUser", "MembershipUpdate", + "ModelCreate", + "ModelPreviewRead", + "ModelProviderRead", + "ModelRead", + "ModelRolesRead", + "ModelRolesUpdate", + "ModelSelection", + "ModelTestPreview", + "ModelUpdate", + "ModelsBulkUpdate", "NewChatMessageAppend", "NewChatMessageCreate", "NewChatMessageRead", @@ -282,6 +297,7 @@ __all__ = [ "UserRead", "UserSearchSpaceAccess", "UserUpdate", + "VerifyConnectionResponse", # Video Presentation schemas "VideoPresentationBase", "VideoPresentationCreate", diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index 896532d6f..67d94f821 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -85,6 +85,10 @@ class ConnectionCreate(BaseModel): models: list[ModelSelection] = Field(default_factory=list) +class ModelTestPreview(ConnectionCreate): + model_id: str = Field(..., max_length=255) + + class ConnectionUpdate(BaseModel): provider: str | None = Field(None, max_length=100) base_url: str | None = Field(None, max_length=500) diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index c9ee2779f..fbfdd437f 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -15,7 +15,7 @@ import litellm from app.db import Connection, Model, ModelSource from app.services.model_resolver import ensure_v1, to_litellm from app.services.openrouter_model_normalizer import normalize_openrouter_models -from app.services.provider_registry import Transport, spec_for +from app.services.provider_registry import Transport, provider_label, spec_for logger = logging.getLogger(__name__) @@ -77,6 +77,68 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str: return raw +def _model_test_error(conn: Connection, model_id: str, exc: Exception) -> VerifyResult: + provider_name = provider_label(conn.provider) + raw = str(exc) + normalized = raw.lower() + exc_name = exc.__class__.__name__.lower() + status_code = getattr(exc, "status_code", None) + + logger.info( + "Model test failed for provider=%s model=%s: %s", + conn.provider, + model_id, + raw, + ) + + if status_code in (401, 403) or "authentication" in exc_name or "401" in normalized: + return VerifyResult( + "AUTH_FAILED", + False, + f"Authentication failed. Check your {provider_name} credentials and try again.", + ) + + if status_code == 404 or "notfound" in exc_name or "not found" in normalized: + if conn.provider == "azure": + message = ( + "Azure OpenAI deployment was not found. Check the deployment name, " + "API version, and endpoint." + ) + else: + message = f"Model '{model_id}' was not found on {provider_name}." + return VerifyResult("NOT_FOUND", False, message) + + if status_code == 429 or "ratelimit" in exc_name or "rate limit" in normalized: + return VerifyResult( + "RATE_LIMITED", + False, + f"{provider_name} rate limited the model test. Try again later.", + ) + + if "timeout" in exc_name or "timed out" in normalized: + return VerifyResult( + "TIMEOUT", + False, + f"{provider_name} did not respond in time. Check the endpoint and try again.", + ) + + if "connection" in exc_name or "connect" in normalized: + return VerifyResult( + "UNREACHABLE", + False, + _docker_hint( + _base_url_or_default(conn), + f"Could not reach {provider_name}. Check the endpoint and try again.", + ), + ) + + return VerifyResult( + "UNREACHABLE", + False, + f"Could not test model '{model_id}' on {provider_name}. Check the credentials, endpoint, and model name.", + ) + + async def verify_connection(conn: Connection) -> VerifyResult: spec = spec_for(conn.provider) base_url = _base_url_or_default(conn) @@ -321,15 +383,24 @@ async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]: return [] def list_models() -> list[dict[str, Any]]: + import os + import boto3 - client_kwargs: dict[str, str] = {"region_name": region_name} - if params.get("aws_access_key_id"): - client_kwargs["aws_access_key_id"] = params["aws_access_key_id"] - if params.get("aws_secret_access_key"): - client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"] + if bearer_token := params.get("aws_bearer_token_bedrock"): + try: + os.environ["AWS_BEARER_TOKEN_BEDROCK"] = bearer_token + client = boto3.client("bedrock", region_name=region_name) + finally: + os.environ.pop("AWS_BEARER_TOKEN_BEDROCK", None) + else: + client_kwargs: dict[str, str] = {"region_name": region_name} + if params.get("aws_access_key_id"): + client_kwargs["aws_access_key_id"] = params["aws_access_key_id"] + if params.get("aws_secret_access_key"): + client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"] + client = boto3.client("bedrock", **client_kwargs) - client = boto3.client("bedrock", **client_kwargs) response = client.list_foundation_models() results: list[dict[str, Any]] = [] for item in response.get("modelSummaries", []): @@ -393,7 +464,7 @@ async def test_model(conn: Connection, model: Model) -> VerifyResult: **kwargs, ) except Exception as exc: - return VerifyResult("UNREACHABLE", False, str(exc)) + return _model_test_error(conn, model.model_id, exc) model.supports_chat = True return VerifyResult("OK", True, "Model test succeeded.") diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py index ae6fd2877..599762824 100644 --- a/surfsense_backend/app/services/model_resolver.py +++ b/surfsense_backend/app/services/model_resolver.py @@ -55,6 +55,8 @@ def to_litellm( kwargs["api_version"] = api_version kwargs.update(extra.get("litellm_params", {})) kwargs.update(extra.get("kwargs", {})) + if provider == "bedrock" and (bearer_token := kwargs.pop("aws_bearer_token_bedrock", None)): + kwargs["api_key"] = bearer_token return model_string, kwargs diff --git a/surfsense_backend/app/services/provider_registry.py b/surfsense_backend/app/services/provider_registry.py index 98bfb63c1..2a58a3468 100644 --- a/surfsense_backend/app/services/provider_registry.py +++ b/surfsense_backend/app/services/provider_registry.py @@ -38,21 +38,24 @@ class ProviderSpec: default_base_url: str | None base_url_required: bool auth_style: AuthStyle + display_name: str | None = None REGISTRY: dict[str, ProviderSpec] = { "openai": ProviderSpec( - Transport.NATIVE, "openai", "openai_models", None, False, "bearer" + Transport.NATIVE, "openai", "openai_models", None, False, "bearer", "OpenAI" ), "anthropic": ProviderSpec( - Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key" + Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key", "Anthropic" + ), + "azure": ProviderSpec( + Transport.NATIVE, "azure", "static", None, True, "native", "Azure OpenAI" ), - "azure": ProviderSpec(Transport.NATIVE, "azure", "static", None, True, "native"), "vertex_ai": ProviderSpec( - Transport.NATIVE, "vertex_ai", "static", None, False, "native" + Transport.NATIVE, "vertex_ai", "static", None, False, "native", "Vertex AI" ), "bedrock": ProviderSpec( - Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native" + Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native", "Amazon Bedrock" ), "openrouter": ProviderSpec( Transport.OPENAI_COMPATIBLE, @@ -61,6 +64,7 @@ REGISTRY: dict[str, ProviderSpec] = { "https://openrouter.ai/api/v1", False, "bearer", + "OpenRouter", ), "openai_compatible": ProviderSpec( Transport.OPENAI_COMPATIBLE, @@ -69,6 +73,7 @@ REGISTRY: dict[str, ProviderSpec] = { None, True, "bearer", + "OpenAI-compatible provider", ), "lm_studio": ProviderSpec( Transport.OPENAI_COMPATIBLE, @@ -77,6 +82,7 @@ REGISTRY: dict[str, ProviderSpec] = { "http://localhost:1234/v1", True, "bearer", + "LM Studio", ), "ollama_chat": ProviderSpec( Transport.OLLAMA, @@ -85,6 +91,7 @@ REGISTRY: dict[str, ProviderSpec] = { "http://localhost:11434", True, "none", + "Ollama", ), } @@ -96,4 +103,12 @@ def spec_for(provider: str | None) -> ProviderSpec: ) -__all__ = ["REGISTRY", "ProviderSpec", "Transport", "spec_for"] +def provider_label(provider: str | None) -> str: + provider_key = (provider or "").strip() + spec = spec_for(provider_key) + if spec.display_name: + return spec.display_name + return provider_key.replace("_", " ").title() if provider_key else "Provider" + + +__all__ = ["REGISTRY", "ProviderSpec", "Transport", "provider_label", "spec_for"] diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index ea91c6483..00f8fa9ad 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -7,6 +7,7 @@ import type { ModelPreviewRead, ModelRead, ModelRoles, + ModelTestPreviewRequest, ModelsBulkUpdateRequest, ModelUpdateRequest, VerifyConnectionResponse, @@ -114,6 +115,20 @@ export const previewConnectionModelsMutationAtom = atomWithMutation(() => { }; }); +export const testPreviewModelMutationAtom = atomWithMutation(() => { + return { + mutationKey: ["model-connections", "test-preview"], + mutationFn: (request: ModelTestPreviewRequest) => + modelConnectionsApiService.testPreviewModel(request), + onSuccess: (result: VerifyConnectionResponse) => { + if (!result.ok) { + toast.error(result.message || "Model test failed"); + } + }, + onError: (error: Error) => toast.error(error.message || "Failed to test model"), + }; +}); + export const addManualModelMutationAtom = atomWithMutation((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 6c3d1a411..cf00ac6c9 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,12 +1,14 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { CheckCircle2, Trash2, XCircle } from "lucide-react"; +import { Trash2 } from "lucide-react"; import { useState } from "react"; +import { toast } from "sonner"; import { createModelConnectionMutationAtom, deleteModelConnectionMutationAtom, previewConnectionModelsMutationAtom, + testPreviewModelMutationAtom, updateModelRolesMutationAtom, } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { @@ -53,24 +55,6 @@ import { providerIcon, } from "./model-connections/provider-metadata"; -function StatusBadge({ connection }: { connection: ConnectionRead }) { - if (connection.last_status === "OK") { - return ( - - Healthy - - ); - } - if (connection.last_status) { - return ( - - {connection.last_status} - - ); - } - return Not tested; -} - function flattenModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models.map((model) => ({ @@ -110,7 +94,6 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
- @@ -156,6 +139,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const [{ data: roles }] = useAtom(modelRolesAtom); const createConnection = useAtomValue(createModelConnectionMutationAtom); const previewModels = useAtomValue(previewConnectionModelsMutationAtom); + const testPreviewModel = useAtomValue(testPreviewModelMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); const [isAddProviderOpen, setIsAddProviderOpen] = useState(false); @@ -220,9 +204,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num }); } - // Each provider connect form builds its own credential payload; the backend - // resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM. - function handleCreate(draft: ConnectionDraft) { + function connectionModelsForDraft(draft: ConnectionDraft) { const models = [...connectModels]; if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) { models.push({ @@ -233,22 +215,46 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num metadata: {}, }); } + return models; + } - createConnection.mutate( + function representativeTestModel(models: ModelSelection[]) { + const enabledModels = models.filter((model) => model.enabled); + return enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0]; + } + + // Each provider connect form builds its own credential payload; the backend + // resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM. + function handleCreate(draft: ConnectionDraft) { + const models = connectionModelsForDraft(draft); + const testModel = representativeTestModel(models); + if (!testModel) { + toast.error("Select at least one model before connecting"); + return; + } + + const request = { + provider, + base_url: draft.base_url, + api_key: draft.api_key, + scope: "SEARCH_SPACE" as const, + search_space_id: searchSpaceId, + extra: draft.extra, + enabled: true, + models, + }; + + testPreviewModel.mutate( + { ...request, model_id: testModel.model_id }, { - provider, - base_url: draft.base_url, - api_key: draft.api_key, - scope: "SEARCH_SPACE", - search_space_id: searchSpaceId, - extra: draft.extra, - enabled: true, - models, - }, - { - onSuccess: () => { - setIsAddProviderOpen(false); - resetConnectState(); + onSuccess: (result) => { + if (!result.ok) return; + createConnection.mutate(request, { + onSuccess: () => { + setIsAddProviderOpen(false); + resetConnectState(); + }, + }); }, } ); @@ -380,7 +386,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num onOpenChange={handleConnectOpenChange} provider={provider} selectedProvider={selectedProvider} - isPending={createConnection.isPending} + isPending={createConnection.isPending || testPreviewModel.isPending} onSubmit={handleCreate} previewModels={connectModels} isPreviewingModels={previewModels.isPending} diff --git a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx index d0f8e6c16..badddb8d7 100644 --- a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx +++ b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx @@ -5,9 +5,9 @@ import { addManualModelMutationAtom, bulkUpdateModelsMutationAtom, discoverConnectionModelsMutationAtom, + testPreviewModelMutationAtom, updateModelConnectionMutationAtom, updateModelMutationAtom, - verifyModelConnectionMutationAtom, } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { Button } from "@/components/ui/button"; import { @@ -26,7 +26,7 @@ import type { ConnectionRead, ConnectionUpdateRequest, } from "@/contracts/types/model-connections.types"; -import type { SelectableModel } from "./model-utils"; +import { capability, type SelectableModel } from "./model-utils"; import { ModelsSelectionPanel } from "./models-selection-panel"; import { providerIcon } from "./provider-metadata"; @@ -39,8 +39,8 @@ export function ConnectionSettingsDialog({ connection, providerLabel, }: ConnectionSettingsDialogProps) { - const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); + const testPreviewModel = useAtomValue(testPreviewModelMutationAtom); const updateConnection = useAtomValue(updateModelConnectionMutationAtom); const addManualModel = useAtomValue(addManualModelMutationAtom); const updateModel = useAtomValue(updateModelMutationAtom); @@ -81,11 +81,45 @@ export function ConnectionSettingsDialog({ if (apiKeyDraft.trim() !== (connection.api_key ?? "")) { data.api_key = apiKeyDraft.trim() || null; } + const apiKeyForTest = Object.hasOwn(data, "api_key") + ? (data.api_key ?? null) + : (connection.api_key ?? null); - updateConnection.mutate( - { id: connection.id, data }, + const enabledModels = connection.models.filter((model) => model.enabled); + const testModel = + enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0]; + if (!testModel) { + updateConnection.mutate( + { id: connection.id, data }, + { + onSuccess: () => setApiKeyDraft(""), + } + ); + return; + } + + testPreviewModel.mutate( { - onSuccess: () => setApiKeyDraft(""), + provider: connection.provider, + base_url: data.base_url, + api_key: apiKeyForTest, + scope: "SEARCH_SPACE", + search_space_id: connection.search_space_id, + extra: connection.extra ?? {}, + enabled: connection.enabled, + models: [], + model_id: testModel.model_id, + }, + { + onSuccess: (result) => { + if (!result.ok) return; + updateConnection.mutate( + { id: connection.id, data }, + { + onSuccess: () => setApiKeyDraft(""), + } + ); + }, } ); } @@ -219,26 +253,15 @@ export function ConnectionSettingsDialog({ onBulkToggle={handleBulkToggle} /> - {connection.last_status && connection.last_status !== "OK" ? ( -

- {connection.last_error || "Could not list models."} Chat may still work; add model - IDs manually if discovery is unavailable. -

- ) : null}
- diff --git a/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx b/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx index 315b9d3fa..51263d5f5 100644 --- a/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx +++ b/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx @@ -94,6 +94,8 @@ export function ProviderConnectDialog({ })(); const canRefreshModels = !isAzure && !isVertex && (!isBedrock || canSubmit); + const hasEnabledModel = previewModels.some((model) => model.enabled) || Boolean(currentDraft.seedModelId); + const canConnect = canSubmit && hasEnabledModel; return ( @@ -134,7 +136,7 @@ export function ProviderConnectDialog({ onOpenChange(false)} onSubmit={() => onSubmit(currentDraft)} - canSubmit={canSubmit} + canSubmit={canConnect} isPending={isPending} /> diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index 134c740b2..16db93868 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -66,6 +66,10 @@ export const connectionCreateRequest = z.object({ models: z.array(modelSelection).default([]), }); +export const modelTestPreviewRequest = connectionCreateRequest.extend({ + model_id: z.string().min(1), +}); + export const connectionUpdateRequest = z.object({ provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), @@ -129,6 +133,7 @@ export type ModelPreviewRead = z.infer; export type ModelSelection = z.infer; export type ConnectionRead = z.infer; export type ConnectionCreateRequest = z.infer; +export type ModelTestPreviewRequest = z.infer; export type ConnectionUpdateRequest = z.infer; export type ModelCreateRequest = z.infer; export type ModelUpdateRequest = z.infer; diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index f463a27e7..d875255ad 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -11,6 +11,7 @@ import { type ModelProviderRead, type ModelRead, type ModelRoles, + type ModelTestPreviewRequest, type ModelsBulkUpdateRequest, type ModelUpdateRequest, modelCreateRequest, @@ -19,6 +20,7 @@ import { modelProviderListResponse, modelRead, modelRoles, + modelTestPreviewRequest, modelsBulkUpdateRequest, modelUpdateRequest, type VerifyConnectionResponse, @@ -92,6 +94,20 @@ class ModelConnectionsApiService { ); }; + testPreviewModel = async (request: ModelTestPreviewRequest): Promise => { + const parsed = modelTestPreviewRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.post( + `/api/v1/model-connections/test-preview`, + verifyConnectionResponse, + { + body: parsed.data, + } + ); + }; + addManualModel = async ( connectionId: number, request: ModelCreateRequest