feat(model-connections): add test preview functionality for model connections

This commit is contained in:
Anish Sarkar 2026-06-13 00:12:04 +05:30
parent 55f004e1da
commit 9f6210ad08
12 changed files with 294 additions and 77 deletions

View file

@ -26,11 +26,13 @@ from app.schemas import (
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
ModelSelection,
ModelsBulkUpdate,
ModelSelection,
ModelTestPreview,
ModelUpdate,
VerifyConnectionResponse,
)
from app.services.model_capabilities import has_capability
from app.services.model_connection_service import (
ModelDiscoveryError,
derive_capabilities,
@ -38,7 +40,6 @@ from app.services.model_connection_service import (
persist_verification,
test_model,
)
from app.services.model_capabilities import has_capability
from app.services.provider_registry import REGISTRY
from app.users import current_active_user
from app.utils.rbac import check_permission
@ -321,6 +322,47 @@ async def preview_connection_models(
return [_preview_model_read(item) for item in discovered]
@router.post("/model-connections/test-preview", response_model=VerifyConnectionResponse)
async def test_preview_connection_model(
data: ModelTestPreview,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
await check_permission(
session,
user,
data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
model_id = data.model_id.strip()
if not model_id:
raise HTTPException(status_code=400, detail="model_id is required")
draft = Connection(
provider=data.provider,
base_url=data.base_url,
api_key=data.api_key,
extra=data.extra or {},
scope=data.scope,
enabled=data.enabled,
search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None,
user_id=user.id,
)
model = Model(
connection_id=0,
model_id=model_id,
source=ModelSource.MANUAL,
enabled=True,
capabilities_override={},
catalog={},
)
result = await test_model(draft, model)
return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message)
@router.put("/model-connections/{connection_id}", response_model=ConnectionRead)
async def update_connection(
connection_id: int,

View file

@ -54,8 +54,9 @@ from .model_connections import (
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
ModelSelection,
ModelsBulkUpdate,
ModelSelection,
ModelTestPreview,
ModelUpdate,
VerifyConnectionResponse,
)
@ -149,7 +150,7 @@ from .vision_llm import (
VisionLLMConfigUpdate,
)
__all__ = [
__all__ = [
# Folder schemas
"BulkDocumentMove",
# Chat schemas (assistant-ui integration)
@ -159,6 +160,10 @@ __all__ = [
"ChunkCreate",
"ChunkRead",
"ChunkUpdate",
# Model connection schemas
"ConnectionCreate",
"ConnectionRead",
"ConnectionUpdate",
"CreateCreditCheckoutSessionRequest",
"CreateCreditCheckoutSessionResponse",
"CreditPurchaseHistoryResponse",
@ -232,6 +237,16 @@ __all__ = [
"MembershipRead",
"MembershipReadWithUser",
"MembershipUpdate",
"ModelCreate",
"ModelPreviewRead",
"ModelProviderRead",
"ModelRead",
"ModelRolesRead",
"ModelRolesUpdate",
"ModelSelection",
"ModelTestPreview",
"ModelUpdate",
"ModelsBulkUpdate",
"NewChatMessageAppend",
"NewChatMessageCreate",
"NewChatMessageRead",
@ -282,6 +297,7 @@ __all__ = [
"UserRead",
"UserSearchSpaceAccess",
"UserUpdate",
"VerifyConnectionResponse",
# Video Presentation schemas
"VideoPresentationBase",
"VideoPresentationCreate",

View file

@ -85,6 +85,10 @@ class ConnectionCreate(BaseModel):
models: list[ModelSelection] = Field(default_factory=list)
class ModelTestPreview(ConnectionCreate):
model_id: str = Field(..., max_length=255)
class ConnectionUpdate(BaseModel):
provider: str | None = Field(None, max_length=100)
base_url: str | None = Field(None, max_length=500)

View file

@ -15,7 +15,7 @@ import litellm
from app.db import Connection, Model, ModelSource
from app.services.model_resolver import ensure_v1, to_litellm
from app.services.openrouter_model_normalizer import normalize_openrouter_models
from app.services.provider_registry import Transport, spec_for
from app.services.provider_registry import Transport, provider_label, spec_for
logger = logging.getLogger(__name__)
@ -77,6 +77,68 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str:
return raw
def _model_test_error(conn: Connection, model_id: str, exc: Exception) -> VerifyResult:
provider_name = provider_label(conn.provider)
raw = str(exc)
normalized = raw.lower()
exc_name = exc.__class__.__name__.lower()
status_code = getattr(exc, "status_code", None)
logger.info(
"Model test failed for provider=%s model=%s: %s",
conn.provider,
model_id,
raw,
)
if status_code in (401, 403) or "authentication" in exc_name or "401" in normalized:
return VerifyResult(
"AUTH_FAILED",
False,
f"Authentication failed. Check your {provider_name} credentials and try again.",
)
if status_code == 404 or "notfound" in exc_name or "not found" in normalized:
if conn.provider == "azure":
message = (
"Azure OpenAI deployment was not found. Check the deployment name, "
"API version, and endpoint."
)
else:
message = f"Model '{model_id}' was not found on {provider_name}."
return VerifyResult("NOT_FOUND", False, message)
if status_code == 429 or "ratelimit" in exc_name or "rate limit" in normalized:
return VerifyResult(
"RATE_LIMITED",
False,
f"{provider_name} rate limited the model test. Try again later.",
)
if "timeout" in exc_name or "timed out" in normalized:
return VerifyResult(
"TIMEOUT",
False,
f"{provider_name} did not respond in time. Check the endpoint and try again.",
)
if "connection" in exc_name or "connect" in normalized:
return VerifyResult(
"UNREACHABLE",
False,
_docker_hint(
_base_url_or_default(conn),
f"Could not reach {provider_name}. Check the endpoint and try again.",
),
)
return VerifyResult(
"UNREACHABLE",
False,
f"Could not test model '{model_id}' on {provider_name}. Check the credentials, endpoint, and model name.",
)
async def verify_connection(conn: Connection) -> VerifyResult:
spec = spec_for(conn.provider)
base_url = _base_url_or_default(conn)
@ -321,15 +383,24 @@ async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]:
return []
def list_models() -> list[dict[str, Any]]:
import os
import boto3
client_kwargs: dict[str, str] = {"region_name": region_name}
if params.get("aws_access_key_id"):
client_kwargs["aws_access_key_id"] = params["aws_access_key_id"]
if params.get("aws_secret_access_key"):
client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"]
if bearer_token := params.get("aws_bearer_token_bedrock"):
try:
os.environ["AWS_BEARER_TOKEN_BEDROCK"] = bearer_token
client = boto3.client("bedrock", region_name=region_name)
finally:
os.environ.pop("AWS_BEARER_TOKEN_BEDROCK", None)
else:
client_kwargs: dict[str, str] = {"region_name": region_name}
if params.get("aws_access_key_id"):
client_kwargs["aws_access_key_id"] = params["aws_access_key_id"]
if params.get("aws_secret_access_key"):
client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"]
client = boto3.client("bedrock", **client_kwargs)
client = boto3.client("bedrock", **client_kwargs)
response = client.list_foundation_models()
results: list[dict[str, Any]] = []
for item in response.get("modelSummaries", []):
@ -393,7 +464,7 @@ async def test_model(conn: Connection, model: Model) -> VerifyResult:
**kwargs,
)
except Exception as exc:
return VerifyResult("UNREACHABLE", False, str(exc))
return _model_test_error(conn, model.model_id, exc)
model.supports_chat = True
return VerifyResult("OK", True, "Model test succeeded.")

View file

@ -55,6 +55,8 @@ def to_litellm(
kwargs["api_version"] = api_version
kwargs.update(extra.get("litellm_params", {}))
kwargs.update(extra.get("kwargs", {}))
if provider == "bedrock" and (bearer_token := kwargs.pop("aws_bearer_token_bedrock", None)):
kwargs["api_key"] = bearer_token
return model_string, kwargs

View file

@ -38,21 +38,24 @@ class ProviderSpec:
default_base_url: str | None
base_url_required: bool
auth_style: AuthStyle
display_name: str | None = None
REGISTRY: dict[str, ProviderSpec] = {
"openai": ProviderSpec(
Transport.NATIVE, "openai", "openai_models", None, False, "bearer"
Transport.NATIVE, "openai", "openai_models", None, False, "bearer", "OpenAI"
),
"anthropic": ProviderSpec(
Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key"
Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key", "Anthropic"
),
"azure": ProviderSpec(
Transport.NATIVE, "azure", "static", None, True, "native", "Azure OpenAI"
),
"azure": ProviderSpec(Transport.NATIVE, "azure", "static", None, True, "native"),
"vertex_ai": ProviderSpec(
Transport.NATIVE, "vertex_ai", "static", None, False, "native"
Transport.NATIVE, "vertex_ai", "static", None, False, "native", "Vertex AI"
),
"bedrock": ProviderSpec(
Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native"
Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native", "Amazon Bedrock"
),
"openrouter": ProviderSpec(
Transport.OPENAI_COMPATIBLE,
@ -61,6 +64,7 @@ REGISTRY: dict[str, ProviderSpec] = {
"https://openrouter.ai/api/v1",
False,
"bearer",
"OpenRouter",
),
"openai_compatible": ProviderSpec(
Transport.OPENAI_COMPATIBLE,
@ -69,6 +73,7 @@ REGISTRY: dict[str, ProviderSpec] = {
None,
True,
"bearer",
"OpenAI-compatible provider",
),
"lm_studio": ProviderSpec(
Transport.OPENAI_COMPATIBLE,
@ -77,6 +82,7 @@ REGISTRY: dict[str, ProviderSpec] = {
"http://localhost:1234/v1",
True,
"bearer",
"LM Studio",
),
"ollama_chat": ProviderSpec(
Transport.OLLAMA,
@ -85,6 +91,7 @@ REGISTRY: dict[str, ProviderSpec] = {
"http://localhost:11434",
True,
"none",
"Ollama",
),
}
@ -96,4 +103,12 @@ def spec_for(provider: str | None) -> ProviderSpec:
)
__all__ = ["REGISTRY", "ProviderSpec", "Transport", "spec_for"]
def provider_label(provider: str | None) -> str:
provider_key = (provider or "").strip()
spec = spec_for(provider_key)
if spec.display_name:
return spec.display_name
return provider_key.replace("_", " ").title() if provider_key else "Provider"
__all__ = ["REGISTRY", "ProviderSpec", "Transport", "provider_label", "spec_for"]

View file

@ -7,6 +7,7 @@ import type {
ModelPreviewRead,
ModelRead,
ModelRoles,
ModelTestPreviewRequest,
ModelsBulkUpdateRequest,
ModelUpdateRequest,
VerifyConnectionResponse,
@ -114,6 +115,20 @@ export const previewConnectionModelsMutationAtom = atomWithMutation(() => {
};
});
export const testPreviewModelMutationAtom = atomWithMutation(() => {
return {
mutationKey: ["model-connections", "test-preview"],
mutationFn: (request: ModelTestPreviewRequest) =>
modelConnectionsApiService.testPreviewModel(request),
onSuccess: (result: VerifyConnectionResponse) => {
if (!result.ok) {
toast.error(result.message || "Model test failed");
}
},
onError: (error: Error) => toast.error(error.message || "Failed to test model"),
};
});
export const addManualModelMutationAtom = atomWithMutation((get) => {
const searchSpaceId = Number(get(activeSearchSpaceIdAtom));
return {

View file

@ -1,12 +1,14 @@
"use client";
import { useAtom, useAtomValue } from "jotai";
import { CheckCircle2, Trash2, XCircle } from "lucide-react";
import { Trash2 } from "lucide-react";
import { useState } from "react";
import { toast } from "sonner";
import {
createModelConnectionMutationAtom,
deleteModelConnectionMutationAtom,
previewConnectionModelsMutationAtom,
testPreviewModelMutationAtom,
updateModelRolesMutationAtom,
} from "@/atoms/model-connections/model-connections-mutation.atoms";
import {
@ -53,24 +55,6 @@ import {
providerIcon,
} from "./model-connections/provider-metadata";
function StatusBadge({ connection }: { connection: ConnectionRead }) {
if (connection.last_status === "OK") {
return (
<Badge variant="outline" className="gap-1 text-green-600">
<CheckCircle2 className="h-3 w-3" /> Healthy
</Badge>
);
}
if (connection.last_status) {
return (
<Badge variant="outline" className="gap-1 text-destructive">
<XCircle className="h-3 w-3" /> {connection.last_status}
</Badge>
);
}
return <Badge variant="secondary">Not tested</Badge>;
}
function flattenModels(connections: ConnectionRead[]) {
return connections.flatMap((connection) =>
connection.models.map((model) => ({
@ -110,7 +94,6 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
</div>
</div>
<div className="flex shrink-0 items-center gap-2">
<StatusBadge connection={connection} />
<ConnectionSettingsDialog connection={connection} providerLabel={providerLabel} />
<AlertDialog>
<AlertDialogTrigger asChild>
@ -156,6 +139,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
const [{ data: roles }] = useAtom(modelRolesAtom);
const createConnection = useAtomValue(createModelConnectionMutationAtom);
const previewModels = useAtomValue(previewConnectionModelsMutationAtom);
const testPreviewModel = useAtomValue(testPreviewModelMutationAtom);
const updateRoles = useAtomValue(updateModelRolesMutationAtom);
const [isAddProviderOpen, setIsAddProviderOpen] = useState(false);
@ -220,9 +204,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
});
}
// Each provider connect form builds its own credential payload; the backend
// resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM.
function handleCreate(draft: ConnectionDraft) {
function connectionModelsForDraft(draft: ConnectionDraft) {
const models = [...connectModels];
if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) {
models.push({
@ -233,22 +215,46 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
metadata: {},
});
}
return models;
}
createConnection.mutate(
function representativeTestModel(models: ModelSelection[]) {
const enabledModels = models.filter((model) => model.enabled);
return enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0];
}
// Each provider connect form builds its own credential payload; the backend
// resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM.
function handleCreate(draft: ConnectionDraft) {
const models = connectionModelsForDraft(draft);
const testModel = representativeTestModel(models);
if (!testModel) {
toast.error("Select at least one model before connecting");
return;
}
const request = {
provider,
base_url: draft.base_url,
api_key: draft.api_key,
scope: "SEARCH_SPACE" as const,
search_space_id: searchSpaceId,
extra: draft.extra,
enabled: true,
models,
};
testPreviewModel.mutate(
{ ...request, model_id: testModel.model_id },
{
provider,
base_url: draft.base_url,
api_key: draft.api_key,
scope: "SEARCH_SPACE",
search_space_id: searchSpaceId,
extra: draft.extra,
enabled: true,
models,
},
{
onSuccess: () => {
setIsAddProviderOpen(false);
resetConnectState();
onSuccess: (result) => {
if (!result.ok) return;
createConnection.mutate(request, {
onSuccess: () => {
setIsAddProviderOpen(false);
resetConnectState();
},
});
},
}
);
@ -380,7 +386,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
onOpenChange={handleConnectOpenChange}
provider={provider}
selectedProvider={selectedProvider}
isPending={createConnection.isPending}
isPending={createConnection.isPending || testPreviewModel.isPending}
onSubmit={handleCreate}
previewModels={connectModels}
isPreviewingModels={previewModels.isPending}

View file

@ -5,9 +5,9 @@ import {
addManualModelMutationAtom,
bulkUpdateModelsMutationAtom,
discoverConnectionModelsMutationAtom,
testPreviewModelMutationAtom,
updateModelConnectionMutationAtom,
updateModelMutationAtom,
verifyModelConnectionMutationAtom,
} from "@/atoms/model-connections/model-connections-mutation.atoms";
import { Button } from "@/components/ui/button";
import {
@ -26,7 +26,7 @@ import type {
ConnectionRead,
ConnectionUpdateRequest,
} from "@/contracts/types/model-connections.types";
import type { SelectableModel } from "./model-utils";
import { capability, type SelectableModel } from "./model-utils";
import { ModelsSelectionPanel } from "./models-selection-panel";
import { providerIcon } from "./provider-metadata";
@ -39,8 +39,8 @@ export function ConnectionSettingsDialog({
connection,
providerLabel,
}: ConnectionSettingsDialogProps) {
const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom);
const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom);
const testPreviewModel = useAtomValue(testPreviewModelMutationAtom);
const updateConnection = useAtomValue(updateModelConnectionMutationAtom);
const addManualModel = useAtomValue(addManualModelMutationAtom);
const updateModel = useAtomValue(updateModelMutationAtom);
@ -81,11 +81,45 @@ export function ConnectionSettingsDialog({
if (apiKeyDraft.trim() !== (connection.api_key ?? "")) {
data.api_key = apiKeyDraft.trim() || null;
}
const apiKeyForTest = Object.hasOwn(data, "api_key")
? (data.api_key ?? null)
: (connection.api_key ?? null);
updateConnection.mutate(
{ id: connection.id, data },
const enabledModels = connection.models.filter((model) => model.enabled);
const testModel =
enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0];
if (!testModel) {
updateConnection.mutate(
{ id: connection.id, data },
{
onSuccess: () => setApiKeyDraft(""),
}
);
return;
}
testPreviewModel.mutate(
{
onSuccess: () => setApiKeyDraft(""),
provider: connection.provider,
base_url: data.base_url,
api_key: apiKeyForTest,
scope: "SEARCH_SPACE",
search_space_id: connection.search_space_id,
extra: connection.extra ?? {},
enabled: connection.enabled,
models: [],
model_id: testModel.model_id,
},
{
onSuccess: (result) => {
if (!result.ok) return;
updateConnection.mutate(
{ id: connection.id, data },
{
onSuccess: () => setApiKeyDraft(""),
}
);
},
}
);
}
@ -219,26 +253,15 @@ export function ConnectionSettingsDialog({
onBulkToggle={handleBulkToggle}
/>
{connection.last_status && connection.last_status !== "OK" ? (
<p className="rounded-lg bg-amber-500/10 px-3 py-2 text-sm text-amber-600 dark:text-amber-500">
{connection.last_error || "Could not list models."} Chat may still work; add model
IDs manually if discovery is unavailable.
</p>
) : null}
</div>
</div>
<DialogFooter className="shrink-0 border-t bg-popover px-6 py-4">
<Button
variant="secondary"
onClick={() => verifyConnection.mutate(connection.id)}
disabled={verifyConnection.isPending}
>
Test
</Button>
<Button
onClick={saveConnectionSettings}
disabled={updateConnection.isPending || !hasConnectionChanges}
disabled={
updateConnection.isPending || testPreviewModel.isPending || !hasConnectionChanges
}
>
Update
</Button>

View file

@ -94,6 +94,8 @@ export function ProviderConnectDialog({
})();
const canRefreshModels = !isAzure && !isVertex && (!isBedrock || canSubmit);
const hasEnabledModel = previewModels.some((model) => model.enabled) || Boolean(currentDraft.seedModelId);
const canConnect = canSubmit && hasEnabledModel;
return (
<Dialog open={open} onOpenChange={onOpenChange}>
@ -134,7 +136,7 @@ export function ProviderConnectDialog({
<ConnectFormFooter
onCancel={() => onOpenChange(false)}
onSubmit={() => onSubmit(currentDraft)}
canSubmit={canSubmit}
canSubmit={canConnect}
isPending={isPending}
/>
</DialogContent>

View file

@ -66,6 +66,10 @@ export const connectionCreateRequest = z.object({
models: z.array(modelSelection).default([]),
});
export const modelTestPreviewRequest = connectionCreateRequest.extend({
model_id: z.string().min(1),
});
export const connectionUpdateRequest = z.object({
provider: z.string().nullable().optional(),
base_url: z.string().nullable().optional(),
@ -129,6 +133,7 @@ export type ModelPreviewRead = z.infer<typeof modelPreviewRead>;
export type ModelSelection = z.infer<typeof modelSelection>;
export type ConnectionRead = z.infer<typeof connectionRead>;
export type ConnectionCreateRequest = z.infer<typeof connectionCreateRequest>;
export type ModelTestPreviewRequest = z.infer<typeof modelTestPreviewRequest>;
export type ConnectionUpdateRequest = z.infer<typeof connectionUpdateRequest>;
export type ModelCreateRequest = z.infer<typeof modelCreateRequest>;
export type ModelUpdateRequest = z.infer<typeof modelUpdateRequest>;

View file

@ -11,6 +11,7 @@ import {
type ModelProviderRead,
type ModelRead,
type ModelRoles,
type ModelTestPreviewRequest,
type ModelsBulkUpdateRequest,
type ModelUpdateRequest,
modelCreateRequest,
@ -19,6 +20,7 @@ import {
modelProviderListResponse,
modelRead,
modelRoles,
modelTestPreviewRequest,
modelsBulkUpdateRequest,
modelUpdateRequest,
type VerifyConnectionResponse,
@ -92,6 +94,20 @@ class ModelConnectionsApiService {
);
};
testPreviewModel = async (request: ModelTestPreviewRequest): Promise<VerifyConnectionResponse> => {
const parsed = modelTestPreviewRequest.safeParse(request);
if (!parsed.success) {
throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", "));
}
return baseApiService.post(
`/api/v1/model-connections/test-preview`,
verifyConnectionResponse,
{
body: parsed.data,
}
);
};
addManualModel = async (
connectionId: number,
request: ModelCreateRequest