mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-12 20:45:20 +02:00
feat(model-connections): add test preview functionality for model connections
This commit is contained in:
parent
55f004e1da
commit
9f6210ad08
12 changed files with 294 additions and 77 deletions
|
|
@ -26,11 +26,13 @@ from app.schemas import (
|
|||
ModelRead,
|
||||
ModelRolesRead,
|
||||
ModelRolesUpdate,
|
||||
ModelSelection,
|
||||
ModelsBulkUpdate,
|
||||
ModelSelection,
|
||||
ModelTestPreview,
|
||||
ModelUpdate,
|
||||
VerifyConnectionResponse,
|
||||
)
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.model_connection_service import (
|
||||
ModelDiscoveryError,
|
||||
derive_capabilities,
|
||||
|
|
@ -38,7 +40,6 @@ from app.services.model_connection_service import (
|
|||
persist_verification,
|
||||
test_model,
|
||||
)
|
||||
from app.services.model_capabilities import has_capability
|
||||
from app.services.provider_registry import REGISTRY
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
@ -321,6 +322,47 @@ async def preview_connection_models(
|
|||
return [_preview_model_read(item) for item in discovered]
|
||||
|
||||
|
||||
@router.post("/model-connections/test-preview", response_model=VerifyConnectionResponse)
|
||||
async def test_preview_connection_model(
|
||||
data: ModelTestPreview,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create model connections in this search space",
|
||||
)
|
||||
|
||||
model_id = data.model_id.strip()
|
||||
if not model_id:
|
||||
raise HTTPException(status_code=400, detail="model_id is required")
|
||||
|
||||
draft = Connection(
|
||||
provider=data.provider,
|
||||
base_url=data.base_url,
|
||||
api_key=data.api_key,
|
||||
extra=data.extra or {},
|
||||
scope=data.scope,
|
||||
enabled=data.enabled,
|
||||
search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None,
|
||||
user_id=user.id,
|
||||
)
|
||||
model = Model(
|
||||
connection_id=0,
|
||||
model_id=model_id,
|
||||
source=ModelSource.MANUAL,
|
||||
enabled=True,
|
||||
capabilities_override={},
|
||||
catalog={},
|
||||
)
|
||||
result = await test_model(draft, model)
|
||||
return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message)
|
||||
|
||||
|
||||
@router.put("/model-connections/{connection_id}", response_model=ConnectionRead)
|
||||
async def update_connection(
|
||||
connection_id: int,
|
||||
|
|
|
|||
|
|
@ -54,8 +54,9 @@ from .model_connections import (
|
|||
ModelRead,
|
||||
ModelRolesRead,
|
||||
ModelRolesUpdate,
|
||||
ModelSelection,
|
||||
ModelsBulkUpdate,
|
||||
ModelSelection,
|
||||
ModelTestPreview,
|
||||
ModelUpdate,
|
||||
VerifyConnectionResponse,
|
||||
)
|
||||
|
|
@ -149,7 +150,7 @@ from .vision_llm import (
|
|||
VisionLLMConfigUpdate,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
__all__ = [
|
||||
# Folder schemas
|
||||
"BulkDocumentMove",
|
||||
# Chat schemas (assistant-ui integration)
|
||||
|
|
@ -159,6 +160,10 @@ __all__ = [
|
|||
"ChunkCreate",
|
||||
"ChunkRead",
|
||||
"ChunkUpdate",
|
||||
# Model connection schemas
|
||||
"ConnectionCreate",
|
||||
"ConnectionRead",
|
||||
"ConnectionUpdate",
|
||||
"CreateCreditCheckoutSessionRequest",
|
||||
"CreateCreditCheckoutSessionResponse",
|
||||
"CreditPurchaseHistoryResponse",
|
||||
|
|
@ -232,6 +237,16 @@ __all__ = [
|
|||
"MembershipRead",
|
||||
"MembershipReadWithUser",
|
||||
"MembershipUpdate",
|
||||
"ModelCreate",
|
||||
"ModelPreviewRead",
|
||||
"ModelProviderRead",
|
||||
"ModelRead",
|
||||
"ModelRolesRead",
|
||||
"ModelRolesUpdate",
|
||||
"ModelSelection",
|
||||
"ModelTestPreview",
|
||||
"ModelUpdate",
|
||||
"ModelsBulkUpdate",
|
||||
"NewChatMessageAppend",
|
||||
"NewChatMessageCreate",
|
||||
"NewChatMessageRead",
|
||||
|
|
@ -282,6 +297,7 @@ __all__ = [
|
|||
"UserRead",
|
||||
"UserSearchSpaceAccess",
|
||||
"UserUpdate",
|
||||
"VerifyConnectionResponse",
|
||||
# Video Presentation schemas
|
||||
"VideoPresentationBase",
|
||||
"VideoPresentationCreate",
|
||||
|
|
|
|||
|
|
@ -85,6 +85,10 @@ class ConnectionCreate(BaseModel):
|
|||
models: list[ModelSelection] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ModelTestPreview(ConnectionCreate):
|
||||
model_id: str = Field(..., max_length=255)
|
||||
|
||||
|
||||
class ConnectionUpdate(BaseModel):
|
||||
provider: str | None = Field(None, max_length=100)
|
||||
base_url: str | None = Field(None, max_length=500)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import litellm
|
|||
from app.db import Connection, Model, ModelSource
|
||||
from app.services.model_resolver import ensure_v1, to_litellm
|
||||
from app.services.openrouter_model_normalizer import normalize_openrouter_models
|
||||
from app.services.provider_registry import Transport, spec_for
|
||||
from app.services.provider_registry import Transport, provider_label, spec_for
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -77,6 +77,68 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str:
|
|||
return raw
|
||||
|
||||
|
||||
def _model_test_error(conn: Connection, model_id: str, exc: Exception) -> VerifyResult:
|
||||
provider_name = provider_label(conn.provider)
|
||||
raw = str(exc)
|
||||
normalized = raw.lower()
|
||||
exc_name = exc.__class__.__name__.lower()
|
||||
status_code = getattr(exc, "status_code", None)
|
||||
|
||||
logger.info(
|
||||
"Model test failed for provider=%s model=%s: %s",
|
||||
conn.provider,
|
||||
model_id,
|
||||
raw,
|
||||
)
|
||||
|
||||
if status_code in (401, 403) or "authentication" in exc_name or "401" in normalized:
|
||||
return VerifyResult(
|
||||
"AUTH_FAILED",
|
||||
False,
|
||||
f"Authentication failed. Check your {provider_name} credentials and try again.",
|
||||
)
|
||||
|
||||
if status_code == 404 or "notfound" in exc_name or "not found" in normalized:
|
||||
if conn.provider == "azure":
|
||||
message = (
|
||||
"Azure OpenAI deployment was not found. Check the deployment name, "
|
||||
"API version, and endpoint."
|
||||
)
|
||||
else:
|
||||
message = f"Model '{model_id}' was not found on {provider_name}."
|
||||
return VerifyResult("NOT_FOUND", False, message)
|
||||
|
||||
if status_code == 429 or "ratelimit" in exc_name or "rate limit" in normalized:
|
||||
return VerifyResult(
|
||||
"RATE_LIMITED",
|
||||
False,
|
||||
f"{provider_name} rate limited the model test. Try again later.",
|
||||
)
|
||||
|
||||
if "timeout" in exc_name or "timed out" in normalized:
|
||||
return VerifyResult(
|
||||
"TIMEOUT",
|
||||
False,
|
||||
f"{provider_name} did not respond in time. Check the endpoint and try again.",
|
||||
)
|
||||
|
||||
if "connection" in exc_name or "connect" in normalized:
|
||||
return VerifyResult(
|
||||
"UNREACHABLE",
|
||||
False,
|
||||
_docker_hint(
|
||||
_base_url_or_default(conn),
|
||||
f"Could not reach {provider_name}. Check the endpoint and try again.",
|
||||
),
|
||||
)
|
||||
|
||||
return VerifyResult(
|
||||
"UNREACHABLE",
|
||||
False,
|
||||
f"Could not test model '{model_id}' on {provider_name}. Check the credentials, endpoint, and model name.",
|
||||
)
|
||||
|
||||
|
||||
async def verify_connection(conn: Connection) -> VerifyResult:
|
||||
spec = spec_for(conn.provider)
|
||||
base_url = _base_url_or_default(conn)
|
||||
|
|
@ -321,15 +383,24 @@ async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]:
|
|||
return []
|
||||
|
||||
def list_models() -> list[dict[str, Any]]:
|
||||
import os
|
||||
|
||||
import boto3
|
||||
|
||||
client_kwargs: dict[str, str] = {"region_name": region_name}
|
||||
if params.get("aws_access_key_id"):
|
||||
client_kwargs["aws_access_key_id"] = params["aws_access_key_id"]
|
||||
if params.get("aws_secret_access_key"):
|
||||
client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"]
|
||||
if bearer_token := params.get("aws_bearer_token_bedrock"):
|
||||
try:
|
||||
os.environ["AWS_BEARER_TOKEN_BEDROCK"] = bearer_token
|
||||
client = boto3.client("bedrock", region_name=region_name)
|
||||
finally:
|
||||
os.environ.pop("AWS_BEARER_TOKEN_BEDROCK", None)
|
||||
else:
|
||||
client_kwargs: dict[str, str] = {"region_name": region_name}
|
||||
if params.get("aws_access_key_id"):
|
||||
client_kwargs["aws_access_key_id"] = params["aws_access_key_id"]
|
||||
if params.get("aws_secret_access_key"):
|
||||
client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"]
|
||||
client = boto3.client("bedrock", **client_kwargs)
|
||||
|
||||
client = boto3.client("bedrock", **client_kwargs)
|
||||
response = client.list_foundation_models()
|
||||
results: list[dict[str, Any]] = []
|
||||
for item in response.get("modelSummaries", []):
|
||||
|
|
@ -393,7 +464,7 @@ async def test_model(conn: Connection, model: Model) -> VerifyResult:
|
|||
**kwargs,
|
||||
)
|
||||
except Exception as exc:
|
||||
return VerifyResult("UNREACHABLE", False, str(exc))
|
||||
return _model_test_error(conn, model.model_id, exc)
|
||||
|
||||
model.supports_chat = True
|
||||
return VerifyResult("OK", True, "Model test succeeded.")
|
||||
|
|
|
|||
|
|
@ -55,6 +55,8 @@ def to_litellm(
|
|||
kwargs["api_version"] = api_version
|
||||
kwargs.update(extra.get("litellm_params", {}))
|
||||
kwargs.update(extra.get("kwargs", {}))
|
||||
if provider == "bedrock" and (bearer_token := kwargs.pop("aws_bearer_token_bedrock", None)):
|
||||
kwargs["api_key"] = bearer_token
|
||||
return model_string, kwargs
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -38,21 +38,24 @@ class ProviderSpec:
|
|||
default_base_url: str | None
|
||||
base_url_required: bool
|
||||
auth_style: AuthStyle
|
||||
display_name: str | None = None
|
||||
|
||||
|
||||
REGISTRY: dict[str, ProviderSpec] = {
|
||||
"openai": ProviderSpec(
|
||||
Transport.NATIVE, "openai", "openai_models", None, False, "bearer"
|
||||
Transport.NATIVE, "openai", "openai_models", None, False, "bearer", "OpenAI"
|
||||
),
|
||||
"anthropic": ProviderSpec(
|
||||
Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key"
|
||||
Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key", "Anthropic"
|
||||
),
|
||||
"azure": ProviderSpec(
|
||||
Transport.NATIVE, "azure", "static", None, True, "native", "Azure OpenAI"
|
||||
),
|
||||
"azure": ProviderSpec(Transport.NATIVE, "azure", "static", None, True, "native"),
|
||||
"vertex_ai": ProviderSpec(
|
||||
Transport.NATIVE, "vertex_ai", "static", None, False, "native"
|
||||
Transport.NATIVE, "vertex_ai", "static", None, False, "native", "Vertex AI"
|
||||
),
|
||||
"bedrock": ProviderSpec(
|
||||
Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native"
|
||||
Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native", "Amazon Bedrock"
|
||||
),
|
||||
"openrouter": ProviderSpec(
|
||||
Transport.OPENAI_COMPATIBLE,
|
||||
|
|
@ -61,6 +64,7 @@ REGISTRY: dict[str, ProviderSpec] = {
|
|||
"https://openrouter.ai/api/v1",
|
||||
False,
|
||||
"bearer",
|
||||
"OpenRouter",
|
||||
),
|
||||
"openai_compatible": ProviderSpec(
|
||||
Transport.OPENAI_COMPATIBLE,
|
||||
|
|
@ -69,6 +73,7 @@ REGISTRY: dict[str, ProviderSpec] = {
|
|||
None,
|
||||
True,
|
||||
"bearer",
|
||||
"OpenAI-compatible provider",
|
||||
),
|
||||
"lm_studio": ProviderSpec(
|
||||
Transport.OPENAI_COMPATIBLE,
|
||||
|
|
@ -77,6 +82,7 @@ REGISTRY: dict[str, ProviderSpec] = {
|
|||
"http://localhost:1234/v1",
|
||||
True,
|
||||
"bearer",
|
||||
"LM Studio",
|
||||
),
|
||||
"ollama_chat": ProviderSpec(
|
||||
Transport.OLLAMA,
|
||||
|
|
@ -85,6 +91,7 @@ REGISTRY: dict[str, ProviderSpec] = {
|
|||
"http://localhost:11434",
|
||||
True,
|
||||
"none",
|
||||
"Ollama",
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -96,4 +103,12 @@ def spec_for(provider: str | None) -> ProviderSpec:
|
|||
)
|
||||
|
||||
|
||||
__all__ = ["REGISTRY", "ProviderSpec", "Transport", "spec_for"]
|
||||
def provider_label(provider: str | None) -> str:
|
||||
provider_key = (provider or "").strip()
|
||||
spec = spec_for(provider_key)
|
||||
if spec.display_name:
|
||||
return spec.display_name
|
||||
return provider_key.replace("_", " ").title() if provider_key else "Provider"
|
||||
|
||||
|
||||
__all__ = ["REGISTRY", "ProviderSpec", "Transport", "provider_label", "spec_for"]
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import type {
|
|||
ModelPreviewRead,
|
||||
ModelRead,
|
||||
ModelRoles,
|
||||
ModelTestPreviewRequest,
|
||||
ModelsBulkUpdateRequest,
|
||||
ModelUpdateRequest,
|
||||
VerifyConnectionResponse,
|
||||
|
|
@ -114,6 +115,20 @@ export const previewConnectionModelsMutationAtom = atomWithMutation(() => {
|
|||
};
|
||||
});
|
||||
|
||||
export const testPreviewModelMutationAtom = atomWithMutation(() => {
|
||||
return {
|
||||
mutationKey: ["model-connections", "test-preview"],
|
||||
mutationFn: (request: ModelTestPreviewRequest) =>
|
||||
modelConnectionsApiService.testPreviewModel(request),
|
||||
onSuccess: (result: VerifyConnectionResponse) => {
|
||||
if (!result.ok) {
|
||||
toast.error(result.message || "Model test failed");
|
||||
}
|
||||
},
|
||||
onError: (error: Error) => toast.error(error.message || "Failed to test model"),
|
||||
};
|
||||
});
|
||||
|
||||
export const addManualModelMutationAtom = atomWithMutation((get) => {
|
||||
const searchSpaceId = Number(get(activeSearchSpaceIdAtom));
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
"use client";
|
||||
|
||||
import { useAtom, useAtomValue } from "jotai";
|
||||
import { CheckCircle2, Trash2, XCircle } from "lucide-react";
|
||||
import { Trash2 } from "lucide-react";
|
||||
import { useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import {
|
||||
createModelConnectionMutationAtom,
|
||||
deleteModelConnectionMutationAtom,
|
||||
previewConnectionModelsMutationAtom,
|
||||
testPreviewModelMutationAtom,
|
||||
updateModelRolesMutationAtom,
|
||||
} from "@/atoms/model-connections/model-connections-mutation.atoms";
|
||||
import {
|
||||
|
|
@ -53,24 +55,6 @@ import {
|
|||
providerIcon,
|
||||
} from "./model-connections/provider-metadata";
|
||||
|
||||
function StatusBadge({ connection }: { connection: ConnectionRead }) {
|
||||
if (connection.last_status === "OK") {
|
||||
return (
|
||||
<Badge variant="outline" className="gap-1 text-green-600">
|
||||
<CheckCircle2 className="h-3 w-3" /> Healthy
|
||||
</Badge>
|
||||
);
|
||||
}
|
||||
if (connection.last_status) {
|
||||
return (
|
||||
<Badge variant="outline" className="gap-1 text-destructive">
|
||||
<XCircle className="h-3 w-3" /> {connection.last_status}
|
||||
</Badge>
|
||||
);
|
||||
}
|
||||
return <Badge variant="secondary">Not tested</Badge>;
|
||||
}
|
||||
|
||||
function flattenModels(connections: ConnectionRead[]) {
|
||||
return connections.flatMap((connection) =>
|
||||
connection.models.map((model) => ({
|
||||
|
|
@ -110,7 +94,6 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
|||
</div>
|
||||
</div>
|
||||
<div className="flex shrink-0 items-center gap-2">
|
||||
<StatusBadge connection={connection} />
|
||||
<ConnectionSettingsDialog connection={connection} providerLabel={providerLabel} />
|
||||
<AlertDialog>
|
||||
<AlertDialogTrigger asChild>
|
||||
|
|
@ -156,6 +139,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
const [{ data: roles }] = useAtom(modelRolesAtom);
|
||||
const createConnection = useAtomValue(createModelConnectionMutationAtom);
|
||||
const previewModels = useAtomValue(previewConnectionModelsMutationAtom);
|
||||
const testPreviewModel = useAtomValue(testPreviewModelMutationAtom);
|
||||
const updateRoles = useAtomValue(updateModelRolesMutationAtom);
|
||||
|
||||
const [isAddProviderOpen, setIsAddProviderOpen] = useState(false);
|
||||
|
|
@ -220,9 +204,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
});
|
||||
}
|
||||
|
||||
// Each provider connect form builds its own credential payload; the backend
|
||||
// resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM.
|
||||
function handleCreate(draft: ConnectionDraft) {
|
||||
function connectionModelsForDraft(draft: ConnectionDraft) {
|
||||
const models = [...connectModels];
|
||||
if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) {
|
||||
models.push({
|
||||
|
|
@ -233,22 +215,46 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
metadata: {},
|
||||
});
|
||||
}
|
||||
return models;
|
||||
}
|
||||
|
||||
createConnection.mutate(
|
||||
function representativeTestModel(models: ModelSelection[]) {
|
||||
const enabledModels = models.filter((model) => model.enabled);
|
||||
return enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0];
|
||||
}
|
||||
|
||||
// Each provider connect form builds its own credential payload; the backend
|
||||
// resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM.
|
||||
function handleCreate(draft: ConnectionDraft) {
|
||||
const models = connectionModelsForDraft(draft);
|
||||
const testModel = representativeTestModel(models);
|
||||
if (!testModel) {
|
||||
toast.error("Select at least one model before connecting");
|
||||
return;
|
||||
}
|
||||
|
||||
const request = {
|
||||
provider,
|
||||
base_url: draft.base_url,
|
||||
api_key: draft.api_key,
|
||||
scope: "SEARCH_SPACE" as const,
|
||||
search_space_id: searchSpaceId,
|
||||
extra: draft.extra,
|
||||
enabled: true,
|
||||
models,
|
||||
};
|
||||
|
||||
testPreviewModel.mutate(
|
||||
{ ...request, model_id: testModel.model_id },
|
||||
{
|
||||
provider,
|
||||
base_url: draft.base_url,
|
||||
api_key: draft.api_key,
|
||||
scope: "SEARCH_SPACE",
|
||||
search_space_id: searchSpaceId,
|
||||
extra: draft.extra,
|
||||
enabled: true,
|
||||
models,
|
||||
},
|
||||
{
|
||||
onSuccess: () => {
|
||||
setIsAddProviderOpen(false);
|
||||
resetConnectState();
|
||||
onSuccess: (result) => {
|
||||
if (!result.ok) return;
|
||||
createConnection.mutate(request, {
|
||||
onSuccess: () => {
|
||||
setIsAddProviderOpen(false);
|
||||
resetConnectState();
|
||||
},
|
||||
});
|
||||
},
|
||||
}
|
||||
);
|
||||
|
|
@ -380,7 +386,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
onOpenChange={handleConnectOpenChange}
|
||||
provider={provider}
|
||||
selectedProvider={selectedProvider}
|
||||
isPending={createConnection.isPending}
|
||||
isPending={createConnection.isPending || testPreviewModel.isPending}
|
||||
onSubmit={handleCreate}
|
||||
previewModels={connectModels}
|
||||
isPreviewingModels={previewModels.isPending}
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ import {
|
|||
addManualModelMutationAtom,
|
||||
bulkUpdateModelsMutationAtom,
|
||||
discoverConnectionModelsMutationAtom,
|
||||
testPreviewModelMutationAtom,
|
||||
updateModelConnectionMutationAtom,
|
||||
updateModelMutationAtom,
|
||||
verifyModelConnectionMutationAtom,
|
||||
} from "@/atoms/model-connections/model-connections-mutation.atoms";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
|
|
@ -26,7 +26,7 @@ import type {
|
|||
ConnectionRead,
|
||||
ConnectionUpdateRequest,
|
||||
} from "@/contracts/types/model-connections.types";
|
||||
import type { SelectableModel } from "./model-utils";
|
||||
import { capability, type SelectableModel } from "./model-utils";
|
||||
import { ModelsSelectionPanel } from "./models-selection-panel";
|
||||
import { providerIcon } from "./provider-metadata";
|
||||
|
||||
|
|
@ -39,8 +39,8 @@ export function ConnectionSettingsDialog({
|
|||
connection,
|
||||
providerLabel,
|
||||
}: ConnectionSettingsDialogProps) {
|
||||
const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom);
|
||||
const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom);
|
||||
const testPreviewModel = useAtomValue(testPreviewModelMutationAtom);
|
||||
const updateConnection = useAtomValue(updateModelConnectionMutationAtom);
|
||||
const addManualModel = useAtomValue(addManualModelMutationAtom);
|
||||
const updateModel = useAtomValue(updateModelMutationAtom);
|
||||
|
|
@ -81,11 +81,45 @@ export function ConnectionSettingsDialog({
|
|||
if (apiKeyDraft.trim() !== (connection.api_key ?? "")) {
|
||||
data.api_key = apiKeyDraft.trim() || null;
|
||||
}
|
||||
const apiKeyForTest = Object.hasOwn(data, "api_key")
|
||||
? (data.api_key ?? null)
|
||||
: (connection.api_key ?? null);
|
||||
|
||||
updateConnection.mutate(
|
||||
{ id: connection.id, data },
|
||||
const enabledModels = connection.models.filter((model) => model.enabled);
|
||||
const testModel =
|
||||
enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0];
|
||||
if (!testModel) {
|
||||
updateConnection.mutate(
|
||||
{ id: connection.id, data },
|
||||
{
|
||||
onSuccess: () => setApiKeyDraft(""),
|
||||
}
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
testPreviewModel.mutate(
|
||||
{
|
||||
onSuccess: () => setApiKeyDraft(""),
|
||||
provider: connection.provider,
|
||||
base_url: data.base_url,
|
||||
api_key: apiKeyForTest,
|
||||
scope: "SEARCH_SPACE",
|
||||
search_space_id: connection.search_space_id,
|
||||
extra: connection.extra ?? {},
|
||||
enabled: connection.enabled,
|
||||
models: [],
|
||||
model_id: testModel.model_id,
|
||||
},
|
||||
{
|
||||
onSuccess: (result) => {
|
||||
if (!result.ok) return;
|
||||
updateConnection.mutate(
|
||||
{ id: connection.id, data },
|
||||
{
|
||||
onSuccess: () => setApiKeyDraft(""),
|
||||
}
|
||||
);
|
||||
},
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
@ -219,26 +253,15 @@ export function ConnectionSettingsDialog({
|
|||
onBulkToggle={handleBulkToggle}
|
||||
/>
|
||||
|
||||
{connection.last_status && connection.last_status !== "OK" ? (
|
||||
<p className="rounded-lg bg-amber-500/10 px-3 py-2 text-sm text-amber-600 dark:text-amber-500">
|
||||
{connection.last_error || "Could not list models."} Chat may still work; add model
|
||||
IDs manually if discovery is unavailable.
|
||||
</p>
|
||||
) : null}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<DialogFooter className="shrink-0 border-t bg-popover px-6 py-4">
|
||||
<Button
|
||||
variant="secondary"
|
||||
onClick={() => verifyConnection.mutate(connection.id)}
|
||||
disabled={verifyConnection.isPending}
|
||||
>
|
||||
Test
|
||||
</Button>
|
||||
<Button
|
||||
onClick={saveConnectionSettings}
|
||||
disabled={updateConnection.isPending || !hasConnectionChanges}
|
||||
disabled={
|
||||
updateConnection.isPending || testPreviewModel.isPending || !hasConnectionChanges
|
||||
}
|
||||
>
|
||||
Update
|
||||
</Button>
|
||||
|
|
|
|||
|
|
@ -94,6 +94,8 @@ export function ProviderConnectDialog({
|
|||
})();
|
||||
|
||||
const canRefreshModels = !isAzure && !isVertex && (!isBedrock || canSubmit);
|
||||
const hasEnabledModel = previewModels.some((model) => model.enabled) || Boolean(currentDraft.seedModelId);
|
||||
const canConnect = canSubmit && hasEnabledModel;
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
|
|
@ -134,7 +136,7 @@ export function ProviderConnectDialog({
|
|||
<ConnectFormFooter
|
||||
onCancel={() => onOpenChange(false)}
|
||||
onSubmit={() => onSubmit(currentDraft)}
|
||||
canSubmit={canSubmit}
|
||||
canSubmit={canConnect}
|
||||
isPending={isPending}
|
||||
/>
|
||||
</DialogContent>
|
||||
|
|
|
|||
|
|
@ -66,6 +66,10 @@ export const connectionCreateRequest = z.object({
|
|||
models: z.array(modelSelection).default([]),
|
||||
});
|
||||
|
||||
export const modelTestPreviewRequest = connectionCreateRequest.extend({
|
||||
model_id: z.string().min(1),
|
||||
});
|
||||
|
||||
export const connectionUpdateRequest = z.object({
|
||||
provider: z.string().nullable().optional(),
|
||||
base_url: z.string().nullable().optional(),
|
||||
|
|
@ -129,6 +133,7 @@ export type ModelPreviewRead = z.infer<typeof modelPreviewRead>;
|
|||
export type ModelSelection = z.infer<typeof modelSelection>;
|
||||
export type ConnectionRead = z.infer<typeof connectionRead>;
|
||||
export type ConnectionCreateRequest = z.infer<typeof connectionCreateRequest>;
|
||||
export type ModelTestPreviewRequest = z.infer<typeof modelTestPreviewRequest>;
|
||||
export type ConnectionUpdateRequest = z.infer<typeof connectionUpdateRequest>;
|
||||
export type ModelCreateRequest = z.infer<typeof modelCreateRequest>;
|
||||
export type ModelUpdateRequest = z.infer<typeof modelUpdateRequest>;
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import {
|
|||
type ModelProviderRead,
|
||||
type ModelRead,
|
||||
type ModelRoles,
|
||||
type ModelTestPreviewRequest,
|
||||
type ModelsBulkUpdateRequest,
|
||||
type ModelUpdateRequest,
|
||||
modelCreateRequest,
|
||||
|
|
@ -19,6 +20,7 @@ import {
|
|||
modelProviderListResponse,
|
||||
modelRead,
|
||||
modelRoles,
|
||||
modelTestPreviewRequest,
|
||||
modelsBulkUpdateRequest,
|
||||
modelUpdateRequest,
|
||||
type VerifyConnectionResponse,
|
||||
|
|
@ -92,6 +94,20 @@ class ModelConnectionsApiService {
|
|||
);
|
||||
};
|
||||
|
||||
testPreviewModel = async (request: ModelTestPreviewRequest): Promise<VerifyConnectionResponse> => {
|
||||
const parsed = modelTestPreviewRequest.safeParse(request);
|
||||
if (!parsed.success) {
|
||||
throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", "));
|
||||
}
|
||||
return baseApiService.post(
|
||||
`/api/v1/model-connections/test-preview`,
|
||||
verifyConnectionResponse,
|
||||
{
|
||||
body: parsed.data,
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
addManualModel = async (
|
||||
connectionId: number,
|
||||
request: ModelCreateRequest
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue