feat(model-connections): enhance model connection functionality with preview and selection features

This commit is contained in:
Anish Sarkar 2026-06-12 22:41:21 +05:30
parent 356f0e56c5
commit 407f2a9612
20 changed files with 630 additions and 429 deletions

View file

@ -21,10 +21,12 @@ from app.schemas import (
ConnectionRead, ConnectionRead,
ConnectionUpdate, ConnectionUpdate,
ModelCreate, ModelCreate,
ModelPreviewRead,
ModelProviderRead, ModelProviderRead,
ModelRead, ModelRead,
ModelRolesRead, ModelRolesRead,
ModelRolesUpdate, ModelRolesUpdate,
ModelSelection,
ModelsBulkUpdate, ModelsBulkUpdate,
ModelUpdate, ModelUpdate,
VerifyConnectionResponse, VerifyConnectionResponse,
@ -48,6 +50,21 @@ def _model_read(model: Model | dict) -> ModelRead:
return ModelRead.model_validate(model) return ModelRead.model_validate(model)
def _preview_model_read(item: dict) -> ModelPreviewRead:
return ModelPreviewRead(
model_id=item["model_id"],
display_name=item.get("display_name"),
source=item.get("source", ModelSource.DISCOVERED),
supports_chat=item.get("supports_chat"),
max_input_tokens=item.get("max_input_tokens"),
supports_image_input=item.get("supports_image_input"),
supports_tools=item.get("supports_tools"),
supports_image_generation=item.get("supports_image_generation"),
enabled=item.get("enabled", False),
metadata=item.get("metadata") or item.get("catalog") or {},
)
def _connection_read(conn: Connection | dict, models: list[Model | dict] | None = None) -> ConnectionRead: def _connection_read(conn: Connection | dict, models: list[Model | dict] | None = None) -> ConnectionRead:
if isinstance(conn, dict): if isinstance(conn, dict):
payload = { payload = {
@ -86,6 +103,25 @@ def _apply_model_facts(model: Model, facts: dict) -> None:
model.supports_image_generation = facts.get("supports_image_generation") model.supports_image_generation = facts.get("supports_image_generation")
def _selection_to_model(conn: Connection, selection: ModelSelection) -> Model:
source = (
selection.source
if isinstance(selection.source, ModelSource)
else ModelSource(selection.source)
)
model = Model(
connection_id=conn.id,
model_id=selection.model_id.strip(),
display_name=selection.display_name,
source=source,
capabilities_override={},
enabled=selection.enabled,
catalog=selection.metadata,
)
_apply_model_facts(model, selection.model_dump())
return model
def _default_model_for(models: list[Model], capability: str) -> int | None: def _default_model_for(models: list[Model], capability: str) -> int | None:
for model in models: for model in models:
if model.enabled and has_capability(model, capability): if model.enabled and has_capability(model, capability):
@ -226,7 +262,7 @@ async def create_connection(
Permission.LLM_CONFIGS_CREATE.value, Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space", "You don't have permission to create model connections in this search space",
) )
payload = data.model_dump(exclude={"search_space_id"}) payload = data.model_dump(exclude={"search_space_id", "models"})
conn = Connection( conn = Connection(
**payload, **payload,
@ -234,9 +270,51 @@ async def create_connection(
user_id=user.id, user_id=user.id,
) )
session.add(conn) session.add(conn)
await session.flush()
seen_model_ids: set[str] = set()
for selection in data.models:
model_id = selection.model_id.strip()
if not model_id or model_id in seen_model_ids:
continue
seen_model_ids.add(model_id)
session.add(_selection_to_model(conn, selection))
await session.commit() await session.commit()
await session.refresh(conn) conn = await _load_connection(session, conn.id)
return _connection_read(conn, []) await _default_unset_roles(session, conn, list(conn.models))
await session.commit()
conn = await _load_connection(session, conn.id)
return _connection_read(conn, list(conn.models))
@router.post("/model-connections/discover-preview", response_model=list[ModelPreviewRead])
async def preview_connection_models(
data: ConnectionCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
await check_permission(
session,
user,
data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
draft = Connection(
provider=data.provider,
base_url=data.base_url,
api_key=data.api_key,
extra=data.extra or {},
scope=data.scope,
enabled=data.enabled,
search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None,
user_id=user.id,
)
discovered = await discover_models(draft)
return [_preview_model_read(item) for item in discovered]
@router.put("/model-connections/{connection_id}", response_model=ConnectionRead) @router.put("/model-connections/{connection_id}", response_model=ConnectionRead)

View file

@ -49,10 +49,12 @@ from .model_connections import (
ConnectionRead, ConnectionRead,
ConnectionUpdate, ConnectionUpdate,
ModelCreate, ModelCreate,
ModelPreviewRead,
ModelProviderRead, ModelProviderRead,
ModelRead, ModelRead,
ModelRolesRead, ModelRolesRead,
ModelRolesUpdate, ModelRolesUpdate,
ModelSelection,
ModelsBulkUpdate, ModelsBulkUpdate,
ModelUpdate, ModelUpdate,
VerifyConnectionResponse, VerifyConnectionResponse,

View file

@ -48,6 +48,32 @@ class ConnectionRead(BaseModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class ModelSelection(BaseModel):
model_id: str = Field(..., max_length=255)
display_name: str | None = Field(None, max_length=255)
source: ModelSource | str = ModelSource.DISCOVERED
supports_chat: bool | None = None
max_input_tokens: int | None = None
supports_image_input: bool | None = None
supports_tools: bool | None = None
supports_image_generation: bool | None = None
enabled: bool = False
metadata: dict[str, Any] = Field(default_factory=dict)
class ModelPreviewRead(BaseModel):
model_id: str
display_name: str | None = None
source: ModelSource | str = ModelSource.DISCOVERED
supports_chat: bool | None = None
max_input_tokens: int | None = None
supports_image_input: bool | None = None
supports_tools: bool | None = None
supports_image_generation: bool | None = None
enabled: bool = False
metadata: dict[str, Any] = Field(default_factory=dict)
class ConnectionCreate(BaseModel): class ConnectionCreate(BaseModel):
provider: str = Field(..., max_length=100) provider: str = Field(..., max_length=100)
base_url: str | None = Field(None, max_length=500) base_url: str | None = Field(None, max_length=500)
@ -56,6 +82,7 @@ class ConnectionCreate(BaseModel):
scope: ConnectionScope = ConnectionScope.SEARCH_SPACE scope: ConnectionScope = ConnectionScope.SEARCH_SPACE
search_space_id: int | None = None search_space_id: int | None = None
enabled: bool = True enabled: bool = True
models: list[ModelSelection] = Field(default_factory=list)
class ConnectionUpdate(BaseModel): class ConnectionUpdate(BaseModel):

View file

@ -8,6 +8,7 @@ from dataclasses import dataclass
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
import anyio
import httpx import httpx
import litellm import litellm
@ -292,6 +293,48 @@ def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]:
return results return results
async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]:
params = (conn.extra or {}).get("litellm_params", {})
region_name = params.get("aws_region_name")
if not region_name:
return []
def list_models() -> list[dict[str, Any]]:
import boto3
client_kwargs: dict[str, str] = {"region_name": region_name}
if params.get("aws_access_key_id"):
client_kwargs["aws_access_key_id"] = params["aws_access_key_id"]
if params.get("aws_secret_access_key"):
client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"]
client = boto3.client("bedrock", **client_kwargs)
response = client.list_foundation_models()
results: list[dict[str, Any]] = []
for item in response.get("modelSummaries", []):
model_id = item.get("modelId")
if not model_id:
continue
input_modalities = set(item.get("inputModalities") or [])
output_modalities = set(item.get("outputModalities") or [])
results.append(
{
"model_id": model_id,
"display_name": item.get("modelName") or model_id,
"source": ModelSource.DISCOVERED,
"supports_chat": "TEXT" in input_modalities and "TEXT" in output_modalities,
"supports_image_input": "IMAGE" in input_modalities,
"supports_tools": None,
"supports_image_generation": "IMAGE" in output_modalities,
"max_input_tokens": None,
"metadata": item,
}
)
return results
return await anyio.to_thread.run_sync(list_models)
async def discover_models(conn: Connection) -> list[dict[str, Any]]: async def discover_models(conn: Connection) -> list[dict[str, Any]]:
allowlist = _allowlist(conn) allowlist = _allowlist(conn)
spec = spec_for(conn.provider) spec = spec_for(conn.provider)
@ -304,6 +347,8 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]:
results = await _discover_anthropic_models(conn) results = await _discover_anthropic_models(conn)
elif spec.discovery == "openai_models": elif spec.discovery == "openai_models":
results = await _discover_openai_shaped_models(conn, conn.base_url) results = await _discover_openai_shaped_models(conn, conn.base_url)
elif spec.discovery == "bedrock_models":
results = await _discover_bedrock_models(conn)
elif spec.discovery == "static": elif spec.discovery == "static":
results = _litellm_static_models(conn) results = _litellm_static_models(conn)
else: else:

View file

@ -21,6 +21,7 @@ DiscoveryKind = Literal[
"ollama", "ollama",
"openai_models", "openai_models",
"anthropic_models", "anthropic_models",
"bedrock_models",
"openrouter", "openrouter",
"static", "static",
"none", "none",
@ -51,7 +52,7 @@ REGISTRY: dict[str, ProviderSpec] = {
Transport.NATIVE, "vertex_ai", "static", None, False, "native" Transport.NATIVE, "vertex_ai", "static", None, False, "native"
), ),
"bedrock": ProviderSpec( "bedrock": ProviderSpec(
Transport.NATIVE, "bedrock", "static", None, False, "native" Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native"
), ),
"openrouter": ProviderSpec( "openrouter": ProviderSpec(
Transport.OPENAI_COMPATIBLE, Transport.OPENAI_COMPATIBLE,

View file

@ -4,6 +4,7 @@ import type {
ConnectionCreateRequest, ConnectionCreateRequest,
ConnectionUpdateRequest, ConnectionUpdateRequest,
ModelCreateRequest, ModelCreateRequest,
ModelPreviewRead,
ModelRead, ModelRead,
ModelRoles, ModelRoles,
ModelsBulkUpdateRequest, ModelsBulkUpdateRequest,
@ -103,6 +104,16 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => {
}; };
}); });
export const previewConnectionModelsMutationAtom = atomWithMutation(() => {
return {
mutationKey: ["model-connections", "discover-preview"],
mutationFn: (request: ConnectionCreateRequest) =>
modelConnectionsApiService.previewModels(request),
onSuccess: (_models: ModelPreviewRead[]) => {},
onError: (error: Error) => toast.error(error.message || "Failed to discover models"),
};
});
export const addManualModelMutationAtom = atomWithMutation((get) => { export const addManualModelMutationAtom = atomWithMutation((get) => {
const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); const searchSpaceId = Number(get(activeSearchSpaceIdAtom));
return { return {

View file

@ -2,7 +2,7 @@
import { useQueryClient } from "@tanstack/react-query"; import { useQueryClient } from "@tanstack/react-query";
import { useAtom, useAtomValue } from "jotai"; import { useAtom, useAtomValue } from "jotai";
import { RefreshCcw, Workflow } from "lucide-react"; import { RefreshCw, Workflow } from "lucide-react";
import { useCallback } from "react"; import { useCallback } from "react";
import { actionLogDialogAtom } from "@/atoms/agent/action-log-dialog.atom"; import { actionLogDialogAtom } from "@/atoms/agent/action-log-dialog.atom";
import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom";
@ -112,7 +112,7 @@ export function ActionLogDialog() {
className="absolute right-14 top-4 size-8 rounded-full p-0 text-muted-foreground hover:bg-accent hover:text-accent-foreground" className="absolute right-14 top-4 size-8 rounded-full p-0 text-muted-foreground hover:bg-accent hover:text-accent-foreground"
aria-label="Refresh action log" aria-label="Refresh action log"
> >
<RefreshCcw className={isFetching ? "size-3.5 animate-spin" : "size-3.5"} /> <RefreshCw className={isFetching ? "size-3.5 animate-spin" : "size-3.5"} />
</Button> </Button>
<div className="flex min-h-0 flex-1 flex-col overflow-y-auto scrollbar-thin"> <div className="flex min-h-0 flex-1 flex-col overflow-y-auto scrollbar-thin">

View file

@ -4,12 +4,9 @@ import { useAtom, useAtomValue } from "jotai";
import { CheckCircle2, Trash2, XCircle } from "lucide-react"; import { CheckCircle2, Trash2, XCircle } from "lucide-react";
import { useState } from "react"; import { useState } from "react";
import { import {
addManualModelMutationAtom,
bulkUpdateModelsMutationAtom,
createModelConnectionMutationAtom, createModelConnectionMutationAtom,
deleteModelConnectionMutationAtom, deleteModelConnectionMutationAtom,
discoverConnectionModelsMutationAtom, previewConnectionModelsMutationAtom,
updateModelMutationAtom,
updateModelRolesMutationAtom, updateModelRolesMutationAtom,
} from "@/atoms/model-connections/model-connections-mutation.atoms"; } from "@/atoms/model-connections/model-connections-mutation.atoms";
import { import {
@ -41,9 +38,13 @@ import {
SelectValue, SelectValue,
} from "@/components/ui/select"; } from "@/components/ui/select";
import { Separator } from "@/components/ui/separator"; import { Separator } from "@/components/ui/separator";
import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; import type {
ConnectionRead,
ModelRead,
ModelSelection,
} from "@/contracts/types/model-connections.types";
import { ConnectionSettingsDialog } from "./model-connections/connection-settings-dialog"; import { ConnectionSettingsDialog } from "./model-connections/connection-settings-dialog";
import { capability, modelLabel } from "./model-connections/model-utils"; import { capability, modelLabel, type SelectableModel } from "./model-connections/model-utils";
import { ProviderConnectDialog } from "./model-connections/provider-connect-dialog"; import { ProviderConnectDialog } from "./model-connections/provider-connect-dialog";
import { import {
type ConnectionDraft, type ConnectionDraft,
@ -154,16 +155,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
const [{ data: providers = [] }] = useAtom(modelProvidersAtom); const [{ data: providers = [] }] = useAtom(modelProvidersAtom);
const [{ data: roles }] = useAtom(modelRolesAtom); const [{ data: roles }] = useAtom(modelRolesAtom);
const createConnection = useAtomValue(createModelConnectionMutationAtom); const createConnection = useAtomValue(createModelConnectionMutationAtom);
const addManualModel = useAtomValue(addManualModelMutationAtom); const previewModels = useAtomValue(previewConnectionModelsMutationAtom);
const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom);
const updateModel = useAtomValue(updateModelMutationAtom);
const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom);
const updateRoles = useAtomValue(updateModelRolesMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom);
const [isAddProviderOpen, setIsAddProviderOpen] = useState(false); const [isAddProviderOpen, setIsAddProviderOpen] = useState(false);
const [provider, setProvider] = useState("openai_compatible"); const [provider, setProvider] = useState("openai_compatible");
const [connectedConnection, setConnectedConnection] = useState<ConnectionRead | null>(null); const [connectModels, setConnectModels] = useState<ModelSelection[]>([]);
const [connectModels, setConnectModels] = useState<ModelRead[]>([]);
const selectedProvider = providers.find((item) => item.provider === provider); const selectedProvider = providers.find((item) => item.provider === provider);
const sortedProviders = [...providers].sort((left, right) => { const sortedProviders = [...providers].sort((left, right) => {
@ -185,7 +182,6 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); const imageModels = enabledModels.filter((model) => capability(model, "image_gen"));
function resetConnectState() { function resetConnectState() {
setConnectedConnection(null);
setConnectModels([]); setConnectModels([]);
} }
@ -196,15 +192,48 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
} }
} }
function replaceConnectModels(updatedModels: ModelRead[]) { function toModelSelection(model: SelectableModel): ModelSelection {
setConnectModels((current) => return {
current.map((model) => updatedModels.find((updated) => updated.id === model.id) ?? model) model_id: model.model_id,
); display_name: model.display_name,
source: model.source || "DISCOVERED",
supports_chat: model.supports_chat,
max_input_tokens: model.max_input_tokens,
supports_image_input: model.supports_image_input,
supports_tools: model.supports_tools,
supports_image_generation: model.supports_image_generation,
enabled: model.enabled,
metadata: "metadata" in model ? (model.metadata ?? {}) : (model.catalog ?? {}),
};
}
function mergePreviewModels(fetchedModels: SelectableModel[]) {
setConnectModels((current) => {
const currentById = new Map(current.map((model) => [model.model_id, model]));
return fetchedModels.map((model) => {
const prior = currentById.get(model.model_id);
return {
...toModelSelection(model),
enabled: prior ? prior.enabled : model.enabled,
};
});
});
} }
// Each provider connect form builds its own credential payload; the backend // Each provider connect form builds its own credential payload; the backend
// resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM. // resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM.
function handleCreate(draft: ConnectionDraft) { function handleCreate(draft: ConnectionDraft) {
const models = [...connectModels];
if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) {
models.push({
model_id: draft.seedModelId,
display_name: draft.seedModelId,
source: "MANUAL",
enabled: true,
metadata: {},
});
}
createConnection.mutate( createConnection.mutate(
{ {
provider, provider,
@ -214,26 +243,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
search_space_id: searchSpaceId, search_space_id: searchSpaceId,
extra: draft.extra, extra: draft.extra,
enabled: true, enabled: true,
models,
}, },
{ {
onSuccess: (created) => { onSuccess: () => {
setConnectedConnection(created); setIsAddProviderOpen(false);
setConnectModels([]); resetConnectState();
if (draft.seedModelId) {
addManualModel.mutate(
{
connectionId: created.id,
data: { model_id: draft.seedModelId },
},
{
onSuccess: (model) => setConnectModels([model]),
}
);
} else {
discoverModels.mutate(created.id, {
onSuccess: (models) => setConnectModels(models),
});
}
}, },
} }
); );
@ -243,52 +258,72 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
resetConnectState(); resetConnectState();
setProvider(providerId); setProvider(providerId);
setIsAddProviderOpen(true); setIsAddProviderOpen(true);
if (providerId === "vertex_ai") {
previewModels.mutate(
{
provider: providerId,
base_url: null,
api_key: null,
scope: "SEARCH_SPACE",
search_space_id: searchSpaceId,
extra: {},
enabled: true,
models: [],
},
{
onSuccess: mergePreviewModels,
}
);
}
} }
function refreshConnectModels() { function refreshConnectModels(draft: ConnectionDraft) {
if (!connectedConnection) return; previewModels.mutate(
discoverModels.mutate(connectedConnection.id, { {
onSuccess: (models) => setConnectModels(models), provider,
}); base_url: draft.base_url,
api_key: draft.api_key,
scope: "SEARCH_SPACE",
search_space_id: searchSpaceId,
extra: draft.extra,
enabled: true,
models: [],
},
{
onSuccess: mergePreviewModels,
}
);
} }
function addConnectModel(modelId: string) { function addConnectModel(modelId: string) {
if (!connectedConnection) return; setConnectModels((current) => {
addManualModel.mutate( if (current.some((model) => model.model_id === modelId)) return current;
{ connectionId: connectedConnection.id, data: { model_id: modelId } }, return [
{ ...current,
onSuccess: (model) => setConnectModels((current) => [...current, model]), {
} model_id: modelId,
display_name: modelId,
source: "MANUAL",
enabled: true,
metadata: {},
},
];
});
}
function toggleConnectModel(model: SelectableModel, enabled: boolean) {
setConnectModels((current) =>
current.map((item) => (item.model_id === model.model_id ? { ...item, enabled } : item))
); );
} }
function toggleConnectModel(model: ModelRead, enabled: boolean) { function bulkToggleConnectModels(models: SelectableModel[], enabled: boolean) {
updateModel.mutate( const modelIds = new Set(models.map((model) => model.model_id));
{ id: model.id, data: { enabled } }, setConnectModels((current) =>
{ current.map((item) => (modelIds.has(item.model_id) ? { ...item, enabled } : item))
onSuccess: (updated) => replaceConnectModels([updated]),
}
); );
} }
function bulkToggleConnectModels(models: ModelRead[], enabled: boolean) {
if (!connectedConnection) return;
bulkUpdateModels.mutate(
{
connectionId: connectedConnection.id,
data: { model_ids: models.map((model) => model.id), enabled },
},
{
onSuccess: replaceConnectModels,
}
);
}
function finishConnectFlow() {
setIsAddProviderOpen(false);
resetConnectState();
}
function renderModelOption(model: ModelRead & { connectionName: string; provider: string }) { function renderModelOption(model: ModelRead & { connectionName: string; provider: string }) {
return ( return (
<SelectItem key={model.id} value={String(model.id)}> <SelectItem key={model.id} value={String(model.id)}>
@ -347,17 +382,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
selectedProvider={selectedProvider} selectedProvider={selectedProvider}
isPending={createConnection.isPending} isPending={createConnection.isPending}
onSubmit={handleCreate} onSubmit={handleCreate}
connectedConnection={connectedConnection} previewModels={connectModels}
connectModels={connectModels} isPreviewingModels={previewModels.isPending}
isDiscoveringModels={discoverModels.isPending} onPreviewModels={refreshConnectModels}
isAddingManualModel={addManualModel.isPending} onAddPreviewModel={addConnectModel}
isUpdatingModel={updateModel.isPending} onTogglePreviewModel={toggleConnectModel}
isBulkUpdatingModels={bulkUpdateModels.isPending} onBulkTogglePreviewModels={bulkToggleConnectModels}
onRefreshModels={refreshConnectModels}
onAddManualModel={addConnectModel}
onToggleModel={toggleConnectModel}
onBulkToggleModels={bulkToggleConnectModels}
onDone={finishConnectFlow}
/> />
{connections.length > 0 ? ( {connections.length > 0 ? (

View file

@ -1,7 +1,7 @@
import { useState } from "react"; import { useEffect, useState } from "react";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label"; import { Label } from "@/components/ui/label";
import { ApiKeyField, ConnectFormFooter } from "./connect-fields"; import { ApiKeyField } from "./connect-fields";
import { import {
isValidAzureTargetUri, isValidAzureTargetUri,
type ProviderConnectFormProps, type ProviderConnectFormProps,
@ -12,48 +12,43 @@ import {
* Azure OpenAI connect form. The user pastes a single Target URI, which we parse * Azure OpenAI connect form. The user pastes a single Target URI, which we parse
* into api base, api version, and the deployment name (seeded as the model). * into api base, api version, and the deployment name (seeded as the model).
*/ */
export function AzureConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { export function AzureConnectForm({ onDraftChange }: ProviderConnectFormProps) {
const [targetUri, setTargetUri] = useState(""); const [targetUri, setTargetUri] = useState("");
const [apiKey, setApiKey] = useState(""); const [apiKey, setApiKey] = useState("");
const canSubmit = isValidAzureTargetUri(targetUri) && Boolean(apiKey.trim()); const canSubmit = isValidAzureTargetUri(targetUri) && Boolean(apiKey.trim());
function handleSubmit() { useEffect(() => {
const parsed = parseAzureTargetUri(targetUri); const parsed = parseAzureTargetUri(targetUri);
onSubmit({ onDraftChange(
base_url: parsed?.origin ?? null, {
api_key: apiKey || null, base_url: parsed?.origin ?? null,
extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {}, api_key: apiKey || null,
seedModelId: parsed?.deploymentName || undefined, extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {},
}); seedModelId: parsed?.deploymentName || undefined,
} },
canSubmit
);
}, [apiKey, canSubmit, onDraftChange, targetUri]);
return ( return (
<> <div className="flex flex-col gap-4">
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-2">
<div className="flex flex-col gap-2"> <Label>Target URI</Label>
<Label>Target URI</Label> <Input
<Input value={targetUri}
value={targetUri} onChange={(event) => setTargetUri(event.target.value)}
onChange={(event) => setTargetUri(event.target.value)} placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
/>
<p className="text-xs text-muted-foreground">
Paste your endpoint target URI from Azure OpenAI (including API base, deployment name,
and API version).
</p>
</div>
<ApiKeyField
value={apiKey}
onChange={setApiKey}
placeholder="Paste your API key from Azure"
/> />
<p className="text-xs text-muted-foreground">
Paste your endpoint target URI from Azure OpenAI (including API base, deployment name, and
API version).
</p>
</div> </div>
<ConnectFormFooter <ApiKeyField
onCancel={onCancel} value={apiKey}
onSubmit={handleSubmit} onChange={setApiKey}
canSubmit={canSubmit} placeholder="Paste your API key from Azure"
isPending={isPending}
/> />
</> </div>
); );
} }

View file

@ -1,4 +1,4 @@
import { useState } from "react"; import { useEffect, useState } from "react";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label"; import { Label } from "@/components/ui/label";
import { import {
@ -8,7 +8,7 @@ import {
SelectTrigger, SelectTrigger,
SelectValue, SelectValue,
} from "@/components/ui/select"; } from "@/components/ui/select";
import { ConnectFormFooter } from "./connect-fields"; import { ApiKeyField } from "./connect-fields";
import { import {
AWS_REGION_OPTIONS, AWS_REGION_OPTIONS,
BEDROCK_AUTH_ACCESS_KEY, BEDROCK_AUTH_ACCESS_KEY,
@ -21,7 +21,7 @@ import {
* Amazon Bedrock connect form. Region + auth method drive which AWS credentials * Amazon Bedrock connect form. Region + auth method drive which AWS credentials
* are collected; everything rides along in `extra.litellm_params`. * are collected; everything rides along in `extra.litellm_params`.
*/ */
export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { export function BedrockConnectForm({ onDraftChange }: ProviderConnectFormProps) {
const [region, setRegion] = useState(""); const [region, setRegion] = useState("");
const [authMethod, setAuthMethod] = useState(BEDROCK_AUTH_ACCESS_KEY); const [authMethod, setAuthMethod] = useState(BEDROCK_AUTH_ACCESS_KEY);
const [accessKeyId, setAccessKeyId] = useState(""); const [accessKeyId, setAccessKeyId] = useState("");
@ -39,7 +39,7 @@ export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderCo
return true; return true;
})(); })();
function handleSubmit() { useEffect(() => {
const params: Record<string, string> = { aws_region_name: region }; const params: Record<string, string> = { aws_region_name: region };
if (authMethod === BEDROCK_AUTH_ACCESS_KEY) { if (authMethod === BEDROCK_AUTH_ACCESS_KEY) {
params.aws_access_key_id = accessKeyId; params.aws_access_key_id = accessKeyId;
@ -47,88 +47,74 @@ export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderCo
} else if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) { } else if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) {
params.aws_bearer_token_bedrock = bearerToken; params.aws_bearer_token_bedrock = bearerToken;
} }
onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } }); onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit);
} }, [accessKeyId, authMethod, bearerToken, canSubmit, onDraftChange, region, secretAccessKey]);
return ( return (
<> <div className="flex flex-col gap-4">
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-2">
<div className="flex flex-col gap-2"> <Label>AWS Region</Label>
<Label>AWS Region</Label> <Select value={region || undefined} onValueChange={setRegion}>
<Select value={region || undefined} onValueChange={setRegion}> <SelectTrigger>
<SelectTrigger> <SelectValue placeholder="Select a region" />
<SelectValue placeholder="Select a region" /> </SelectTrigger>
</SelectTrigger> <SelectContent>
<SelectContent> {AWS_REGION_OPTIONS.map((option) => (
{AWS_REGION_OPTIONS.map((option) => ( <SelectItem key={option} value={option}>
<SelectItem key={option} value={option}> {option}
{option} </SelectItem>
</SelectItem> ))}
))} </SelectContent>
</SelectContent> </Select>
</Select> </div>
</div> <div className="flex flex-col gap-2">
<div className="flex flex-col gap-2"> <Label>Authentication Method</Label>
<Label>Authentication Method</Label> <Select value={authMethod} onValueChange={setAuthMethod}>
<Select value={authMethod} onValueChange={setAuthMethod}> <SelectTrigger>
<SelectTrigger> <SelectValue />
<SelectValue /> </SelectTrigger>
</SelectTrigger> <SelectContent>
<SelectContent> <SelectItem value={BEDROCK_AUTH_IAM}>Environment IAM Role</SelectItem>
<SelectItem value={BEDROCK_AUTH_IAM}>Environment IAM Role</SelectItem> <SelectItem value={BEDROCK_AUTH_ACCESS_KEY}>Access Key</SelectItem>
<SelectItem value={BEDROCK_AUTH_ACCESS_KEY}>Access Key</SelectItem> <SelectItem value={BEDROCK_AUTH_LONG_TERM_API_KEY}>Long-term API Key</SelectItem>
<SelectItem value={BEDROCK_AUTH_LONG_TERM_API_KEY}>Long-term API Key</SelectItem> </SelectContent>
</SelectContent> </Select>
</Select> </div>
</div> {authMethod === BEDROCK_AUTH_ACCESS_KEY ? (
{authMethod === BEDROCK_AUTH_ACCESS_KEY ? ( <>
<>
<div className="flex flex-col gap-2">
<Label>AWS Access Key ID</Label>
<Input
value={accessKeyId}
onChange={(event) => setAccessKeyId(event.target.value)}
placeholder="AKIAIOSFODNN7EXAMPLE"
/>
</div>
<div className="flex flex-col gap-2">
<Label>AWS Secret Access Key</Label>
<Input
value={secretAccessKey}
onChange={(event) => setSecretAccessKey(event.target.value)}
placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
type="password"
/>
</div>
</>
) : null}
{authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? (
<div className="flex flex-col gap-2"> <div className="flex flex-col gap-2">
<Label>Long-term API Key</Label> <Label>AWS Access Key ID</Label>
<Input <Input
value={bearerToken} value={accessKeyId}
onChange={(event) => setBearerToken(event.target.value)} onChange={(event) => setAccessKeyId(event.target.value)}
placeholder="Your long-term API key" placeholder="AKIAIOSFODNN7EXAMPLE"
type="password"
/> />
</div> </div>
) : null} <ApiKeyField
{authMethod === BEDROCK_AUTH_IAM ? ( value={secretAccessKey}
<p className="text-xs text-muted-foreground"> onChange={setSecretAccessKey}
SurfSense will use the IAM role attached to the environment it&apos;s running in to label="AWS Secret Access Key"
authenticate. placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
</p> />
) : null} </>
) : null}
{authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? (
<ApiKeyField
value={bearerToken}
onChange={setBearerToken}
label="Long-term API Key"
placeholder="Your long-term API key"
/>
) : null}
{authMethod === BEDROCK_AUTH_IAM ? (
<p className="text-xs text-muted-foreground"> <p className="text-xs text-muted-foreground">
Add Bedrock model IDs from the provider&apos;s settings after connecting. SurfSense will use the IAM role attached to the environment it&apos;s running in to
authenticate.
</p> </p>
</div> ) : null}
<ConnectFormFooter <p className="text-xs text-muted-foreground">
onCancel={onCancel} Add Bedrock model IDs from the provider&apos;s settings after connecting.
onSubmit={handleSubmit} </p>
canSubmit={canSubmit} </div>
isPending={isPending}
/>
</>
); );
} }

View file

@ -1,3 +1,5 @@
import { Eye, EyeOff } from "lucide-react";
import { useState } from "react";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { DialogFooter } from "@/components/ui/dialog"; import { DialogFooter } from "@/components/ui/dialog";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
@ -43,15 +45,31 @@ export function ApiKeyField({
label = "API Key", label = "API Key",
placeholder = "API key", placeholder = "API key",
}: ApiKeyFieldProps) { }: ApiKeyFieldProps) {
const [showApiKey, setShowApiKey] = useState(false);
return ( return (
<div className="flex flex-col gap-2"> <div className="flex flex-col gap-2">
<Label>{label}</Label> <Label>{label}</Label>
<Input <div className="relative">
value={value} <Input
onChange={(event) => onChange(event.target.value)} value={value}
placeholder={placeholder} onChange={(event) => onChange(event.target.value)}
type="password" placeholder={placeholder}
/> type={showApiKey ? "text" : "password"}
className="pr-11"
/>
<Button
type="button"
variant="ghost"
size="icon"
className="absolute top-1/2 right-1 size-8 -translate-y-1/2 text-muted-foreground"
onClick={() => setShowApiKey((current) => !current)}
disabled={!value}
aria-label={showApiKey ? "Hide API key" : "Show API key"}
>
{showApiKey ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
</Button>
</div>
</div> </div>
); );
} }
@ -71,7 +89,7 @@ export function ConnectFormFooter({
isPending, isPending,
}: ConnectFormFooterProps) { }: ConnectFormFooterProps) {
return ( return (
<DialogFooter className="mt-6"> <DialogFooter className="shrink-0 border-t bg-popover px-6 py-4">
<Button variant="secondary" onClick={onCancel}> <Button variant="secondary" onClick={onCancel}>
Cancel Cancel
</Button> </Button>

View file

@ -25,8 +25,8 @@ import { Separator } from "@/components/ui/separator";
import type { import type {
ConnectionRead, ConnectionRead,
ConnectionUpdateRequest, ConnectionUpdateRequest,
ModelRead,
} from "@/contracts/types/model-connections.types"; } from "@/contracts/types/model-connections.types";
import type { SelectableModel } from "./model-utils";
import { ModelsSelectionPanel } from "./models-selection-panel"; import { ModelsSelectionPanel } from "./models-selection-panel";
import { providerIcon } from "./provider-metadata"; import { providerIcon } from "./provider-metadata";
@ -101,17 +101,22 @@ export function ConnectionSettingsDialog({
}); });
} }
function handleToggleModel(model: ModelRead, enabled: boolean) { function handleToggleModel(model: SelectableModel, enabled: boolean) {
if (typeof model.id !== "number") return;
updateModel.mutate({ updateModel.mutate({
id: model.id, id: model.id,
data: { enabled }, data: { enabled },
}); });
} }
function handleBulkToggle(models: ModelRead[], enabled: boolean) { function handleBulkToggle(models: SelectableModel[], enabled: boolean) {
const modelIds = models
.map((model) => model.id)
.filter((id): id is number => typeof id === "number");
if (modelIds.length === 0) return;
bulkUpdateModels.mutate({ bulkUpdateModels.mutate({
connectionId: connection.id, connectionId: connection.id,
data: { model_ids: models.map((model) => model.id), enabled }, data: { model_ids: modelIds, enabled },
}); });
} }
@ -184,12 +189,7 @@ export function ConnectionSettingsDialog({
onChange={(event) => setAllowlistText(event.target.value)} onChange={(event) => setAllowlistText(event.target.value)}
placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro"
/> />
<Button <Button size="sm" onClick={saveAllowlist} disabled={updateConnection.isPending}>
variant="outline"
size="sm"
onClick={saveAllowlist}
disabled={updateConnection.isPending}
>
Save filter Save filter
</Button> </Button>
</div> </div>

View file

@ -1,5 +1,5 @@
import { useState } from "react"; import { useEffect, useState } from "react";
import { ApiBaseUrlField, ApiKeyField, ConnectFormFooter } from "./connect-fields"; import { ApiBaseUrlField, ApiKeyField } from "./connect-fields";
import type { ProviderConnectFormProps } from "./provider-metadata"; import type { ProviderConnectFormProps } from "./provider-metadata";
/** /**
@ -11,41 +11,31 @@ export function DefaultConnectForm({
provider, provider,
defaultBaseUrl, defaultBaseUrl,
baseUrlRequired, baseUrlRequired,
isPending, onDraftChange,
onCancel,
onSubmit,
}: ProviderConnectFormProps) { }: ProviderConnectFormProps) {
const [baseUrl, setBaseUrl] = useState(defaultBaseUrl); const [baseUrl, setBaseUrl] = useState(defaultBaseUrl);
const [apiKey, setApiKey] = useState(""); const [apiKey, setApiKey] = useState("");
const isOllama = provider === "ollama_chat"; const isOllama = provider === "ollama_chat";
const canSubmit = !(baseUrlRequired && !baseUrl.trim()); const canSubmit = !(baseUrlRequired && !baseUrl.trim());
function handleSubmit() { useEffect(() => {
onSubmit({ base_url: baseUrl || null, api_key: apiKey || null, extra: {} }); onDraftChange({ base_url: baseUrl || null, api_key: apiKey || null, extra: {} }, canSubmit);
} }, [apiKey, baseUrl, canSubmit, onDraftChange]);
return ( return (
<> <div className="flex flex-col gap-4">
<div className="flex flex-col gap-4"> <ApiBaseUrlField
<ApiBaseUrlField value={baseUrl}
value={baseUrl} onChange={setBaseUrl}
onChange={setBaseUrl} optional={!baseUrlRequired}
optional={!baseUrlRequired} placeholder={defaultBaseUrl}
placeholder={defaultBaseUrl}
/>
<ApiKeyField
value={apiKey}
onChange={setApiKey}
label={isOllama ? "API Key (optional)" : "API Key"}
placeholder={isOllama ? "Optional for Ollama" : "API key"}
/>
</div>
<ConnectFormFooter
onCancel={onCancel}
onSubmit={handleSubmit}
canSubmit={canSubmit}
isPending={isPending}
/> />
</> <ApiKeyField
value={apiKey}
onChange={setApiKey}
label={isOllama ? "API Key (optional)" : "API Key"}
placeholder={isOllama ? "Optional for Ollama" : "API key"}
/>
</div>
); );
} }

View file

@ -1,4 +1,4 @@
import type { ModelRead } from "@/contracts/types/model-connections.types"; import type { ModelPreviewRead, ModelRead } from "@/contracts/types/model-connections.types";
export type ModelCapabilityFilter = "chat" | "vision" | "image_gen"; export type ModelCapabilityFilter = "chat" | "vision" | "image_gen";
@ -8,17 +8,22 @@ export const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: stri
{ key: "image_gen", label: "Image" }, { key: "image_gen", label: "Image" },
]; ];
export function modelLabel(model: ModelRead) { export type SelectableModel = (ModelRead | ModelPreviewRead) & {
id?: number | string;
connection_id?: number;
};
export function modelLabel(model: SelectableModel) {
return model.display_name || model.model_id; return model.display_name || model.model_id;
} }
export function capability(model: ModelRead, key: ModelCapabilityFilter) { export function capability(model: SelectableModel, key: ModelCapabilityFilter) {
if (key === "chat") return Boolean(model.supports_chat); if (key === "chat") return Boolean(model.supports_chat);
if (key === "vision") return Boolean(model.supports_image_input); if (key === "vision") return Boolean(model.supports_image_input);
return Boolean(model.supports_image_generation); return Boolean(model.supports_image_generation);
} }
export function capabilityLabels(model: ModelRead) { export function capabilityLabels(model: SelectableModel) {
return MODEL_CAPABILITY_FILTERS.filter((filter) => capability(model, filter.key)) return MODEL_CAPABILITY_FILTERS.filter((filter) => capability(model, filter.key))
.map((filter) => filter.label.toLowerCase()) .map((filter) => filter.label.toLowerCase())
.join(", "); .join(", ");

View file

@ -4,17 +4,17 @@ import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox"; import { Checkbox } from "@/components/ui/checkbox";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import type { ModelRead } from "@/contracts/types/model-connections.types";
import { import {
capability, capability,
capabilityLabels, capabilityLabels,
MODEL_CAPABILITY_FILTERS, MODEL_CAPABILITY_FILTERS,
type ModelCapabilityFilter, type ModelCapabilityFilter,
modelLabel, modelLabel,
type SelectableModel,
} from "./model-utils"; } from "./model-utils";
interface ModelsSelectionPanelProps { interface ModelsSelectionPanelProps {
models: ModelRead[]; models: SelectableModel[];
description?: string; description?: string;
emptyMessage?: string; emptyMessage?: string;
manualInputPlaceholder?: string; manualInputPlaceholder?: string;
@ -25,8 +25,8 @@ interface ModelsSelectionPanelProps {
isBulkUpdating?: boolean; isBulkUpdating?: boolean;
onRefresh?: () => void; onRefresh?: () => void;
onAddManual?: (modelId: string) => void; onAddManual?: (modelId: string) => void;
onToggleModel?: (model: ModelRead, enabled: boolean) => void; onToggleModel?: (model: SelectableModel, enabled: boolean) => void;
onBulkToggle?: (models: ModelRead[], enabled: boolean) => void; onBulkToggle?: (models: SelectableModel[], enabled: boolean) => void;
} }
export function ModelsSelectionPanel({ export function ModelsSelectionPanel({
@ -166,7 +166,7 @@ export function ModelsSelectionPanel({
<div className="space-y-2"> <div className="space-y-2">
{filteredModels.map((model) => ( {filteredModels.map((model) => (
<div <div
key={model.id} key={model.id ?? model.model_id}
className="flex items-center gap-3 rounded-lg px-3 py-2 transition-colors hover:bg-background" className="flex items-center gap-3 rounded-lg px-3 py-2 transition-colors hover:bg-background"
> >
<Checkbox <Checkbox

View file

@ -1,20 +1,18 @@
import { Button } from "@/components/ui/button"; import { useCallback, useState } from "react";
import { import {
Dialog, Dialog,
DialogContent, DialogContent,
DialogDescription, DialogDescription,
DialogFooter,
DialogHeader, DialogHeader,
DialogTitle, DialogTitle,
} from "@/components/ui/dialog"; } from "@/components/ui/dialog";
import type { import { Separator } from "@/components/ui/separator";
ConnectionRead, import type { ModelProviderRead } from "@/contracts/types/model-connections.types";
ModelProviderRead,
ModelRead,
} from "@/contracts/types/model-connections.types";
import { AzureConnectForm } from "./azure-connect-form"; import { AzureConnectForm } from "./azure-connect-form";
import { BedrockConnectForm } from "./bedrock-connect-form"; import { BedrockConnectForm } from "./bedrock-connect-form";
import { ConnectFormFooter } from "./connect-fields";
import { DefaultConnectForm } from "./default-connect-form"; import { DefaultConnectForm } from "./default-connect-form";
import type { SelectableModel } from "./model-utils";
import { ModelsSelectionPanel } from "./models-selection-panel"; import { ModelsSelectionPanel } from "./models-selection-panel";
import { import {
type ConnectionDraft, type ConnectionDraft,
@ -32,17 +30,12 @@ interface ProviderConnectDialogProps {
selectedProvider?: ModelProviderRead; selectedProvider?: ModelProviderRead;
isPending: boolean; isPending: boolean;
onSubmit: (draft: ConnectionDraft) => void; onSubmit: (draft: ConnectionDraft) => void;
connectedConnection?: ConnectionRead | null; previewModels?: SelectableModel[];
connectModels?: ModelRead[]; isPreviewingModels?: boolean;
isDiscoveringModels?: boolean; onPreviewModels?: (draft: ConnectionDraft) => void;
isAddingManualModel?: boolean; onAddPreviewModel?: (modelId: string) => void;
isUpdatingModel?: boolean; onTogglePreviewModel?: (model: SelectableModel, enabled: boolean) => void;
isBulkUpdatingModels?: boolean; onBulkTogglePreviewModels?: (models: SelectableModel[], enabled: boolean) => void;
onRefreshModels?: () => void;
onAddManualModel?: (modelId: string) => void;
onToggleModel?: (model: ModelRead, enabled: boolean) => void;
onBulkToggleModels?: (models: ModelRead[], enabled: boolean) => void;
onDone?: () => void;
} }
/** /**
@ -57,97 +50,93 @@ export function ProviderConnectDialog({
selectedProvider, selectedProvider,
isPending, isPending,
onSubmit, onSubmit,
connectedConnection, previewModels = [],
connectModels = [], isPreviewingModels = false,
isDiscoveringModels = false, onPreviewModels,
isAddingManualModel = false, onAddPreviewModel,
isUpdatingModel = false, onTogglePreviewModel,
isBulkUpdatingModels = false, onBulkTogglePreviewModels,
onRefreshModels,
onAddManualModel,
onToggleModel,
onBulkToggleModels,
onDone,
}: ProviderConnectDialogProps) { }: ProviderConnectDialogProps) {
const meta = providerDisplay(provider); const meta = providerDisplay(provider);
const isModelSelectionStep = Boolean(connectedConnection); const isAzure = provider === "azure";
const isBedrock = provider === "bedrock";
const isVertex = provider === "vertex_ai";
const [currentDraft, setCurrentDraft] = useState<ConnectionDraft>({
base_url: null,
api_key: null,
extra: {},
});
const [canSubmit, setCanSubmit] = useState(false);
const handleDraftChange = useCallback((draft: ConnectionDraft, nextCanSubmit: boolean) => {
setCurrentDraft(draft);
setCanSubmit(nextCanSubmit);
}, []);
const formProps: ProviderConnectFormProps = { const formProps: ProviderConnectFormProps = {
provider, provider,
defaultBaseUrl: providerDefaultBaseUrl(provider, selectedProvider?.default_base_url), defaultBaseUrl: providerDefaultBaseUrl(provider, selectedProvider?.default_base_url),
baseUrlRequired: Boolean(selectedProvider?.base_url_required), baseUrlRequired: Boolean(selectedProvider?.base_url_required),
isPending, onDraftChange: handleDraftChange,
onCancel: () => onOpenChange(false),
onSubmit,
}; };
const modelDescription = (() => {
if (isAzure) {
return "Select the models to enable for Azure OpenAI";
}
if (isBedrock) {
return "Select the models to enable for Amazon Bedrock";
}
if (isVertex) {
return "Select the models to enable for Gemini";
}
return "Select the models to enable for this provider";
})();
const canRefreshModels = !isAzure && !isVertex && (!isBedrock || canSubmit);
return ( return (
<Dialog open={open} onOpenChange={onOpenChange}> <Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent <DialogContent className="flex h-[85vh] max-h-[760px] min-h-[640px] max-w-2xl flex-col overflow-hidden bg-popover p-0 text-popover-foreground">
className={`flex max-h-[90vh] ${
isModelSelectionStep ? "max-w-2xl" : "max-w-xl"
} flex-col overflow-hidden bg-popover p-0 text-popover-foreground`}
>
<DialogHeader className="shrink-0 border-b px-6 py-5"> <DialogHeader className="shrink-0 border-b px-6 py-5">
<div className="flex items-center gap-3"> <div className="flex items-center gap-3">
{providerIcon(provider, "size-5")} {providerIcon(provider, "size-5")}
<div> <div>
<DialogTitle> <DialogTitle>Connect {meta.name}</DialogTitle>
{isModelSelectionStep ? `Select ${meta.name} models` : `Connect ${meta.name}`} <DialogDescription>{meta.subtitle}</DialogDescription>
</DialogTitle>
<DialogDescription>
{isModelSelectionStep
? selectedProvider?.discovery === "static"
? "Choose from known model IDs or add one manually."
: "Choose which discovered models should be available in this search space."
: meta.subtitle}
</DialogDescription>
</div> </div>
</div> </div>
</DialogHeader> </DialogHeader>
{isModelSelectionStep ? ( <div className="min-h-0 flex-1 space-y-5 overflow-y-auto px-6 py-5">
<> {provider === "azure" ? (
<div className="min-h-0 flex-1 overflow-y-auto px-6 py-5"> <AzureConnectForm {...formProps} />
<ModelsSelectionPanel ) : provider === "bedrock" ? (
models={connectModels} <BedrockConnectForm {...formProps} />
description={ ) : provider === "vertex_ai" ? (
selectedProvider?.discovery === "static" <VertexConnectForm {...formProps} />
? "These are known model IDs for this provider. Select the ones to make available." ) : (
: "Select models to make available for this provider." <DefaultConnectForm {...formProps} />
} )}
emptyMessage={
isDiscoveringModels <Separator className="bg-muted-foreground/20" />
? "Discovering models..."
: "No models found. You can refresh discovery or add a model ID manually." <ModelsSelectionPanel
} models={previewModels}
isRefreshing={isDiscoveringModels} description={modelDescription}
isAddingManual={isAddingManualModel} isRefreshing={isPreviewingModels}
isUpdatingModel={isUpdatingModel} refreshLabel={`Refresh ${meta.name} models`}
isBulkUpdating={isBulkUpdatingModels} onRefresh={canRefreshModels ? () => onPreviewModels?.(currentDraft) : undefined}
refreshLabel={`Refresh ${meta.name} models`} onAddManual={onAddPreviewModel}
onRefresh={onRefreshModels} onToggleModel={onTogglePreviewModel}
onAddManual={onAddManualModel} onBulkToggle={onBulkTogglePreviewModels}
onToggleModel={onToggleModel} />
onBulkToggle={onBulkToggleModels} </div>
/> <ConnectFormFooter
</div> onCancel={() => onOpenChange(false)}
<DialogFooter className="shrink-0 border-t bg-popover px-6 py-4"> onSubmit={() => onSubmit(currentDraft)}
<Button onClick={onDone}>Done</Button> canSubmit={canSubmit}
</DialogFooter> isPending={isPending}
</> />
) : (
<div className="overflow-y-auto px-6 py-5">
{provider === "azure" ? (
<AzureConnectForm {...formProps} />
) : provider === "bedrock" ? (
<BedrockConnectForm {...formProps} />
) : provider === "vertex_ai" ? (
<VertexConnectForm {...formProps} />
) : (
<DefaultConnectForm {...formProps} />
)}
</div>
)}
</DialogContent> </DialogContent>
</Dialog> </Dialog>
); );

View file

@ -133,7 +133,5 @@ export interface ProviderConnectFormProps {
provider: string; provider: string;
defaultBaseUrl: string; defaultBaseUrl: string;
baseUrlRequired: boolean; baseUrlRequired: boolean;
isPending: boolean; onDraftChange: (draft: ConnectionDraft, canSubmit: boolean) => void;
onCancel: () => void;
onSubmit: (draft: ConnectionDraft) => void;
} }

View file

@ -1,4 +1,4 @@
import { useState } from "react"; import { useEffect, useState } from "react";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label"; import { Label } from "@/components/ui/label";
import { import {
@ -8,7 +8,6 @@ import {
SelectTrigger, SelectTrigger,
SelectValue, SelectValue,
} from "@/components/ui/select"; } from "@/components/ui/select";
import { ConnectFormFooter } from "./connect-fields";
import { import {
type ProviderConnectFormProps, type ProviderConnectFormProps,
VERTEX_AUTH_SERVICE_ACCOUNT, VERTEX_AUTH_SERVICE_ACCOUNT,
@ -21,7 +20,7 @@ import {
* credentials JSON file (read into a string); workload identity collects a * credentials JSON file (read into a string); workload identity collects a
* project id. Credentials ride along in `extra.litellm_params`. * project id. Credentials ride along in `extra.litellm_params`.
*/ */
export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { export function VertexConnectForm({ onDraftChange }: ProviderConnectFormProps) {
const [authMethod, setAuthMethod] = useState(VERTEX_AUTH_SERVICE_ACCOUNT); const [authMethod, setAuthMethod] = useState(VERTEX_AUTH_SERVICE_ACCOUNT);
const [location, setLocation] = useState(VERTEX_DEFAULT_LOCATION); const [location, setLocation] = useState(VERTEX_DEFAULT_LOCATION);
const [credentials, setCredentials] = useState(""); const [credentials, setCredentials] = useState("");
@ -35,7 +34,7 @@ export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderCon
setCredentials(await file.text()); setCredentials(await file.text());
} }
function handleSubmit() { useEffect(() => {
const params: Record<string, string> = {}; const params: Record<string, string> = {};
if (location) params.vertex_location = location; if (location) params.vertex_location = location;
if (authMethod === VERTEX_AUTH_SERVICE_ACCOUNT) { if (authMethod === VERTEX_AUTH_SERVICE_ACCOUNT) {
@ -43,85 +42,77 @@ export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderCon
} else if (project) { } else if (project) {
params.vertex_project = project; params.vertex_project = project;
} }
onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } }); onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit);
} }, [authMethod, canSubmit, credentials, location, onDraftChange, project]);
return ( return (
<> <div className="flex flex-col gap-4">
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-2">
<div className="flex flex-col gap-2"> <Label>Authentication Method</Label>
<Label>Authentication Method</Label> <Select value={authMethod} onValueChange={setAuthMethod}>
<Select value={authMethod} onValueChange={setAuthMethod}> <SelectTrigger>
<SelectTrigger> <SelectValue />
<SelectValue /> </SelectTrigger>
</SelectTrigger> <SelectContent>
<SelectContent> <SelectItem value={VERTEX_AUTH_SERVICE_ACCOUNT}>Service Account JSON</SelectItem>
<SelectItem value={VERTEX_AUTH_SERVICE_ACCOUNT}>Service Account JSON</SelectItem> <SelectItem value={VERTEX_AUTH_WORKLOAD_IDENTITY}>Workload Identity (GKE)</SelectItem>
<SelectItem value={VERTEX_AUTH_WORKLOAD_IDENTITY}>Workload Identity (GKE)</SelectItem> </SelectContent>
</SelectContent> </Select>
</Select> </div>
</div> <div className="flex flex-col gap-2">
<div className="flex flex-col gap-2"> <Label>Google Cloud Region Name</Label>
<Label>Google Cloud Region Name</Label> <Input
<Input value={location}
value={location} onChange={(event) => setLocation(event.target.value)}
onChange={(event) => setLocation(event.target.value)} placeholder={VERTEX_DEFAULT_LOCATION}
placeholder={VERTEX_DEFAULT_LOCATION} />
/>
<p className="text-xs text-muted-foreground">
Region where your Google Vertex AI models are hosted.
</p>
</div>
{authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? (
<div className="flex flex-col gap-2">
<Label>Service Account JSON</Label>
<Input
id="vertex-service-account-json"
type="file"
accept="application/json,.json"
className="sr-only"
onChange={(event) => handleCredentialsFile(event.target.files?.[0])}
/>
<Label
htmlFor="vertex-service-account-json"
className="flex min-h-28 cursor-pointer flex-col items-center justify-center gap-2 rounded-lg border-2 border-dashed border-muted-foreground/40 bg-muted/20 px-4 py-6 text-center transition-colors hover:border-muted-foreground/70 hover:bg-muted/40"
>
<span className="text-sm font-medium">
{credentials ? "Service account JSON selected" : "Upload service account JSON"}
</span>
<span className="text-xs text-muted-foreground">
Choose a .json file from Google Cloud
</span>
</Label>
<p className="text-xs text-muted-foreground">
{credentials
? "Credentials file loaded."
: "Attach your service account key JSON from Google Cloud."}
</p>
</div>
) : (
<div className="flex flex-col gap-2">
<Label>GCP Project ID</Label>
<Input
value={project}
onChange={(event) => setProject(event.target.value)}
placeholder="my-vertex-project"
/>
<p className="text-xs text-muted-foreground">
The GCP project where Vertex AI is enabled.
</p>
</div>
)}
<p className="text-xs text-muted-foreground"> <p className="text-xs text-muted-foreground">
Add Vertex AI model IDs from the provider&apos;s settings after connecting. Region where your Google Vertex AI models are hosted.
</p> </p>
</div> </div>
<ConnectFormFooter {authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? (
onCancel={onCancel} <div className="flex flex-col gap-2">
onSubmit={handleSubmit} <Label>Service Account JSON</Label>
canSubmit={canSubmit} <Input
isPending={isPending} id="vertex-service-account-json"
/> type="file"
</> accept="application/json,.json"
className="sr-only"
onChange={(event) => handleCredentialsFile(event.target.files?.[0])}
/>
<Label
htmlFor="vertex-service-account-json"
className="flex min-h-28 cursor-pointer flex-col items-center justify-center gap-2 rounded-lg border-2 border-dashed border-muted-foreground/40 bg-muted/20 px-4 py-6 text-center transition-colors hover:border-muted-foreground/70 hover:bg-muted/40"
>
<span className="text-sm font-medium">
{credentials ? "Service account JSON selected" : "Upload service account JSON"}
</span>
<span className="text-xs text-muted-foreground">
Choose a .json file from Google Cloud
</span>
</Label>
<p className="text-xs text-muted-foreground">
{credentials
? "Credentials file loaded."
: "Attach your service account key JSON from Google Cloud."}
</p>
</div>
) : (
<div className="flex flex-col gap-2">
<Label>GCP Project ID</Label>
<Input
value={project}
onChange={(event) => setProject(event.target.value)}
placeholder="my-vertex-project"
/>
<p className="text-xs text-muted-foreground">
The GCP project where Vertex AI is enabled.
</p>
</div>
)}
<p className="text-xs text-muted-foreground">
Add Vertex AI model IDs from the provider&apos;s settings after connecting.
</p>
</div>
); );
} }

View file

@ -40,6 +40,21 @@ export const connectionRead = z.object({
created_at: z.string().nullable().optional(), created_at: z.string().nullable().optional(),
}); });
export const modelSelection = z.object({
model_id: z.string().min(1),
display_name: z.string().nullable().optional(),
source: z.union([modelSourceEnum, z.string()]).default("DISCOVERED"),
supports_chat: z.boolean().nullable().optional(),
max_input_tokens: z.number().nullable().optional(),
supports_image_input: z.boolean().nullable().optional(),
supports_tools: z.boolean().nullable().optional(),
supports_image_generation: z.boolean().nullable().optional(),
enabled: z.boolean().default(false),
metadata: z.record(z.string(), z.any()).default({}),
});
export const modelPreviewRead = modelSelection;
export const connectionCreateRequest = z.object({ export const connectionCreateRequest = z.object({
provider: z.string().min(1), provider: z.string().min(1),
base_url: z.string().nullable().optional(), base_url: z.string().nullable().optional(),
@ -48,6 +63,7 @@ export const connectionCreateRequest = z.object({
scope: connectionScopeEnum.default("SEARCH_SPACE"), scope: connectionScopeEnum.default("SEARCH_SPACE"),
search_space_id: z.number().nullable().optional(), search_space_id: z.number().nullable().optional(),
enabled: z.boolean().default(true), enabled: z.boolean().default(true),
models: z.array(modelSelection).default([]),
}); });
export const connectionUpdateRequest = z.object({ export const connectionUpdateRequest = z.object({
@ -105,9 +121,12 @@ export const modelProviderListResponse = z.array(modelProviderRead);
export const connectionListResponse = z.array(connectionRead); export const connectionListResponse = z.array(connectionRead);
export const modelListResponse = z.array(modelRead); export const modelListResponse = z.array(modelRead);
export const modelPreviewListResponse = z.array(modelPreviewRead);
export type ConnectionScope = z.infer<typeof connectionScopeEnum>; export type ConnectionScope = z.infer<typeof connectionScopeEnum>;
export type ModelRead = z.infer<typeof modelRead>; export type ModelRead = z.infer<typeof modelRead>;
export type ModelPreviewRead = z.infer<typeof modelPreviewRead>;
export type ModelSelection = z.infer<typeof modelSelection>;
export type ConnectionRead = z.infer<typeof connectionRead>; export type ConnectionRead = z.infer<typeof connectionRead>;
export type ConnectionCreateRequest = z.infer<typeof connectionCreateRequest>; export type ConnectionCreateRequest = z.infer<typeof connectionCreateRequest>;
export type ConnectionUpdateRequest = z.infer<typeof connectionUpdateRequest>; export type ConnectionUpdateRequest = z.infer<typeof connectionUpdateRequest>;

View file

@ -7,6 +7,7 @@ import {
connectionRead, connectionRead,
connectionUpdateRequest, connectionUpdateRequest,
type ModelCreateRequest, type ModelCreateRequest,
type ModelPreviewRead,
type ModelProviderRead, type ModelProviderRead,
type ModelRead, type ModelRead,
type ModelRoles, type ModelRoles,
@ -14,6 +15,7 @@ import {
type ModelUpdateRequest, type ModelUpdateRequest,
modelCreateRequest, modelCreateRequest,
modelListResponse, modelListResponse,
modelPreviewListResponse,
modelProviderListResponse, modelProviderListResponse,
modelRead, modelRead,
modelRoles, modelRoles,
@ -76,6 +78,20 @@ class ModelConnectionsApiService {
return baseApiService.post(`/api/v1/model-connections/${id}/discover`, modelListResponse); return baseApiService.post(`/api/v1/model-connections/${id}/discover`, modelListResponse);
}; };
previewModels = async (request: ConnectionCreateRequest): Promise<ModelPreviewRead[]> => {
const parsed = connectionCreateRequest.safeParse(request);
if (!parsed.success) {
throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", "));
}
return baseApiService.post(
`/api/v1/model-connections/discover-preview`,
modelPreviewListResponse,
{
body: parsed.data,
}
);
};
addManualModel = async ( addManualModel = async (
connectionId: number, connectionId: number,
request: ModelCreateRequest request: ModelCreateRequest