diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py
index 474d376d3..76e4a3dfb 100644
--- a/surfsense_backend/app/routes/model_connections_routes.py
+++ b/surfsense_backend/app/routes/model_connections_routes.py
@@ -26,11 +26,13 @@ from app.schemas import (
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
- ModelSelection,
ModelsBulkUpdate,
+ ModelSelection,
+ ModelTestPreview,
ModelUpdate,
VerifyConnectionResponse,
)
+from app.services.model_capabilities import has_capability
from app.services.model_connection_service import (
ModelDiscoveryError,
derive_capabilities,
@@ -38,7 +40,6 @@ from app.services.model_connection_service import (
persist_verification,
test_model,
)
-from app.services.model_capabilities import has_capability
from app.services.provider_registry import REGISTRY
from app.users import current_active_user
from app.utils.rbac import check_permission
@@ -321,6 +322,47 @@ async def preview_connection_models(
return [_preview_model_read(item) for item in discovered]
+@router.post("/model-connections/test-preview", response_model=VerifyConnectionResponse)
+async def test_preview_connection_model(
+ data: ModelTestPreview,
+ session: AsyncSession = Depends(get_async_session),
+ user: User = Depends(current_active_user),
+):
+ if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
+ await check_permission(
+ session,
+ user,
+ data.search_space_id,
+ Permission.LLM_CONFIGS_CREATE.value,
+ "You don't have permission to create model connections in this search space",
+ )
+
+ model_id = data.model_id.strip()
+ if not model_id:
+ raise HTTPException(status_code=400, detail="model_id is required")
+
+ draft = Connection(
+ provider=data.provider,
+ base_url=data.base_url,
+ api_key=data.api_key,
+ extra=data.extra or {},
+ scope=data.scope,
+ enabled=data.enabled,
+ search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None,
+ user_id=user.id,
+ )
+ model = Model(
+ connection_id=0,
+ model_id=model_id,
+ source=ModelSource.MANUAL,
+ enabled=True,
+ capabilities_override={},
+ catalog={},
+ )
+ result = await test_model(draft, model)
+ return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message)
+
+
@router.put("/model-connections/{connection_id}", response_model=ConnectionRead)
async def update_connection(
connection_id: int,
diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py
index efa448dcd..3c4fdfa83 100644
--- a/surfsense_backend/app/schemas/__init__.py
+++ b/surfsense_backend/app/schemas/__init__.py
@@ -54,8 +54,9 @@ from .model_connections import (
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
- ModelSelection,
ModelsBulkUpdate,
+ ModelSelection,
+ ModelTestPreview,
ModelUpdate,
VerifyConnectionResponse,
)
@@ -149,7 +150,7 @@ from .vision_llm import (
VisionLLMConfigUpdate,
)
-__all__ = [
+__all__ = [
# Folder schemas
"BulkDocumentMove",
# Chat schemas (assistant-ui integration)
@@ -159,6 +160,10 @@ __all__ = [
"ChunkCreate",
"ChunkRead",
"ChunkUpdate",
+ # Model connection schemas
+ "ConnectionCreate",
+ "ConnectionRead",
+ "ConnectionUpdate",
"CreateCreditCheckoutSessionRequest",
"CreateCreditCheckoutSessionResponse",
"CreditPurchaseHistoryResponse",
@@ -232,6 +237,16 @@ __all__ = [
"MembershipRead",
"MembershipReadWithUser",
"MembershipUpdate",
+ "ModelCreate",
+ "ModelPreviewRead",
+ "ModelProviderRead",
+ "ModelRead",
+ "ModelRolesRead",
+ "ModelRolesUpdate",
+ "ModelSelection",
+ "ModelTestPreview",
+ "ModelUpdate",
+ "ModelsBulkUpdate",
"NewChatMessageAppend",
"NewChatMessageCreate",
"NewChatMessageRead",
@@ -282,6 +297,7 @@ __all__ = [
"UserRead",
"UserSearchSpaceAccess",
"UserUpdate",
+ "VerifyConnectionResponse",
# Video Presentation schemas
"VideoPresentationBase",
"VideoPresentationCreate",
diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py
index 896532d6f..67d94f821 100644
--- a/surfsense_backend/app/schemas/model_connections.py
+++ b/surfsense_backend/app/schemas/model_connections.py
@@ -85,6 +85,10 @@ class ConnectionCreate(BaseModel):
models: list[ModelSelection] = Field(default_factory=list)
+class ModelTestPreview(ConnectionCreate):
+ model_id: str = Field(..., max_length=255)
+
+
class ConnectionUpdate(BaseModel):
provider: str | None = Field(None, max_length=100)
base_url: str | None = Field(None, max_length=500)
diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py
index c9ee2779f..fbfdd437f 100644
--- a/surfsense_backend/app/services/model_connection_service.py
+++ b/surfsense_backend/app/services/model_connection_service.py
@@ -15,7 +15,7 @@ import litellm
from app.db import Connection, Model, ModelSource
from app.services.model_resolver import ensure_v1, to_litellm
from app.services.openrouter_model_normalizer import normalize_openrouter_models
-from app.services.provider_registry import Transport, spec_for
+from app.services.provider_registry import Transport, provider_label, spec_for
logger = logging.getLogger(__name__)
@@ -77,6 +77,68 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str:
return raw
+def _model_test_error(conn: Connection, model_id: str, exc: Exception) -> VerifyResult:
+ provider_name = provider_label(conn.provider)
+ raw = str(exc)
+ normalized = raw.lower()
+ exc_name = exc.__class__.__name__.lower()
+ status_code = getattr(exc, "status_code", None)
+
+ logger.info(
+ "Model test failed for provider=%s model=%s: %s",
+ conn.provider,
+ model_id,
+ raw,
+ )
+
+ if status_code in (401, 403) or "authentication" in exc_name or "401" in normalized:
+ return VerifyResult(
+ "AUTH_FAILED",
+ False,
+ f"Authentication failed. Check your {provider_name} credentials and try again.",
+ )
+
+ if status_code == 404 or "notfound" in exc_name or "not found" in normalized:
+ if conn.provider == "azure":
+ message = (
+ "Azure OpenAI deployment was not found. Check the deployment name, "
+ "API version, and endpoint."
+ )
+ else:
+ message = f"Model '{model_id}' was not found on {provider_name}."
+ return VerifyResult("NOT_FOUND", False, message)
+
+ if status_code == 429 or "ratelimit" in exc_name or "rate limit" in normalized:
+ return VerifyResult(
+ "RATE_LIMITED",
+ False,
+ f"{provider_name} rate limited the model test. Try again later.",
+ )
+
+ if "timeout" in exc_name or "timed out" in normalized:
+ return VerifyResult(
+ "TIMEOUT",
+ False,
+ f"{provider_name} did not respond in time. Check the endpoint and try again.",
+ )
+
+ if "connection" in exc_name or "connect" in normalized:
+ return VerifyResult(
+ "UNREACHABLE",
+ False,
+ _docker_hint(
+ _base_url_or_default(conn),
+ f"Could not reach {provider_name}. Check the endpoint and try again.",
+ ),
+ )
+
+ return VerifyResult(
+ "UNREACHABLE",
+ False,
+ f"Could not test model '{model_id}' on {provider_name}. Check the credentials, endpoint, and model name.",
+ )
+
+
async def verify_connection(conn: Connection) -> VerifyResult:
spec = spec_for(conn.provider)
base_url = _base_url_or_default(conn)
@@ -321,15 +383,24 @@ async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]:
return []
def list_models() -> list[dict[str, Any]]:
+ import os
+
import boto3
- client_kwargs: dict[str, str] = {"region_name": region_name}
- if params.get("aws_access_key_id"):
- client_kwargs["aws_access_key_id"] = params["aws_access_key_id"]
- if params.get("aws_secret_access_key"):
- client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"]
+ if bearer_token := params.get("aws_bearer_token_bedrock"):
+ try:
+ os.environ["AWS_BEARER_TOKEN_BEDROCK"] = bearer_token
+ client = boto3.client("bedrock", region_name=region_name)
+ finally:
+ os.environ.pop("AWS_BEARER_TOKEN_BEDROCK", None)
+ else:
+ client_kwargs: dict[str, str] = {"region_name": region_name}
+ if params.get("aws_access_key_id"):
+ client_kwargs["aws_access_key_id"] = params["aws_access_key_id"]
+ if params.get("aws_secret_access_key"):
+ client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"]
+ client = boto3.client("bedrock", **client_kwargs)
- client = boto3.client("bedrock", **client_kwargs)
response = client.list_foundation_models()
results: list[dict[str, Any]] = []
for item in response.get("modelSummaries", []):
@@ -393,7 +464,7 @@ async def test_model(conn: Connection, model: Model) -> VerifyResult:
**kwargs,
)
except Exception as exc:
- return VerifyResult("UNREACHABLE", False, str(exc))
+ return _model_test_error(conn, model.model_id, exc)
model.supports_chat = True
return VerifyResult("OK", True, "Model test succeeded.")
diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py
index ae6fd2877..599762824 100644
--- a/surfsense_backend/app/services/model_resolver.py
+++ b/surfsense_backend/app/services/model_resolver.py
@@ -55,6 +55,8 @@ def to_litellm(
kwargs["api_version"] = api_version
kwargs.update(extra.get("litellm_params", {}))
kwargs.update(extra.get("kwargs", {}))
+ if provider == "bedrock" and (bearer_token := kwargs.pop("aws_bearer_token_bedrock", None)):
+ kwargs["api_key"] = bearer_token
return model_string, kwargs
diff --git a/surfsense_backend/app/services/provider_registry.py b/surfsense_backend/app/services/provider_registry.py
index 98bfb63c1..2a58a3468 100644
--- a/surfsense_backend/app/services/provider_registry.py
+++ b/surfsense_backend/app/services/provider_registry.py
@@ -38,21 +38,24 @@ class ProviderSpec:
default_base_url: str | None
base_url_required: bool
auth_style: AuthStyle
+ display_name: str | None = None
REGISTRY: dict[str, ProviderSpec] = {
"openai": ProviderSpec(
- Transport.NATIVE, "openai", "openai_models", None, False, "bearer"
+ Transport.NATIVE, "openai", "openai_models", None, False, "bearer", "OpenAI"
),
"anthropic": ProviderSpec(
- Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key"
+ Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key", "Anthropic"
+ ),
+ "azure": ProviderSpec(
+ Transport.NATIVE, "azure", "static", None, True, "native", "Azure OpenAI"
),
- "azure": ProviderSpec(Transport.NATIVE, "azure", "static", None, True, "native"),
"vertex_ai": ProviderSpec(
- Transport.NATIVE, "vertex_ai", "static", None, False, "native"
+ Transport.NATIVE, "vertex_ai", "static", None, False, "native", "Vertex AI"
),
"bedrock": ProviderSpec(
- Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native"
+ Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native", "Amazon Bedrock"
),
"openrouter": ProviderSpec(
Transport.OPENAI_COMPATIBLE,
@@ -61,6 +64,7 @@ REGISTRY: dict[str, ProviderSpec] = {
"https://openrouter.ai/api/v1",
False,
"bearer",
+ "OpenRouter",
),
"openai_compatible": ProviderSpec(
Transport.OPENAI_COMPATIBLE,
@@ -69,6 +73,7 @@ REGISTRY: dict[str, ProviderSpec] = {
None,
True,
"bearer",
+ "OpenAI-compatible provider",
),
"lm_studio": ProviderSpec(
Transport.OPENAI_COMPATIBLE,
@@ -77,6 +82,7 @@ REGISTRY: dict[str, ProviderSpec] = {
"http://localhost:1234/v1",
True,
"bearer",
+ "LM Studio",
),
"ollama_chat": ProviderSpec(
Transport.OLLAMA,
@@ -85,6 +91,7 @@ REGISTRY: dict[str, ProviderSpec] = {
"http://localhost:11434",
True,
"none",
+ "Ollama",
),
}
@@ -96,4 +103,12 @@ def spec_for(provider: str | None) -> ProviderSpec:
)
-__all__ = ["REGISTRY", "ProviderSpec", "Transport", "spec_for"]
+def provider_label(provider: str | None) -> str:
+ provider_key = (provider or "").strip()
+ spec = spec_for(provider_key)
+ if spec.display_name:
+ return spec.display_name
+ return provider_key.replace("_", " ").title() if provider_key else "Provider"
+
+
+__all__ = ["REGISTRY", "ProviderSpec", "Transport", "provider_label", "spec_for"]
diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts
index ea91c6483..00f8fa9ad 100644
--- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts
+++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts
@@ -7,6 +7,7 @@ import type {
ModelPreviewRead,
ModelRead,
ModelRoles,
+ ModelTestPreviewRequest,
ModelsBulkUpdateRequest,
ModelUpdateRequest,
VerifyConnectionResponse,
@@ -114,6 +115,20 @@ export const previewConnectionModelsMutationAtom = atomWithMutation(() => {
};
});
+export const testPreviewModelMutationAtom = atomWithMutation(() => {
+ return {
+ mutationKey: ["model-connections", "test-preview"],
+ mutationFn: (request: ModelTestPreviewRequest) =>
+ modelConnectionsApiService.testPreviewModel(request),
+ onSuccess: (result: VerifyConnectionResponse) => {
+ if (!result.ok) {
+ toast.error(result.message || "Model test failed");
+ }
+ },
+ onError: (error: Error) => toast.error(error.message || "Failed to test model"),
+ };
+});
+
export const addManualModelMutationAtom = atomWithMutation((get) => {
const searchSpaceId = Number(get(activeSearchSpaceIdAtom));
return {
diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx
index 6c3d1a411..cf00ac6c9 100644
--- a/surfsense_web/components/settings/model-connections-settings.tsx
+++ b/surfsense_web/components/settings/model-connections-settings.tsx
@@ -1,12 +1,14 @@
"use client";
import { useAtom, useAtomValue } from "jotai";
-import { CheckCircle2, Trash2, XCircle } from "lucide-react";
+import { Trash2 } from "lucide-react";
import { useState } from "react";
+import { toast } from "sonner";
import {
createModelConnectionMutationAtom,
deleteModelConnectionMutationAtom,
previewConnectionModelsMutationAtom,
+ testPreviewModelMutationAtom,
updateModelRolesMutationAtom,
} from "@/atoms/model-connections/model-connections-mutation.atoms";
import {
@@ -53,24 +55,6 @@ import {
providerIcon,
} from "./model-connections/provider-metadata";
-function StatusBadge({ connection }: { connection: ConnectionRead }) {
- if (connection.last_status === "OK") {
- return (
-
- {connection.last_error || "Could not list models."} Chat may still work; add model - IDs manually if discovery is unavailable. -
- ) : null}