feat(model-connections): enhance model connection functionality with preview and selection features

This commit is contained in:
Anish Sarkar 2026-06-12 22:41:21 +05:30
parent 356f0e56c5
commit 407f2a9612
20 changed files with 630 additions and 429 deletions

View file

@ -21,10 +21,12 @@ from app.schemas import (
ConnectionRead,
ConnectionUpdate,
ModelCreate,
ModelPreviewRead,
ModelProviderRead,
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
ModelSelection,
ModelsBulkUpdate,
ModelUpdate,
VerifyConnectionResponse,
@ -48,6 +50,21 @@ def _model_read(model: Model | dict) -> ModelRead:
return ModelRead.model_validate(model)
def _preview_model_read(item: dict) -> ModelPreviewRead:
return ModelPreviewRead(
model_id=item["model_id"],
display_name=item.get("display_name"),
source=item.get("source", ModelSource.DISCOVERED),
supports_chat=item.get("supports_chat"),
max_input_tokens=item.get("max_input_tokens"),
supports_image_input=item.get("supports_image_input"),
supports_tools=item.get("supports_tools"),
supports_image_generation=item.get("supports_image_generation"),
enabled=item.get("enabled", False),
metadata=item.get("metadata") or item.get("catalog") or {},
)
def _connection_read(conn: Connection | dict, models: list[Model | dict] | None = None) -> ConnectionRead:
if isinstance(conn, dict):
payload = {
@ -86,6 +103,25 @@ def _apply_model_facts(model: Model, facts: dict) -> None:
model.supports_image_generation = facts.get("supports_image_generation")
def _selection_to_model(conn: Connection, selection: ModelSelection) -> Model:
source = (
selection.source
if isinstance(selection.source, ModelSource)
else ModelSource(selection.source)
)
model = Model(
connection_id=conn.id,
model_id=selection.model_id.strip(),
display_name=selection.display_name,
source=source,
capabilities_override={},
enabled=selection.enabled,
catalog=selection.metadata,
)
_apply_model_facts(model, selection.model_dump())
return model
def _default_model_for(models: list[Model], capability: str) -> int | None:
for model in models:
if model.enabled and has_capability(model, capability):
@ -226,7 +262,7 @@ async def create_connection(
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
payload = data.model_dump(exclude={"search_space_id"})
payload = data.model_dump(exclude={"search_space_id", "models"})
conn = Connection(
**payload,
@ -234,9 +270,51 @@ async def create_connection(
user_id=user.id,
)
session.add(conn)
await session.flush()
seen_model_ids: set[str] = set()
for selection in data.models:
model_id = selection.model_id.strip()
if not model_id or model_id in seen_model_ids:
continue
seen_model_ids.add(model_id)
session.add(_selection_to_model(conn, selection))
await session.commit()
await session.refresh(conn)
return _connection_read(conn, [])
conn = await _load_connection(session, conn.id)
await _default_unset_roles(session, conn, list(conn.models))
await session.commit()
conn = await _load_connection(session, conn.id)
return _connection_read(conn, list(conn.models))
@router.post("/model-connections/discover-preview", response_model=list[ModelPreviewRead])
async def preview_connection_models(
data: ConnectionCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
await check_permission(
session,
user,
data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
draft = Connection(
provider=data.provider,
base_url=data.base_url,
api_key=data.api_key,
extra=data.extra or {},
scope=data.scope,
enabled=data.enabled,
search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None,
user_id=user.id,
)
discovered = await discover_models(draft)
return [_preview_model_read(item) for item in discovered]
@router.put("/model-connections/{connection_id}", response_model=ConnectionRead)

View file

@ -49,10 +49,12 @@ from .model_connections import (
ConnectionRead,
ConnectionUpdate,
ModelCreate,
ModelPreviewRead,
ModelProviderRead,
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
ModelSelection,
ModelsBulkUpdate,
ModelUpdate,
VerifyConnectionResponse,

View file

@ -48,6 +48,32 @@ class ConnectionRead(BaseModel):
model_config = ConfigDict(from_attributes=True)
class ModelSelection(BaseModel):
model_id: str = Field(..., max_length=255)
display_name: str | None = Field(None, max_length=255)
source: ModelSource | str = ModelSource.DISCOVERED
supports_chat: bool | None = None
max_input_tokens: int | None = None
supports_image_input: bool | None = None
supports_tools: bool | None = None
supports_image_generation: bool | None = None
enabled: bool = False
metadata: dict[str, Any] = Field(default_factory=dict)
class ModelPreviewRead(BaseModel):
model_id: str
display_name: str | None = None
source: ModelSource | str = ModelSource.DISCOVERED
supports_chat: bool | None = None
max_input_tokens: int | None = None
supports_image_input: bool | None = None
supports_tools: bool | None = None
supports_image_generation: bool | None = None
enabled: bool = False
metadata: dict[str, Any] = Field(default_factory=dict)
class ConnectionCreate(BaseModel):
provider: str = Field(..., max_length=100)
base_url: str | None = Field(None, max_length=500)
@ -56,6 +82,7 @@ class ConnectionCreate(BaseModel):
scope: ConnectionScope = ConnectionScope.SEARCH_SPACE
search_space_id: int | None = None
enabled: bool = True
models: list[ModelSelection] = Field(default_factory=list)
class ConnectionUpdate(BaseModel):

View file

@ -8,6 +8,7 @@ from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
import anyio
import httpx
import litellm
@ -292,6 +293,48 @@ def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]:
return results
async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]:
params = (conn.extra or {}).get("litellm_params", {})
region_name = params.get("aws_region_name")
if not region_name:
return []
def list_models() -> list[dict[str, Any]]:
import boto3
client_kwargs: dict[str, str] = {"region_name": region_name}
if params.get("aws_access_key_id"):
client_kwargs["aws_access_key_id"] = params["aws_access_key_id"]
if params.get("aws_secret_access_key"):
client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"]
client = boto3.client("bedrock", **client_kwargs)
response = client.list_foundation_models()
results: list[dict[str, Any]] = []
for item in response.get("modelSummaries", []):
model_id = item.get("modelId")
if not model_id:
continue
input_modalities = set(item.get("inputModalities") or [])
output_modalities = set(item.get("outputModalities") or [])
results.append(
{
"model_id": model_id,
"display_name": item.get("modelName") or model_id,
"source": ModelSource.DISCOVERED,
"supports_chat": "TEXT" in input_modalities and "TEXT" in output_modalities,
"supports_image_input": "IMAGE" in input_modalities,
"supports_tools": None,
"supports_image_generation": "IMAGE" in output_modalities,
"max_input_tokens": None,
"metadata": item,
}
)
return results
return await anyio.to_thread.run_sync(list_models)
async def discover_models(conn: Connection) -> list[dict[str, Any]]:
allowlist = _allowlist(conn)
spec = spec_for(conn.provider)
@ -304,6 +347,8 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]:
results = await _discover_anthropic_models(conn)
elif spec.discovery == "openai_models":
results = await _discover_openai_shaped_models(conn, conn.base_url)
elif spec.discovery == "bedrock_models":
results = await _discover_bedrock_models(conn)
elif spec.discovery == "static":
results = _litellm_static_models(conn)
else:

View file

@ -21,6 +21,7 @@ DiscoveryKind = Literal[
"ollama",
"openai_models",
"anthropic_models",
"bedrock_models",
"openrouter",
"static",
"none",
@ -51,7 +52,7 @@ REGISTRY: dict[str, ProviderSpec] = {
Transport.NATIVE, "vertex_ai", "static", None, False, "native"
),
"bedrock": ProviderSpec(
Transport.NATIVE, "bedrock", "static", None, False, "native"
Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native"
),
"openrouter": ProviderSpec(
Transport.OPENAI_COMPATIBLE,

View file

@ -4,6 +4,7 @@ import type {
ConnectionCreateRequest,
ConnectionUpdateRequest,
ModelCreateRequest,
ModelPreviewRead,
ModelRead,
ModelRoles,
ModelsBulkUpdateRequest,
@ -103,6 +104,16 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => {
};
});
export const previewConnectionModelsMutationAtom = atomWithMutation(() => {
return {
mutationKey: ["model-connections", "discover-preview"],
mutationFn: (request: ConnectionCreateRequest) =>
modelConnectionsApiService.previewModels(request),
onSuccess: (_models: ModelPreviewRead[]) => {},
onError: (error: Error) => toast.error(error.message || "Failed to discover models"),
};
});
export const addManualModelMutationAtom = atomWithMutation((get) => {
const searchSpaceId = Number(get(activeSearchSpaceIdAtom));
return {

View file

@ -2,7 +2,7 @@
import { useQueryClient } from "@tanstack/react-query";
import { useAtom, useAtomValue } from "jotai";
import { RefreshCcw, Workflow } from "lucide-react";
import { RefreshCw, Workflow } from "lucide-react";
import { useCallback } from "react";
import { actionLogDialogAtom } from "@/atoms/agent/action-log-dialog.atom";
import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom";
@ -112,7 +112,7 @@ export function ActionLogDialog() {
className="absolute right-14 top-4 size-8 rounded-full p-0 text-muted-foreground hover:bg-accent hover:text-accent-foreground"
aria-label="Refresh action log"
>
<RefreshCcw className={isFetching ? "size-3.5 animate-spin" : "size-3.5"} />
<RefreshCw className={isFetching ? "size-3.5 animate-spin" : "size-3.5"} />
</Button>
<div className="flex min-h-0 flex-1 flex-col overflow-y-auto scrollbar-thin">

View file

@ -4,12 +4,9 @@ import { useAtom, useAtomValue } from "jotai";
import { CheckCircle2, Trash2, XCircle } from "lucide-react";
import { useState } from "react";
import {
addManualModelMutationAtom,
bulkUpdateModelsMutationAtom,
createModelConnectionMutationAtom,
deleteModelConnectionMutationAtom,
discoverConnectionModelsMutationAtom,
updateModelMutationAtom,
previewConnectionModelsMutationAtom,
updateModelRolesMutationAtom,
} from "@/atoms/model-connections/model-connections-mutation.atoms";
import {
@ -41,9 +38,13 @@ import {
SelectValue,
} from "@/components/ui/select";
import { Separator } from "@/components/ui/separator";
import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types";
import type {
ConnectionRead,
ModelRead,
ModelSelection,
} from "@/contracts/types/model-connections.types";
import { ConnectionSettingsDialog } from "./model-connections/connection-settings-dialog";
import { capability, modelLabel } from "./model-connections/model-utils";
import { capability, modelLabel, type SelectableModel } from "./model-connections/model-utils";
import { ProviderConnectDialog } from "./model-connections/provider-connect-dialog";
import {
type ConnectionDraft,
@ -154,16 +155,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
const [{ data: providers = [] }] = useAtom(modelProvidersAtom);
const [{ data: roles }] = useAtom(modelRolesAtom);
const createConnection = useAtomValue(createModelConnectionMutationAtom);
const addManualModel = useAtomValue(addManualModelMutationAtom);
const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom);
const updateModel = useAtomValue(updateModelMutationAtom);
const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom);
const previewModels = useAtomValue(previewConnectionModelsMutationAtom);
const updateRoles = useAtomValue(updateModelRolesMutationAtom);
const [isAddProviderOpen, setIsAddProviderOpen] = useState(false);
const [provider, setProvider] = useState("openai_compatible");
const [connectedConnection, setConnectedConnection] = useState<ConnectionRead | null>(null);
const [connectModels, setConnectModels] = useState<ModelRead[]>([]);
const [connectModels, setConnectModels] = useState<ModelSelection[]>([]);
const selectedProvider = providers.find((item) => item.provider === provider);
const sortedProviders = [...providers].sort((left, right) => {
@ -185,7 +182,6 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
const imageModels = enabledModels.filter((model) => capability(model, "image_gen"));
function resetConnectState() {
setConnectedConnection(null);
setConnectModels([]);
}
@ -196,15 +192,48 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
}
}
function replaceConnectModels(updatedModels: ModelRead[]) {
setConnectModels((current) =>
current.map((model) => updatedModels.find((updated) => updated.id === model.id) ?? model)
);
function toModelSelection(model: SelectableModel): ModelSelection {
return {
model_id: model.model_id,
display_name: model.display_name,
source: model.source || "DISCOVERED",
supports_chat: model.supports_chat,
max_input_tokens: model.max_input_tokens,
supports_image_input: model.supports_image_input,
supports_tools: model.supports_tools,
supports_image_generation: model.supports_image_generation,
enabled: model.enabled,
metadata: "metadata" in model ? (model.metadata ?? {}) : (model.catalog ?? {}),
};
}
function mergePreviewModels(fetchedModels: SelectableModel[]) {
setConnectModels((current) => {
const currentById = new Map(current.map((model) => [model.model_id, model]));
return fetchedModels.map((model) => {
const prior = currentById.get(model.model_id);
return {
...toModelSelection(model),
enabled: prior ? prior.enabled : model.enabled,
};
});
});
}
// Each provider connect form builds its own credential payload; the backend
// resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM.
function handleCreate(draft: ConnectionDraft) {
const models = [...connectModels];
if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) {
models.push({
model_id: draft.seedModelId,
display_name: draft.seedModelId,
source: "MANUAL",
enabled: true,
metadata: {},
});
}
createConnection.mutate(
{
provider,
@ -214,26 +243,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
search_space_id: searchSpaceId,
extra: draft.extra,
enabled: true,
models,
},
{
onSuccess: (created) => {
setConnectedConnection(created);
setConnectModels([]);
if (draft.seedModelId) {
addManualModel.mutate(
{
connectionId: created.id,
data: { model_id: draft.seedModelId },
},
{
onSuccess: (model) => setConnectModels([model]),
}
);
} else {
discoverModels.mutate(created.id, {
onSuccess: (models) => setConnectModels(models),
});
}
onSuccess: () => {
setIsAddProviderOpen(false);
resetConnectState();
},
}
);
@ -243,52 +258,72 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
resetConnectState();
setProvider(providerId);
setIsAddProviderOpen(true);
if (providerId === "vertex_ai") {
previewModels.mutate(
{
provider: providerId,
base_url: null,
api_key: null,
scope: "SEARCH_SPACE",
search_space_id: searchSpaceId,
extra: {},
enabled: true,
models: [],
},
{
onSuccess: mergePreviewModels,
}
);
}
}
function refreshConnectModels() {
if (!connectedConnection) return;
discoverModels.mutate(connectedConnection.id, {
onSuccess: (models) => setConnectModels(models),
});
function refreshConnectModels(draft: ConnectionDraft) {
previewModels.mutate(
{
provider,
base_url: draft.base_url,
api_key: draft.api_key,
scope: "SEARCH_SPACE",
search_space_id: searchSpaceId,
extra: draft.extra,
enabled: true,
models: [],
},
{
onSuccess: mergePreviewModels,
}
);
}
function addConnectModel(modelId: string) {
if (!connectedConnection) return;
addManualModel.mutate(
{ connectionId: connectedConnection.id, data: { model_id: modelId } },
{
onSuccess: (model) => setConnectModels((current) => [...current, model]),
}
setConnectModels((current) => {
if (current.some((model) => model.model_id === modelId)) return current;
return [
...current,
{
model_id: modelId,
display_name: modelId,
source: "MANUAL",
enabled: true,
metadata: {},
},
];
});
}
function toggleConnectModel(model: SelectableModel, enabled: boolean) {
setConnectModels((current) =>
current.map((item) => (item.model_id === model.model_id ? { ...item, enabled } : item))
);
}
function toggleConnectModel(model: ModelRead, enabled: boolean) {
updateModel.mutate(
{ id: model.id, data: { enabled } },
{
onSuccess: (updated) => replaceConnectModels([updated]),
}
function bulkToggleConnectModels(models: SelectableModel[], enabled: boolean) {
const modelIds = new Set(models.map((model) => model.model_id));
setConnectModels((current) =>
current.map((item) => (modelIds.has(item.model_id) ? { ...item, enabled } : item))
);
}
function bulkToggleConnectModels(models: ModelRead[], enabled: boolean) {
if (!connectedConnection) return;
bulkUpdateModels.mutate(
{
connectionId: connectedConnection.id,
data: { model_ids: models.map((model) => model.id), enabled },
},
{
onSuccess: replaceConnectModels,
}
);
}
function finishConnectFlow() {
setIsAddProviderOpen(false);
resetConnectState();
}
function renderModelOption(model: ModelRead & { connectionName: string; provider: string }) {
return (
<SelectItem key={model.id} value={String(model.id)}>
@ -347,17 +382,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
selectedProvider={selectedProvider}
isPending={createConnection.isPending}
onSubmit={handleCreate}
connectedConnection={connectedConnection}
connectModels={connectModels}
isDiscoveringModels={discoverModels.isPending}
isAddingManualModel={addManualModel.isPending}
isUpdatingModel={updateModel.isPending}
isBulkUpdatingModels={bulkUpdateModels.isPending}
onRefreshModels={refreshConnectModels}
onAddManualModel={addConnectModel}
onToggleModel={toggleConnectModel}
onBulkToggleModels={bulkToggleConnectModels}
onDone={finishConnectFlow}
previewModels={connectModels}
isPreviewingModels={previewModels.isPending}
onPreviewModels={refreshConnectModels}
onAddPreviewModel={addConnectModel}
onTogglePreviewModel={toggleConnectModel}
onBulkTogglePreviewModels={bulkToggleConnectModels}
/>
{connections.length > 0 ? (

View file

@ -1,7 +1,7 @@
import { useState } from "react";
import { useEffect, useState } from "react";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { ApiKeyField, ConnectFormFooter } from "./connect-fields";
import { ApiKeyField } from "./connect-fields";
import {
isValidAzureTargetUri,
type ProviderConnectFormProps,
@ -12,48 +12,43 @@ import {
* Azure OpenAI connect form. The user pastes a single Target URI, which we parse
* into api base, api version, and the deployment name (seeded as the model).
*/
export function AzureConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) {
export function AzureConnectForm({ onDraftChange }: ProviderConnectFormProps) {
const [targetUri, setTargetUri] = useState("");
const [apiKey, setApiKey] = useState("");
const canSubmit = isValidAzureTargetUri(targetUri) && Boolean(apiKey.trim());
function handleSubmit() {
useEffect(() => {
const parsed = parseAzureTargetUri(targetUri);
onSubmit({
base_url: parsed?.origin ?? null,
api_key: apiKey || null,
extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {},
seedModelId: parsed?.deploymentName || undefined,
});
}
onDraftChange(
{
base_url: parsed?.origin ?? null,
api_key: apiKey || null,
extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {},
seedModelId: parsed?.deploymentName || undefined,
},
canSubmit
);
}, [apiKey, canSubmit, onDraftChange, targetUri]);
return (
<>
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-2">
<Label>Target URI</Label>
<Input
value={targetUri}
onChange={(event) => setTargetUri(event.target.value)}
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
/>
<p className="text-xs text-muted-foreground">
Paste your endpoint target URI from Azure OpenAI (including API base, deployment name,
and API version).
</p>
</div>
<ApiKeyField
value={apiKey}
onChange={setApiKey}
placeholder="Paste your API key from Azure"
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-2">
<Label>Target URI</Label>
<Input
value={targetUri}
onChange={(event) => setTargetUri(event.target.value)}
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
/>
<p className="text-xs text-muted-foreground">
Paste your endpoint target URI from Azure OpenAI (including API base, deployment name, and
API version).
</p>
</div>
<ConnectFormFooter
onCancel={onCancel}
onSubmit={handleSubmit}
canSubmit={canSubmit}
isPending={isPending}
<ApiKeyField
value={apiKey}
onChange={setApiKey}
placeholder="Paste your API key from Azure"
/>
</>
</div>
);
}

View file

@ -1,4 +1,4 @@
import { useState } from "react";
import { useEffect, useState } from "react";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import {
@ -8,7 +8,7 @@ import {
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { ConnectFormFooter } from "./connect-fields";
import { ApiKeyField } from "./connect-fields";
import {
AWS_REGION_OPTIONS,
BEDROCK_AUTH_ACCESS_KEY,
@ -21,7 +21,7 @@ import {
* Amazon Bedrock connect form. Region + auth method drive which AWS credentials
* are collected; everything rides along in `extra.litellm_params`.
*/
export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) {
export function BedrockConnectForm({ onDraftChange }: ProviderConnectFormProps) {
const [region, setRegion] = useState("");
const [authMethod, setAuthMethod] = useState(BEDROCK_AUTH_ACCESS_KEY);
const [accessKeyId, setAccessKeyId] = useState("");
@ -39,7 +39,7 @@ export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderCo
return true;
})();
function handleSubmit() {
useEffect(() => {
const params: Record<string, string> = { aws_region_name: region };
if (authMethod === BEDROCK_AUTH_ACCESS_KEY) {
params.aws_access_key_id = accessKeyId;
@ -47,88 +47,74 @@ export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderCo
} else if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) {
params.aws_bearer_token_bedrock = bearerToken;
}
onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } });
}
onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit);
}, [accessKeyId, authMethod, bearerToken, canSubmit, onDraftChange, region, secretAccessKey]);
return (
<>
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-2">
<Label>AWS Region</Label>
<Select value={region || undefined} onValueChange={setRegion}>
<SelectTrigger>
<SelectValue placeholder="Select a region" />
</SelectTrigger>
<SelectContent>
{AWS_REGION_OPTIONS.map((option) => (
<SelectItem key={option} value={option}>
{option}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="flex flex-col gap-2">
<Label>Authentication Method</Label>
<Select value={authMethod} onValueChange={setAuthMethod}>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value={BEDROCK_AUTH_IAM}>Environment IAM Role</SelectItem>
<SelectItem value={BEDROCK_AUTH_ACCESS_KEY}>Access Key</SelectItem>
<SelectItem value={BEDROCK_AUTH_LONG_TERM_API_KEY}>Long-term API Key</SelectItem>
</SelectContent>
</Select>
</div>
{authMethod === BEDROCK_AUTH_ACCESS_KEY ? (
<>
<div className="flex flex-col gap-2">
<Label>AWS Access Key ID</Label>
<Input
value={accessKeyId}
onChange={(event) => setAccessKeyId(event.target.value)}
placeholder="AKIAIOSFODNN7EXAMPLE"
/>
</div>
<div className="flex flex-col gap-2">
<Label>AWS Secret Access Key</Label>
<Input
value={secretAccessKey}
onChange={(event) => setSecretAccessKey(event.target.value)}
placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
type="password"
/>
</div>
</>
) : null}
{authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? (
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-2">
<Label>AWS Region</Label>
<Select value={region || undefined} onValueChange={setRegion}>
<SelectTrigger>
<SelectValue placeholder="Select a region" />
</SelectTrigger>
<SelectContent>
{AWS_REGION_OPTIONS.map((option) => (
<SelectItem key={option} value={option}>
{option}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="flex flex-col gap-2">
<Label>Authentication Method</Label>
<Select value={authMethod} onValueChange={setAuthMethod}>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value={BEDROCK_AUTH_IAM}>Environment IAM Role</SelectItem>
<SelectItem value={BEDROCK_AUTH_ACCESS_KEY}>Access Key</SelectItem>
<SelectItem value={BEDROCK_AUTH_LONG_TERM_API_KEY}>Long-term API Key</SelectItem>
</SelectContent>
</Select>
</div>
{authMethod === BEDROCK_AUTH_ACCESS_KEY ? (
<>
<div className="flex flex-col gap-2">
<Label>Long-term API Key</Label>
<Label>AWS Access Key ID</Label>
<Input
value={bearerToken}
onChange={(event) => setBearerToken(event.target.value)}
placeholder="Your long-term API key"
type="password"
value={accessKeyId}
onChange={(event) => setAccessKeyId(event.target.value)}
placeholder="AKIAIOSFODNN7EXAMPLE"
/>
</div>
) : null}
{authMethod === BEDROCK_AUTH_IAM ? (
<p className="text-xs text-muted-foreground">
SurfSense will use the IAM role attached to the environment it&apos;s running in to
authenticate.
</p>
) : null}
<ApiKeyField
value={secretAccessKey}
onChange={setSecretAccessKey}
label="AWS Secret Access Key"
placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
/>
</>
) : null}
{authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? (
<ApiKeyField
value={bearerToken}
onChange={setBearerToken}
label="Long-term API Key"
placeholder="Your long-term API key"
/>
) : null}
{authMethod === BEDROCK_AUTH_IAM ? (
<p className="text-xs text-muted-foreground">
Add Bedrock model IDs from the provider&apos;s settings after connecting.
SurfSense will use the IAM role attached to the environment it&apos;s running in to
authenticate.
</p>
</div>
<ConnectFormFooter
onCancel={onCancel}
onSubmit={handleSubmit}
canSubmit={canSubmit}
isPending={isPending}
/>
</>
) : null}
<p className="text-xs text-muted-foreground">
Add Bedrock model IDs from the provider&apos;s settings after connecting.
</p>
</div>
);
}

View file

@ -1,3 +1,5 @@
import { Eye, EyeOff } from "lucide-react";
import { useState } from "react";
import { Button } from "@/components/ui/button";
import { DialogFooter } from "@/components/ui/dialog";
import { Input } from "@/components/ui/input";
@ -43,15 +45,31 @@ export function ApiKeyField({
label = "API Key",
placeholder = "API key",
}: ApiKeyFieldProps) {
const [showApiKey, setShowApiKey] = useState(false);
return (
<div className="flex flex-col gap-2">
<Label>{label}</Label>
<Input
value={value}
onChange={(event) => onChange(event.target.value)}
placeholder={placeholder}
type="password"
/>
<div className="relative">
<Input
value={value}
onChange={(event) => onChange(event.target.value)}
placeholder={placeholder}
type={showApiKey ? "text" : "password"}
className="pr-11"
/>
<Button
type="button"
variant="ghost"
size="icon"
className="absolute top-1/2 right-1 size-8 -translate-y-1/2 text-muted-foreground"
onClick={() => setShowApiKey((current) => !current)}
disabled={!value}
aria-label={showApiKey ? "Hide API key" : "Show API key"}
>
{showApiKey ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
</Button>
</div>
</div>
);
}
@ -71,7 +89,7 @@ export function ConnectFormFooter({
isPending,
}: ConnectFormFooterProps) {
return (
<DialogFooter className="mt-6">
<DialogFooter className="shrink-0 border-t bg-popover px-6 py-4">
<Button variant="secondary" onClick={onCancel}>
Cancel
</Button>

View file

@ -25,8 +25,8 @@ import { Separator } from "@/components/ui/separator";
import type {
ConnectionRead,
ConnectionUpdateRequest,
ModelRead,
} from "@/contracts/types/model-connections.types";
import type { SelectableModel } from "./model-utils";
import { ModelsSelectionPanel } from "./models-selection-panel";
import { providerIcon } from "./provider-metadata";
@ -101,17 +101,22 @@ export function ConnectionSettingsDialog({
});
}
function handleToggleModel(model: ModelRead, enabled: boolean) {
function handleToggleModel(model: SelectableModel, enabled: boolean) {
if (typeof model.id !== "number") return;
updateModel.mutate({
id: model.id,
data: { enabled },
});
}
function handleBulkToggle(models: ModelRead[], enabled: boolean) {
function handleBulkToggle(models: SelectableModel[], enabled: boolean) {
const modelIds = models
.map((model) => model.id)
.filter((id): id is number => typeof id === "number");
if (modelIds.length === 0) return;
bulkUpdateModels.mutate({
connectionId: connection.id,
data: { model_ids: models.map((model) => model.id), enabled },
data: { model_ids: modelIds, enabled },
});
}
@ -184,12 +189,7 @@ export function ConnectionSettingsDialog({
onChange={(event) => setAllowlistText(event.target.value)}
placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro"
/>
<Button
variant="outline"
size="sm"
onClick={saveAllowlist}
disabled={updateConnection.isPending}
>
<Button size="sm" onClick={saveAllowlist} disabled={updateConnection.isPending}>
Save filter
</Button>
</div>

View file

@ -1,5 +1,5 @@
import { useState } from "react";
import { ApiBaseUrlField, ApiKeyField, ConnectFormFooter } from "./connect-fields";
import { useEffect, useState } from "react";
import { ApiBaseUrlField, ApiKeyField } from "./connect-fields";
import type { ProviderConnectFormProps } from "./provider-metadata";
/**
@ -11,41 +11,31 @@ export function DefaultConnectForm({
provider,
defaultBaseUrl,
baseUrlRequired,
isPending,
onCancel,
onSubmit,
onDraftChange,
}: ProviderConnectFormProps) {
const [baseUrl, setBaseUrl] = useState(defaultBaseUrl);
const [apiKey, setApiKey] = useState("");
const isOllama = provider === "ollama_chat";
const canSubmit = !(baseUrlRequired && !baseUrl.trim());
function handleSubmit() {
onSubmit({ base_url: baseUrl || null, api_key: apiKey || null, extra: {} });
}
useEffect(() => {
onDraftChange({ base_url: baseUrl || null, api_key: apiKey || null, extra: {} }, canSubmit);
}, [apiKey, baseUrl, canSubmit, onDraftChange]);
return (
<>
<div className="flex flex-col gap-4">
<ApiBaseUrlField
value={baseUrl}
onChange={setBaseUrl}
optional={!baseUrlRequired}
placeholder={defaultBaseUrl}
/>
<ApiKeyField
value={apiKey}
onChange={setApiKey}
label={isOllama ? "API Key (optional)" : "API Key"}
placeholder={isOllama ? "Optional for Ollama" : "API key"}
/>
</div>
<ConnectFormFooter
onCancel={onCancel}
onSubmit={handleSubmit}
canSubmit={canSubmit}
isPending={isPending}
<div className="flex flex-col gap-4">
<ApiBaseUrlField
value={baseUrl}
onChange={setBaseUrl}
optional={!baseUrlRequired}
placeholder={defaultBaseUrl}
/>
</>
<ApiKeyField
value={apiKey}
onChange={setApiKey}
label={isOllama ? "API Key (optional)" : "API Key"}
placeholder={isOllama ? "Optional for Ollama" : "API key"}
/>
</div>
);
}

View file

@ -1,4 +1,4 @@
import type { ModelRead } from "@/contracts/types/model-connections.types";
import type { ModelPreviewRead, ModelRead } from "@/contracts/types/model-connections.types";
export type ModelCapabilityFilter = "chat" | "vision" | "image_gen";
@ -8,17 +8,22 @@ export const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: stri
{ key: "image_gen", label: "Image" },
];
export function modelLabel(model: ModelRead) {
export type SelectableModel = (ModelRead | ModelPreviewRead) & {
id?: number | string;
connection_id?: number;
};
export function modelLabel(model: SelectableModel) {
return model.display_name || model.model_id;
}
export function capability(model: ModelRead, key: ModelCapabilityFilter) {
export function capability(model: SelectableModel, key: ModelCapabilityFilter) {
if (key === "chat") return Boolean(model.supports_chat);
if (key === "vision") return Boolean(model.supports_image_input);
return Boolean(model.supports_image_generation);
}
export function capabilityLabels(model: ModelRead) {
export function capabilityLabels(model: SelectableModel) {
return MODEL_CAPABILITY_FILTERS.filter((filter) => capability(model, filter.key))
.map((filter) => filter.label.toLowerCase())
.join(", ");

View file

@ -4,17 +4,17 @@ import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox";
import { Input } from "@/components/ui/input";
import type { ModelRead } from "@/contracts/types/model-connections.types";
import {
capability,
capabilityLabels,
MODEL_CAPABILITY_FILTERS,
type ModelCapabilityFilter,
modelLabel,
type SelectableModel,
} from "./model-utils";
interface ModelsSelectionPanelProps {
models: ModelRead[];
models: SelectableModel[];
description?: string;
emptyMessage?: string;
manualInputPlaceholder?: string;
@ -25,8 +25,8 @@ interface ModelsSelectionPanelProps {
isBulkUpdating?: boolean;
onRefresh?: () => void;
onAddManual?: (modelId: string) => void;
onToggleModel?: (model: ModelRead, enabled: boolean) => void;
onBulkToggle?: (models: ModelRead[], enabled: boolean) => void;
onToggleModel?: (model: SelectableModel, enabled: boolean) => void;
onBulkToggle?: (models: SelectableModel[], enabled: boolean) => void;
}
export function ModelsSelectionPanel({
@ -166,7 +166,7 @@ export function ModelsSelectionPanel({
<div className="space-y-2">
{filteredModels.map((model) => (
<div
key={model.id}
key={model.id ?? model.model_id}
className="flex items-center gap-3 rounded-lg px-3 py-2 transition-colors hover:bg-background"
>
<Checkbox

View file

@ -1,20 +1,18 @@
import { Button } from "@/components/ui/button";
import { useCallback, useState } from "react";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import type {
ConnectionRead,
ModelProviderRead,
ModelRead,
} from "@/contracts/types/model-connections.types";
import { Separator } from "@/components/ui/separator";
import type { ModelProviderRead } from "@/contracts/types/model-connections.types";
import { AzureConnectForm } from "./azure-connect-form";
import { BedrockConnectForm } from "./bedrock-connect-form";
import { ConnectFormFooter } from "./connect-fields";
import { DefaultConnectForm } from "./default-connect-form";
import type { SelectableModel } from "./model-utils";
import { ModelsSelectionPanel } from "./models-selection-panel";
import {
type ConnectionDraft,
@ -32,17 +30,12 @@ interface ProviderConnectDialogProps {
selectedProvider?: ModelProviderRead;
isPending: boolean;
onSubmit: (draft: ConnectionDraft) => void;
connectedConnection?: ConnectionRead | null;
connectModels?: ModelRead[];
isDiscoveringModels?: boolean;
isAddingManualModel?: boolean;
isUpdatingModel?: boolean;
isBulkUpdatingModels?: boolean;
onRefreshModels?: () => void;
onAddManualModel?: (modelId: string) => void;
onToggleModel?: (model: ModelRead, enabled: boolean) => void;
onBulkToggleModels?: (models: ModelRead[], enabled: boolean) => void;
onDone?: () => void;
previewModels?: SelectableModel[];
isPreviewingModels?: boolean;
onPreviewModels?: (draft: ConnectionDraft) => void;
onAddPreviewModel?: (modelId: string) => void;
onTogglePreviewModel?: (model: SelectableModel, enabled: boolean) => void;
onBulkTogglePreviewModels?: (models: SelectableModel[], enabled: boolean) => void;
}
/**
@ -57,97 +50,93 @@ export function ProviderConnectDialog({
selectedProvider,
isPending,
onSubmit,
connectedConnection,
connectModels = [],
isDiscoveringModels = false,
isAddingManualModel = false,
isUpdatingModel = false,
isBulkUpdatingModels = false,
onRefreshModels,
onAddManualModel,
onToggleModel,
onBulkToggleModels,
onDone,
previewModels = [],
isPreviewingModels = false,
onPreviewModels,
onAddPreviewModel,
onTogglePreviewModel,
onBulkTogglePreviewModels,
}: ProviderConnectDialogProps) {
const meta = providerDisplay(provider);
const isModelSelectionStep = Boolean(connectedConnection);
const isAzure = provider === "azure";
const isBedrock = provider === "bedrock";
const isVertex = provider === "vertex_ai";
const [currentDraft, setCurrentDraft] = useState<ConnectionDraft>({
base_url: null,
api_key: null,
extra: {},
});
const [canSubmit, setCanSubmit] = useState(false);
const handleDraftChange = useCallback((draft: ConnectionDraft, nextCanSubmit: boolean) => {
setCurrentDraft(draft);
setCanSubmit(nextCanSubmit);
}, []);
const formProps: ProviderConnectFormProps = {
provider,
defaultBaseUrl: providerDefaultBaseUrl(provider, selectedProvider?.default_base_url),
baseUrlRequired: Boolean(selectedProvider?.base_url_required),
isPending,
onCancel: () => onOpenChange(false),
onSubmit,
onDraftChange: handleDraftChange,
};
const modelDescription = (() => {
if (isAzure) {
return "Select the models to enable for Azure OpenAI";
}
if (isBedrock) {
return "Select the models to enable for Amazon Bedrock";
}
if (isVertex) {
return "Select the models to enable for Gemini";
}
return "Select the models to enable for this provider";
})();
const canRefreshModels = !isAzure && !isVertex && (!isBedrock || canSubmit);
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent
className={`flex max-h-[90vh] ${
isModelSelectionStep ? "max-w-2xl" : "max-w-xl"
} flex-col overflow-hidden bg-popover p-0 text-popover-foreground`}
>
<DialogContent className="flex h-[85vh] max-h-[760px] min-h-[640px] max-w-2xl flex-col overflow-hidden bg-popover p-0 text-popover-foreground">
<DialogHeader className="shrink-0 border-b px-6 py-5">
<div className="flex items-center gap-3">
{providerIcon(provider, "size-5")}
<div>
<DialogTitle>
{isModelSelectionStep ? `Select ${meta.name} models` : `Connect ${meta.name}`}
</DialogTitle>
<DialogDescription>
{isModelSelectionStep
? selectedProvider?.discovery === "static"
? "Choose from known model IDs or add one manually."
: "Choose which discovered models should be available in this search space."
: meta.subtitle}
</DialogDescription>
<DialogTitle>Connect {meta.name}</DialogTitle>
<DialogDescription>{meta.subtitle}</DialogDescription>
</div>
</div>
</DialogHeader>
{isModelSelectionStep ? (
<>
<div className="min-h-0 flex-1 overflow-y-auto px-6 py-5">
<ModelsSelectionPanel
models={connectModels}
description={
selectedProvider?.discovery === "static"
? "These are known model IDs for this provider. Select the ones to make available."
: "Select models to make available for this provider."
}
emptyMessage={
isDiscoveringModels
? "Discovering models..."
: "No models found. You can refresh discovery or add a model ID manually."
}
isRefreshing={isDiscoveringModels}
isAddingManual={isAddingManualModel}
isUpdatingModel={isUpdatingModel}
isBulkUpdating={isBulkUpdatingModels}
refreshLabel={`Refresh ${meta.name} models`}
onRefresh={onRefreshModels}
onAddManual={onAddManualModel}
onToggleModel={onToggleModel}
onBulkToggle={onBulkToggleModels}
/>
</div>
<DialogFooter className="shrink-0 border-t bg-popover px-6 py-4">
<Button onClick={onDone}>Done</Button>
</DialogFooter>
</>
) : (
<div className="overflow-y-auto px-6 py-5">
{provider === "azure" ? (
<AzureConnectForm {...formProps} />
) : provider === "bedrock" ? (
<BedrockConnectForm {...formProps} />
) : provider === "vertex_ai" ? (
<VertexConnectForm {...formProps} />
) : (
<DefaultConnectForm {...formProps} />
)}
</div>
)}
<div className="min-h-0 flex-1 space-y-5 overflow-y-auto px-6 py-5">
{provider === "azure" ? (
<AzureConnectForm {...formProps} />
) : provider === "bedrock" ? (
<BedrockConnectForm {...formProps} />
) : provider === "vertex_ai" ? (
<VertexConnectForm {...formProps} />
) : (
<DefaultConnectForm {...formProps} />
)}
<Separator className="bg-muted-foreground/20" />
<ModelsSelectionPanel
models={previewModels}
description={modelDescription}
isRefreshing={isPreviewingModels}
refreshLabel={`Refresh ${meta.name} models`}
onRefresh={canRefreshModels ? () => onPreviewModels?.(currentDraft) : undefined}
onAddManual={onAddPreviewModel}
onToggleModel={onTogglePreviewModel}
onBulkToggle={onBulkTogglePreviewModels}
/>
</div>
<ConnectFormFooter
onCancel={() => onOpenChange(false)}
onSubmit={() => onSubmit(currentDraft)}
canSubmit={canSubmit}
isPending={isPending}
/>
</DialogContent>
</Dialog>
);

View file

@ -133,7 +133,5 @@ export interface ProviderConnectFormProps {
provider: string;
defaultBaseUrl: string;
baseUrlRequired: boolean;
isPending: boolean;
onCancel: () => void;
onSubmit: (draft: ConnectionDraft) => void;
onDraftChange: (draft: ConnectionDraft, canSubmit: boolean) => void;
}

View file

@ -1,4 +1,4 @@
import { useState } from "react";
import { useEffect, useState } from "react";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import {
@ -8,7 +8,6 @@ import {
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { ConnectFormFooter } from "./connect-fields";
import {
type ProviderConnectFormProps,
VERTEX_AUTH_SERVICE_ACCOUNT,
@ -21,7 +20,7 @@ import {
* credentials JSON file (read into a string); workload identity collects a
* project id. Credentials ride along in `extra.litellm_params`.
*/
export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) {
export function VertexConnectForm({ onDraftChange }: ProviderConnectFormProps) {
const [authMethod, setAuthMethod] = useState(VERTEX_AUTH_SERVICE_ACCOUNT);
const [location, setLocation] = useState(VERTEX_DEFAULT_LOCATION);
const [credentials, setCredentials] = useState("");
@ -35,7 +34,7 @@ export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderCon
setCredentials(await file.text());
}
function handleSubmit() {
useEffect(() => {
const params: Record<string, string> = {};
if (location) params.vertex_location = location;
if (authMethod === VERTEX_AUTH_SERVICE_ACCOUNT) {
@ -43,85 +42,77 @@ export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderCon
} else if (project) {
params.vertex_project = project;
}
onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } });
}
onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit);
}, [authMethod, canSubmit, credentials, location, onDraftChange, project]);
return (
<>
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-2">
<Label>Authentication Method</Label>
<Select value={authMethod} onValueChange={setAuthMethod}>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value={VERTEX_AUTH_SERVICE_ACCOUNT}>Service Account JSON</SelectItem>
<SelectItem value={VERTEX_AUTH_WORKLOAD_IDENTITY}>Workload Identity (GKE)</SelectItem>
</SelectContent>
</Select>
</div>
<div className="flex flex-col gap-2">
<Label>Google Cloud Region Name</Label>
<Input
value={location}
onChange={(event) => setLocation(event.target.value)}
placeholder={VERTEX_DEFAULT_LOCATION}
/>
<p className="text-xs text-muted-foreground">
Region where your Google Vertex AI models are hosted.
</p>
</div>
{authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? (
<div className="flex flex-col gap-2">
<Label>Service Account JSON</Label>
<Input
id="vertex-service-account-json"
type="file"
accept="application/json,.json"
className="sr-only"
onChange={(event) => handleCredentialsFile(event.target.files?.[0])}
/>
<Label
htmlFor="vertex-service-account-json"
className="flex min-h-28 cursor-pointer flex-col items-center justify-center gap-2 rounded-lg border-2 border-dashed border-muted-foreground/40 bg-muted/20 px-4 py-6 text-center transition-colors hover:border-muted-foreground/70 hover:bg-muted/40"
>
<span className="text-sm font-medium">
{credentials ? "Service account JSON selected" : "Upload service account JSON"}
</span>
<span className="text-xs text-muted-foreground">
Choose a .json file from Google Cloud
</span>
</Label>
<p className="text-xs text-muted-foreground">
{credentials
? "Credentials file loaded."
: "Attach your service account key JSON from Google Cloud."}
</p>
</div>
) : (
<div className="flex flex-col gap-2">
<Label>GCP Project ID</Label>
<Input
value={project}
onChange={(event) => setProject(event.target.value)}
placeholder="my-vertex-project"
/>
<p className="text-xs text-muted-foreground">
The GCP project where Vertex AI is enabled.
</p>
</div>
)}
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-2">
<Label>Authentication Method</Label>
<Select value={authMethod} onValueChange={setAuthMethod}>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value={VERTEX_AUTH_SERVICE_ACCOUNT}>Service Account JSON</SelectItem>
<SelectItem value={VERTEX_AUTH_WORKLOAD_IDENTITY}>Workload Identity (GKE)</SelectItem>
</SelectContent>
</Select>
</div>
<div className="flex flex-col gap-2">
<Label>Google Cloud Region Name</Label>
<Input
value={location}
onChange={(event) => setLocation(event.target.value)}
placeholder={VERTEX_DEFAULT_LOCATION}
/>
<p className="text-xs text-muted-foreground">
Add Vertex AI model IDs from the provider&apos;s settings after connecting.
Region where your Google Vertex AI models are hosted.
</p>
</div>
<ConnectFormFooter
onCancel={onCancel}
onSubmit={handleSubmit}
canSubmit={canSubmit}
isPending={isPending}
/>
</>
{authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? (
<div className="flex flex-col gap-2">
<Label>Service Account JSON</Label>
<Input
id="vertex-service-account-json"
type="file"
accept="application/json,.json"
className="sr-only"
onChange={(event) => handleCredentialsFile(event.target.files?.[0])}
/>
<Label
htmlFor="vertex-service-account-json"
className="flex min-h-28 cursor-pointer flex-col items-center justify-center gap-2 rounded-lg border-2 border-dashed border-muted-foreground/40 bg-muted/20 px-4 py-6 text-center transition-colors hover:border-muted-foreground/70 hover:bg-muted/40"
>
<span className="text-sm font-medium">
{credentials ? "Service account JSON selected" : "Upload service account JSON"}
</span>
<span className="text-xs text-muted-foreground">
Choose a .json file from Google Cloud
</span>
</Label>
<p className="text-xs text-muted-foreground">
{credentials
? "Credentials file loaded."
: "Attach your service account key JSON from Google Cloud."}
</p>
</div>
) : (
<div className="flex flex-col gap-2">
<Label>GCP Project ID</Label>
<Input
value={project}
onChange={(event) => setProject(event.target.value)}
placeholder="my-vertex-project"
/>
<p className="text-xs text-muted-foreground">
The GCP project where Vertex AI is enabled.
</p>
</div>
)}
<p className="text-xs text-muted-foreground">
Add Vertex AI model IDs from the provider&apos;s settings after connecting.
</p>
</div>
);
}

View file

@ -40,6 +40,21 @@ export const connectionRead = z.object({
created_at: z.string().nullable().optional(),
});
export const modelSelection = z.object({
model_id: z.string().min(1),
display_name: z.string().nullable().optional(),
source: z.union([modelSourceEnum, z.string()]).default("DISCOVERED"),
supports_chat: z.boolean().nullable().optional(),
max_input_tokens: z.number().nullable().optional(),
supports_image_input: z.boolean().nullable().optional(),
supports_tools: z.boolean().nullable().optional(),
supports_image_generation: z.boolean().nullable().optional(),
enabled: z.boolean().default(false),
metadata: z.record(z.string(), z.any()).default({}),
});
export const modelPreviewRead = modelSelection;
export const connectionCreateRequest = z.object({
provider: z.string().min(1),
base_url: z.string().nullable().optional(),
@ -48,6 +63,7 @@ export const connectionCreateRequest = z.object({
scope: connectionScopeEnum.default("SEARCH_SPACE"),
search_space_id: z.number().nullable().optional(),
enabled: z.boolean().default(true),
models: z.array(modelSelection).default([]),
});
export const connectionUpdateRequest = z.object({
@ -105,9 +121,12 @@ export const modelProviderListResponse = z.array(modelProviderRead);
export const connectionListResponse = z.array(connectionRead);
export const modelListResponse = z.array(modelRead);
export const modelPreviewListResponse = z.array(modelPreviewRead);
export type ConnectionScope = z.infer<typeof connectionScopeEnum>;
export type ModelRead = z.infer<typeof modelRead>;
export type ModelPreviewRead = z.infer<typeof modelPreviewRead>;
export type ModelSelection = z.infer<typeof modelSelection>;
export type ConnectionRead = z.infer<typeof connectionRead>;
export type ConnectionCreateRequest = z.infer<typeof connectionCreateRequest>;
export type ConnectionUpdateRequest = z.infer<typeof connectionUpdateRequest>;

View file

@ -7,6 +7,7 @@ import {
connectionRead,
connectionUpdateRequest,
type ModelCreateRequest,
type ModelPreviewRead,
type ModelProviderRead,
type ModelRead,
type ModelRoles,
@ -14,6 +15,7 @@ import {
type ModelUpdateRequest,
modelCreateRequest,
modelListResponse,
modelPreviewListResponse,
modelProviderListResponse,
modelRead,
modelRoles,
@ -76,6 +78,20 @@ class ModelConnectionsApiService {
return baseApiService.post(`/api/v1/model-connections/${id}/discover`, modelListResponse);
};
previewModels = async (request: ConnectionCreateRequest): Promise<ModelPreviewRead[]> => {
const parsed = connectionCreateRequest.safeParse(request);
if (!parsed.success) {
throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", "));
}
return baseApiService.post(
`/api/v1/model-connections/discover-preview`,
modelPreviewListResponse,
{
body: parsed.data,
}
);
};
addManualModel = async (
connectionId: number,
request: ModelCreateRequest