mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-12 20:45:20 +02:00
feat(model-connections): implement bulk model update endpoint and related schema changes
This commit is contained in:
parent
ad404b2dbc
commit
ced1bb85ed
7 changed files with 538 additions and 168 deletions
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
|
@ -25,6 +25,7 @@ from app.schemas import (
|
|||
ModelRead,
|
||||
ModelRolesRead,
|
||||
ModelRolesUpdate,
|
||||
ModelsBulkUpdate,
|
||||
ModelUpdate,
|
||||
VerifyConnectionResponse,
|
||||
)
|
||||
|
|
@ -62,6 +63,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None
|
|||
id=conn.id,
|
||||
provider=conn.provider,
|
||||
base_url=conn.base_url,
|
||||
api_key=conn.api_key,
|
||||
extra=conn.extra or {},
|
||||
scope=conn.scope,
|
||||
search_space_id=conn.search_space_id,
|
||||
|
|
@ -351,6 +353,33 @@ async def add_manual_model(
|
|||
return _model_read(model)
|
||||
|
||||
|
||||
@router.patch("/model-connections/{connection_id}/models", response_model=list[ModelRead])
|
||||
async def bulk_update_models(
|
||||
connection_id: int,
|
||||
data: ModelsBulkUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
|
||||
|
||||
model_ids = set(data.model_ids)
|
||||
await session.execute(
|
||||
update(Model)
|
||||
.where(Model.connection_id == connection_id, Model.id.in_(model_ids))
|
||||
.values(enabled=data.enabled)
|
||||
)
|
||||
await session.commit()
|
||||
session.expire_all()
|
||||
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
.where(Model.connection_id == connection_id, Model.id.in_(model_ids))
|
||||
.order_by(Model.id)
|
||||
)
|
||||
return [_model_read(model) for model in result.scalars().all()]
|
||||
|
||||
|
||||
@router.put("/models/{model_id}", response_model=ModelRead)
|
||||
async def update_model(
|
||||
model_id: int,
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ from .model_connections import (
|
|||
ModelRead,
|
||||
ModelRolesRead,
|
||||
ModelRolesUpdate,
|
||||
ModelsBulkUpdate,
|
||||
ModelUpdate,
|
||||
VerifyConnectionResponse,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class ConnectionRead(BaseModel):
|
|||
id: int
|
||||
provider: str
|
||||
base_url: str | None = None
|
||||
api_key: str | None = None
|
||||
extra: dict[str, Any] = Field(default_factory=dict)
|
||||
scope: ConnectionScope | str
|
||||
search_space_id: int | None = None
|
||||
|
|
@ -87,6 +88,11 @@ class ModelUpdate(BaseModel):
|
|||
capabilities_override: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ModelsBulkUpdate(BaseModel):
|
||||
model_ids: list[int] = Field(..., min_length=1, max_length=1000)
|
||||
enabled: bool
|
||||
|
||||
|
||||
class ModelProviderRead(BaseModel):
|
||||
provider: str
|
||||
transport: str
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import type {
|
|||
ModelCreateRequest,
|
||||
ModelRead,
|
||||
ModelRoles,
|
||||
ModelsBulkUpdateRequest,
|
||||
ModelUpdateRequest,
|
||||
VerifyConnectionResponse,
|
||||
} from "@/contracts/types/model-connections.types";
|
||||
|
|
@ -127,6 +128,17 @@ export const updateModelMutationAtom = atomWithMutation((get) => {
|
|||
};
|
||||
});
|
||||
|
||||
export const bulkUpdateModelsMutationAtom = atomWithMutation((get) => {
|
||||
const searchSpaceId = Number(get(activeSearchSpaceIdAtom));
|
||||
return {
|
||||
mutationKey: ["models", "bulk-update"],
|
||||
mutationFn: ({ connectionId, data }: { connectionId: number; data: ModelsBulkUpdateRequest }) =>
|
||||
modelConnectionsApiService.bulkUpdateModels(connectionId, data),
|
||||
onSuccess: () => invalidateModelConnections(searchSpaceId),
|
||||
onError: (error: Error) => toast.error(error.message || "Failed to update models"),
|
||||
};
|
||||
});
|
||||
|
||||
export const testModelMutationAtom = atomWithMutation((get) => {
|
||||
const searchSpaceId = Number(get(activeSearchSpaceIdAtom));
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,14 +1,24 @@
|
|||
"use client";
|
||||
|
||||
import { useAtom, useAtomValue } from "jotai";
|
||||
import { CheckCircle2, PlugZap, Plus, RefreshCcw, Trash2, XCircle } from "lucide-react";
|
||||
import {
|
||||
Check,
|
||||
CheckCircle2,
|
||||
ChevronsUpDown,
|
||||
Eye,
|
||||
EyeOff,
|
||||
RefreshCcw,
|
||||
Settings,
|
||||
Trash2,
|
||||
XCircle,
|
||||
} from "lucide-react";
|
||||
import { useState } from "react";
|
||||
import {
|
||||
addManualModelMutationAtom,
|
||||
bulkUpdateModelsMutationAtom,
|
||||
createModelConnectionMutationAtom,
|
||||
deleteModelConnectionMutationAtom,
|
||||
discoverConnectionModelsMutationAtom,
|
||||
testModelMutationAtom,
|
||||
updateModelConnectionMutationAtom,
|
||||
updateModelMutationAtom,
|
||||
updateModelRolesMutationAtom,
|
||||
|
|
@ -20,11 +30,41 @@ import {
|
|||
modelProvidersAtom,
|
||||
modelRolesAtom,
|
||||
} from "@/atoms/model-connections/model-connections-query.atoms";
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogAction,
|
||||
AlertDialogCancel,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogTitle,
|
||||
AlertDialogTrigger,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandInput,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
} from "@/components/ui/command";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
|
|
@ -32,8 +72,14 @@ import {
|
|||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import type {
|
||||
ConnectionRead,
|
||||
ConnectionUpdateRequest,
|
||||
ModelRead,
|
||||
} from "@/contracts/types/model-connections.types";
|
||||
import { getProviderIcon } from "@/lib/provider-icons";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
// Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict
|
||||
// what the user can type — any OpenAI-compatible endpoint works.
|
||||
|
|
@ -69,6 +115,67 @@ const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[]
|
|||
{ key: "image_gen", label: "Image" },
|
||||
];
|
||||
|
||||
function UrlSuggestionCombobox({
|
||||
value,
|
||||
onChange,
|
||||
placeholder,
|
||||
}: {
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
placeholder: string;
|
||||
}) {
|
||||
const [open, setOpen] = useState(false);
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
role="combobox"
|
||||
aria-expanded={open}
|
||||
className="w-full justify-between bg-transparent font-normal"
|
||||
>
|
||||
<span className={cn("truncate", !value && "text-muted-foreground")}>
|
||||
{value || placeholder}
|
||||
</span>
|
||||
<ChevronsUpDown className="ml-2 h-4 w-4 shrink-0 opacity-50" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-[var(--radix-popover-trigger-width)] p-0" align="start">
|
||||
<Command className="bg-transparent">
|
||||
<CommandInput
|
||||
placeholder="Search or type URL..."
|
||||
value={value}
|
||||
onValueChange={onChange}
|
||||
/>
|
||||
<CommandList>
|
||||
<CommandEmpty>
|
||||
<span className="text-xs text-muted-foreground">Use the custom URL you typed</span>
|
||||
</CommandEmpty>
|
||||
<CommandGroup>
|
||||
{URL_SUGGESTIONS.map((url) => (
|
||||
<CommandItem
|
||||
key={url}
|
||||
value={url}
|
||||
onSelect={() => {
|
||||
onChange(url);
|
||||
setOpen(false);
|
||||
}}
|
||||
>
|
||||
<Check
|
||||
className={cn("mr-2 h-4 w-4", value === url ? "opacity-100" : "opacity-0")}
|
||||
/>
|
||||
<span className="truncate font-mono text-sm">{url}</span>
|
||||
</CommandItem>
|
||||
))}
|
||||
</CommandGroup>
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
|
||||
function StatusBadge({ connection }: { connection: ConnectionRead }) {
|
||||
if (connection.last_status === "OK") {
|
||||
return (
|
||||
|
|
@ -105,11 +212,15 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
|||
const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom);
|
||||
const addManualModel = useAtomValue(addManualModelMutationAtom);
|
||||
const updateModel = useAtomValue(updateModelMutationAtom);
|
||||
const testModel = useAtomValue(testModelMutationAtom);
|
||||
const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom);
|
||||
|
||||
const allowlist = Array.isArray(connection.extra?.model_ids)
|
||||
? (connection.extra.model_ids as string[])
|
||||
: [];
|
||||
const [isSettingsOpen, setIsSettingsOpen] = useState(false);
|
||||
const [baseUrlDraft, setBaseUrlDraft] = useState(connection.base_url ?? "");
|
||||
const [apiKeyDraft, setApiKeyDraft] = useState("");
|
||||
const [showApiKey, setShowApiKey] = useState(false);
|
||||
const [allowlistText, setAllowlistText] = useState(allowlist.join(", "));
|
||||
const [manualModelId, setManualModelId] = useState("");
|
||||
const [modelFilter, setModelFilter] = useState<ModelCapabilityFilter | null>(null);
|
||||
|
|
@ -122,6 +233,38 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
|||
const filteredModels = modelFilter
|
||||
? connection.models.filter((model) => capability(model, modelFilter))
|
||||
: connection.models;
|
||||
const allFilteredModelsEnabled =
|
||||
filteredModels.length > 0 && filteredModels.every((model) => model.enabled);
|
||||
const hasConnectionChanges =
|
||||
baseUrlDraft.trim() !== (connection.base_url ?? "") ||
|
||||
apiKeyDraft.trim() !== (connection.api_key ?? "");
|
||||
|
||||
function handleSettingsOpenChange(open: boolean) {
|
||||
setIsSettingsOpen(open);
|
||||
if (open) {
|
||||
setBaseUrlDraft(connection.base_url ?? "");
|
||||
setApiKeyDraft(connection.api_key ?? "");
|
||||
setShowApiKey(false);
|
||||
setAllowlistText(allowlist.join(", "));
|
||||
}
|
||||
}
|
||||
|
||||
function saveConnectionSettings() {
|
||||
const data: ConnectionUpdateRequest = {
|
||||
base_url: baseUrlDraft.trim() || null,
|
||||
};
|
||||
|
||||
if (apiKeyDraft.trim() !== (connection.api_key ?? "")) {
|
||||
data.api_key = apiKeyDraft.trim() || null;
|
||||
}
|
||||
|
||||
updateConnection.mutate(
|
||||
{ id: connection.id, data },
|
||||
{
|
||||
onSuccess: () => setApiKeyDraft(""),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
function saveAllowlist() {
|
||||
const ids = allowlistText
|
||||
|
|
@ -144,170 +287,321 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
|||
}
|
||||
|
||||
function deleteCurrentConnection() {
|
||||
const confirmed = window.confirm(
|
||||
`Delete the ${providerLabel} connection and all of its models? This cannot be undone.`
|
||||
);
|
||||
if (!confirmed) return;
|
||||
deleteConnection.mutate(connection.id);
|
||||
}
|
||||
|
||||
function toggleFilteredModels() {
|
||||
const nextEnabled = !allFilteredModelsEnabled;
|
||||
const modelIds = filteredModels
|
||||
.filter((model) => model.enabled !== nextEnabled)
|
||||
.map((model) => model.id);
|
||||
|
||||
if (modelIds.length === 0) return;
|
||||
|
||||
bulkUpdateModels.mutate({
|
||||
connectionId: connection.id,
|
||||
data: { model_ids: modelIds, enabled: nextEnabled },
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border p-4">
|
||||
<div className="flex flex-wrap items-center justify-between gap-3">
|
||||
<div>
|
||||
<div className="flex items-center gap-2 font-medium">
|
||||
<div className="rounded-xl border bg-background p-4 shadow-sm">
|
||||
<div className="flex items-center justify-between gap-3">
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-2 font-semibold">
|
||||
{getProviderIcon(providerLabel, { className: "size-4" })}
|
||||
{providerLabel}
|
||||
<span className="truncate">{providerLabel}</span>
|
||||
{connection.scope === "GLOBAL" ? (
|
||||
<Badge variant="outline" className="text-[10px]">
|
||||
Default
|
||||
</Badge>
|
||||
) : null}
|
||||
</div>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
<div className="truncate text-sm text-muted-foreground">
|
||||
{connection.base_url || "Provider default endpoint"}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
<div className="flex shrink-0 items-center gap-2">
|
||||
<StatusBadge connection={connection} />
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => verifyConnection.mutate(connection.id)}
|
||||
>
|
||||
Test
|
||||
</Button>
|
||||
<Button variant="outline" size="sm" onClick={() => discoverModels.mutate(connection.id)}>
|
||||
<RefreshCcw className="mr-2 h-4 w-4" /> Discover
|
||||
</Button>
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
onClick={deleteCurrentConnection}
|
||||
disabled={deleteConnection.isPending}
|
||||
>
|
||||
<Trash2 className="mr-2 h-4 w-4" /> Delete
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{connection.last_status && connection.last_status !== "OK" ? (
|
||||
<p className="mt-2 text-sm text-amber-600 dark:text-amber-500">
|
||||
{connection.last_error || "Could not list models."} Chat may still work — add model IDs
|
||||
manually below.
|
||||
</p>
|
||||
) : null}
|
||||
|
||||
{!isLocal ? (
|
||||
<div className="mt-4 space-y-1">
|
||||
<Label className="text-xs">Model IDs filter (optional)</Label>
|
||||
<div className="flex gap-2">
|
||||
<Input
|
||||
value={allowlistText}
|
||||
onChange={(event) => setAllowlistText(event.target.value)}
|
||||
placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro"
|
||||
/>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={saveAllowlist}
|
||||
disabled={updateConnection.isPending}
|
||||
>
|
||||
Save filter
|
||||
</Button>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Leave empty to discover all models. Recommended for providers with large catalogs (e.g.
|
||||
OpenRouter).
|
||||
</p>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
<div className="mt-4 flex gap-2">
|
||||
<Input
|
||||
value={manualModelId}
|
||||
onChange={(event) => setManualModelId(event.target.value)}
|
||||
onKeyDown={(event) => {
|
||||
if (event.key === "Enter") {
|
||||
event.preventDefault();
|
||||
addModel();
|
||||
}
|
||||
}}
|
||||
placeholder="Add a model ID manually (for providers without /models)"
|
||||
/>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={addModel}
|
||||
disabled={addManualModel.isPending || !manualModelId.trim()}
|
||||
>
|
||||
<Plus className="mr-2 h-4 w-4" /> Add model
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{connection.models.length > 0 ? (
|
||||
<div className="mt-4 flex flex-wrap items-center gap-2">
|
||||
<span className="text-xs font-medium text-muted-foreground">Filter models</span>
|
||||
{MODEL_CAPABILITY_FILTERS.map((filter) => {
|
||||
const count = connection.models.filter((model) => capability(model, filter.key)).length;
|
||||
const isActive = modelFilter === filter.key;
|
||||
|
||||
return (
|
||||
<Button
|
||||
key={filter.key}
|
||||
type="button"
|
||||
variant={isActive ? "secondary" : "outline"}
|
||||
size="sm"
|
||||
className="h-7 rounded-full px-3 text-xs"
|
||||
onClick={() => setModelFilter(isActive ? null : filter.key)}
|
||||
>
|
||||
{filter.label}
|
||||
<span className="ml-1 text-muted-foreground">{count}</span>
|
||||
<Dialog open={isSettingsOpen} onOpenChange={handleSettingsOpenChange}>
|
||||
<DialogTrigger asChild>
|
||||
<Button variant="ghost" size="icon" aria-label={`Configure ${providerLabel}`}>
|
||||
<Settings className="h-4 w-4" />
|
||||
</Button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
) : null}
|
||||
</DialogTrigger>
|
||||
<DialogContent className="flex max-h-[90vh] max-w-3xl flex-col overflow-hidden bg-popover p-0 text-popover-foreground">
|
||||
<DialogHeader className="shrink-0 border-b px-6 py-5">
|
||||
<div className="flex items-center gap-3">
|
||||
{getProviderIcon(providerLabel, { className: "size-5" })}
|
||||
<div>
|
||||
<DialogTitle>
|
||||
Configure <span className="italic">{providerLabel}</span>
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
Manage credentials and choose which models are available from this provider.
|
||||
</DialogDescription>
|
||||
</div>
|
||||
</div>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="mt-4 grid gap-2">
|
||||
{filteredModels.length === 0 && modelFilter ? (
|
||||
<div className="rounded-md bg-muted/30 px-3 py-2 text-xs text-muted-foreground">
|
||||
No {MODEL_CAPABILITY_FILTERS.find((filter) => filter.key === modelFilter)?.label.toLowerCase()}{" "}
|
||||
models found on this connection.
|
||||
</div>
|
||||
) : null}
|
||||
{filteredModels.map((model) => (
|
||||
<div
|
||||
key={model.id}
|
||||
className="flex flex-wrap items-center justify-between gap-2 rounded-md bg-muted/40 px-3 py-2"
|
||||
>
|
||||
<div>
|
||||
<div className="flex items-center gap-2 text-sm font-medium">
|
||||
{getProviderIcon(providerLabel, { className: "size-4" })}
|
||||
{modelLabel(model)}
|
||||
{model.source === "MANUAL" ? (
|
||||
<Badge variant="outline" className="text-[10px]">
|
||||
manual
|
||||
</Badge>
|
||||
) : null}
|
||||
<div className="min-h-0 flex-1 overflow-y-auto px-6 py-5">
|
||||
<div className="space-y-6">
|
||||
<div className="space-y-2">
|
||||
<Label>API Base URL</Label>
|
||||
<UrlSuggestionCombobox
|
||||
value={baseUrlDraft}
|
||||
onChange={setBaseUrlDraft}
|
||||
placeholder="https://api.example.com/v1"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Leave empty to use the provider default endpoint.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label>API Key</Label>
|
||||
<div className="relative">
|
||||
<Input
|
||||
value={apiKeyDraft}
|
||||
onChange={(event) => setApiKeyDraft(event.target.value)}
|
||||
placeholder={connection.has_api_key ? "Saved API key" : "Paste an API key"}
|
||||
type={showApiKey ? "text" : "password"}
|
||||
className="pr-11"
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="absolute top-1/2 right-1 size-8 -translate-y-1/2 text-muted-foreground"
|
||||
onClick={() => setShowApiKey((current) => !current)}
|
||||
disabled={!apiKeyDraft}
|
||||
aria-label={showApiKey ? "Hide API key" : "Show API key"}
|
||||
>
|
||||
{showApiKey ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{!isLocal ? (
|
||||
<div className="space-y-2">
|
||||
<Label className="text-xs">Model IDs filter (optional)</Label>
|
||||
<div className="flex gap-2">
|
||||
<Input
|
||||
value={allowlistText}
|
||||
onChange={(event) => setAllowlistText(event.target.value)}
|
||||
placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro"
|
||||
/>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={saveAllowlist}
|
||||
disabled={updateConnection.isPending}
|
||||
>
|
||||
Save filter
|
||||
</Button>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Leave empty to discover all models. Recommended for providers with large
|
||||
catalogs.
|
||||
</p>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
<Separator className="bg-muted-foreground/20" />
|
||||
|
||||
<div className="space-y-3">
|
||||
<div className="flex flex-wrap items-start justify-between gap-3">
|
||||
<div>
|
||||
<div className="font-semibold">Models</div>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Select models to make available for this provider.
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
type="button"
|
||||
onClick={toggleFilteredModels}
|
||||
disabled={bulkUpdateModels.isPending || filteredModels.length === 0}
|
||||
>
|
||||
{allFilteredModelsEnabled ? "Deselect All" : "Select All"}
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={() => discoverModels.mutate(connection.id)}
|
||||
disabled={discoverModels.isPending}
|
||||
aria-label={`Refresh ${providerLabel} models`}
|
||||
>
|
||||
<RefreshCcw className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex gap-2">
|
||||
<Input
|
||||
value={manualModelId}
|
||||
onChange={(event) => setManualModelId(event.target.value)}
|
||||
onKeyDown={(event) => {
|
||||
if (event.key === "Enter") {
|
||||
event.preventDefault();
|
||||
addModel();
|
||||
}
|
||||
}}
|
||||
placeholder="Add a model ID manually"
|
||||
/>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={addModel}
|
||||
disabled={addManualModel.isPending || !manualModelId.trim()}
|
||||
>
|
||||
Add model
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{connection.models.length > 0 ? (
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
<span className="text-xs font-medium text-muted-foreground">
|
||||
Filter models
|
||||
</span>
|
||||
{MODEL_CAPABILITY_FILTERS.map((filter) => {
|
||||
const count = connection.models.filter((model) =>
|
||||
capability(model, filter.key)
|
||||
).length;
|
||||
const isActive = modelFilter === filter.key;
|
||||
|
||||
return (
|
||||
<Button
|
||||
key={filter.key}
|
||||
type="button"
|
||||
variant={isActive ? "secondary" : "outline"}
|
||||
size="sm"
|
||||
className="h-7 rounded-full px-3 text-xs"
|
||||
onClick={() => setModelFilter(isActive ? null : filter.key)}
|
||||
>
|
||||
{filter.label}
|
||||
<span className="ml-1 text-muted-foreground">{count}</span>
|
||||
</Button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
<div className="max-h-80 overflow-y-auto rounded-xl border bg-muted/20 p-2">
|
||||
{connection.models.length === 0 ? (
|
||||
<div className="rounded-lg px-3 py-6 text-center text-sm text-muted-foreground">
|
||||
No models yet. Use the refresh button to discover models or add one
|
||||
manually.
|
||||
</div>
|
||||
) : null}
|
||||
{filteredModels.length === 0 && modelFilter ? (
|
||||
<div className="rounded-lg px-3 py-6 text-center text-sm text-muted-foreground">
|
||||
No{" "}
|
||||
{MODEL_CAPABILITY_FILTERS.find(
|
||||
(filter) => filter.key === modelFilter
|
||||
)?.label.toLowerCase()}{" "}
|
||||
models found on this connection.
|
||||
</div>
|
||||
) : null}
|
||||
<div className="space-y-2">
|
||||
{filteredModels.map((model) => (
|
||||
<div
|
||||
key={model.id}
|
||||
className="flex items-center gap-3 rounded-lg px-3 py-2 transition-colors hover:bg-background"
|
||||
>
|
||||
<Checkbox
|
||||
checked={model.enabled}
|
||||
onCheckedChange={(checked) =>
|
||||
updateModel.mutate({
|
||||
id: model.id,
|
||||
data: { enabled: checked === true },
|
||||
})
|
||||
}
|
||||
disabled={updateModel.isPending}
|
||||
/>
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="flex items-center gap-2 text-sm font-medium">
|
||||
<span className="truncate">{modelLabel(model)}</span>
|
||||
{model.source === "MANUAL" ? (
|
||||
<Badge variant="outline" className="text-[10px]">
|
||||
manual
|
||||
</Badge>
|
||||
) : null}
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{["chat", "vision", "image_gen"]
|
||||
.filter((key) =>
|
||||
capability(model, key as "chat" | "vision" | "image_gen")
|
||||
)
|
||||
.join(", ") || "No discovered capabilities"}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{connection.last_status && connection.last_status !== "OK" ? (
|
||||
<p className="rounded-lg bg-amber-500/10 px-3 py-2 text-sm text-amber-600 dark:text-amber-500">
|
||||
{connection.last_error || "Could not list models."} Chat may still work; add
|
||||
model IDs manually if discovery is unavailable.
|
||||
</p>
|
||||
) : null}
|
||||
</div>
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{["chat", "vision", "image_gen"]
|
||||
.filter((key) => capability(model, key as "chat" | "vision" | "image_gen"))
|
||||
.join(", ") || "No discovered capabilities"}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<Button variant="outline" size="sm" onClick={() => testModel.mutate(model.id)}>
|
||||
Test
|
||||
</Button>
|
||||
|
||||
<DialogFooter className="shrink-0 border-t bg-popover px-6 py-4">
|
||||
<Button
|
||||
variant="secondary"
|
||||
onClick={() => verifyConnection.mutate(connection.id)}
|
||||
disabled={verifyConnection.isPending}
|
||||
>
|
||||
Test
|
||||
</Button>
|
||||
<Button
|
||||
onClick={saveConnectionSettings}
|
||||
disabled={updateConnection.isPending || !hasConnectionChanges}
|
||||
>
|
||||
Update
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
<AlertDialog>
|
||||
<AlertDialogTrigger asChild>
|
||||
<Button
|
||||
variant={model.enabled ? "secondary" : "outline"}
|
||||
size="sm"
|
||||
onClick={() =>
|
||||
updateModel.mutate({ id: model.id, data: { enabled: !model.enabled } })
|
||||
}
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
disabled={deleteConnection.isPending}
|
||||
aria-label={`Delete ${providerLabel}`}
|
||||
>
|
||||
{model.enabled ? "Enabled" : "Enable"}
|
||||
<Trash2 className="h-4 w-4 text-destructive" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</AlertDialogTrigger>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Delete this provider?</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
<span className="font-medium text-foreground">{providerLabel}</span> and all of
|
||||
its models will be removed from this search space. This cannot be undone.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel disabled={deleteConnection.isPending}>Cancel</AlertDialogCancel>
|
||||
<AlertDialogAction
|
||||
onClick={deleteCurrentConnection}
|
||||
disabled={deleteConnection.isPending}
|
||||
className="bg-destructive text-white hover:bg-destructive/90"
|
||||
>
|
||||
Delete
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
@ -394,19 +688,13 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label>Base URL</Label>
|
||||
<Input
|
||||
<UrlSuggestionCombobox
|
||||
value={baseUrl}
|
||||
onChange={(event) => setBaseUrl(event.target.value)}
|
||||
onChange={setBaseUrl}
|
||||
placeholder={
|
||||
isOllama ? "http://host.docker.internal:11434" : "https://api.example.com/v1"
|
||||
}
|
||||
list="model-conn-url-suggestions"
|
||||
/>
|
||||
<datalist id="model-conn-url-suggestions">
|
||||
{URL_SUGGESTIONS.map((url) => (
|
||||
<option key={url} value={url} />
|
||||
))}
|
||||
</datalist>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label>{isOllama ? "API Key (optional)" : "API Key"}</Label>
|
||||
|
|
@ -425,7 +713,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
Boolean(selectedProvider?.base_url_required && !baseUrl.trim())
|
||||
}
|
||||
>
|
||||
<PlugZap className="mr-2 h-4 w-4" /> Add
|
||||
Add
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -439,11 +727,17 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-3">
|
||||
{connections.map((connection) => (
|
||||
<ConnectionCard key={connection.id} connection={connection} />
|
||||
))}
|
||||
</div>
|
||||
{connections.length > 0 ? (
|
||||
<div className="flex flex-col gap-3">
|
||||
<Separator />
|
||||
<h3 className="text-sm font-semibold">Available Providers</h3>
|
||||
<div className="flex flex-col gap-3">
|
||||
{connections.map((connection) => (
|
||||
<ConnectionCard key={connection.id} connection={connection} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ export const connectionRead = z.object({
|
|||
id: z.number(),
|
||||
provider: z.string(),
|
||||
base_url: z.string().nullable().optional(),
|
||||
api_key: z.string().nullable().optional(),
|
||||
extra: z.record(z.string(), z.any()).default({}),
|
||||
scope: z.union([connectionScopeEnum, z.string()]),
|
||||
search_space_id: z.number().nullable().optional(),
|
||||
|
|
@ -73,6 +74,11 @@ export const modelUpdateRequest = z.object({
|
|||
capabilities_override: z.record(z.string(), z.any()).optional(),
|
||||
});
|
||||
|
||||
export const modelsBulkUpdateRequest = z.object({
|
||||
model_ids: z.array(z.number()).min(1).max(1000),
|
||||
enabled: z.boolean(),
|
||||
});
|
||||
|
||||
export const verifyConnectionResponse = z.object({
|
||||
status: z.string(),
|
||||
ok: z.boolean(),
|
||||
|
|
@ -107,6 +113,7 @@ export type ConnectionCreateRequest = z.infer<typeof connectionCreateRequest>;
|
|||
export type ConnectionUpdateRequest = z.infer<typeof connectionUpdateRequest>;
|
||||
export type ModelCreateRequest = z.infer<typeof modelCreateRequest>;
|
||||
export type ModelUpdateRequest = z.infer<typeof modelUpdateRequest>;
|
||||
export type ModelsBulkUpdateRequest = z.infer<typeof modelsBulkUpdateRequest>;
|
||||
export type ModelRoles = z.infer<typeof modelRoles>;
|
||||
export type VerifyConnectionResponse = z.infer<typeof verifyConnectionResponse>;
|
||||
export type ModelProviderRead = z.infer<typeof modelProviderRead>;
|
||||
|
|
|
|||
|
|
@ -10,12 +10,14 @@ import {
|
|||
type ModelProviderRead,
|
||||
type ModelRead,
|
||||
type ModelRoles,
|
||||
type ModelsBulkUpdateRequest,
|
||||
type ModelUpdateRequest,
|
||||
modelCreateRequest,
|
||||
modelProviderListResponse,
|
||||
modelListResponse,
|
||||
modelProviderListResponse,
|
||||
modelRead,
|
||||
modelRoles,
|
||||
modelsBulkUpdateRequest,
|
||||
modelUpdateRequest,
|
||||
type VerifyConnectionResponse,
|
||||
verifyConnectionResponse,
|
||||
|
|
@ -97,6 +99,25 @@ class ModelConnectionsApiService {
|
|||
});
|
||||
};
|
||||
|
||||
bulkUpdateModels = async (
|
||||
connectionId: number,
|
||||
request: ModelsBulkUpdateRequest
|
||||
): Promise<ModelRead[]> => {
|
||||
const parsed = modelsBulkUpdateRequest.safeParse(request);
|
||||
if (!parsed.success) {
|
||||
throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", "));
|
||||
}
|
||||
return baseApiService.request(
|
||||
`/api/v1/model-connections/${connectionId}/models`,
|
||||
modelListResponse,
|
||||
{
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: parsed.data,
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
testModel = async (id: number): Promise<VerifyConnectionResponse> => {
|
||||
return baseApiService.post(`/api/v1/models/${id}/test`, verifyConnectionResponse);
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue