refactor(model-connections): update frontend for provider-based models

This commit is contained in:
Anish Sarkar 2026-06-12 02:17:51 +05:30
parent 3dd54230e7
commit 610ff063d6
10 changed files with 177 additions and 124 deletions

View file

@ -49,7 +49,7 @@ export function DashboardClientLayout({
const firstGlobalChatModel = useMemo(() => {
for (const connection of globalConnections) {
const model = connection.models.find((item) => item.enabled && item.capabilities?.chat);
const model = connection.models.find((item) => item.enabled && item.supports_chat);
if (model) return model;
}
return null;

View file

@ -33,7 +33,7 @@ export default function OnboardPage() {
const firstGlobalChatModel = useMemo(() => {
for (const connection of globalConnections) {
const model = connection.models.find((item) => item.enabled && item.capabilities?.chat);
const model = connection.models.find((item) => item.enabled && item.supports_chat);
if (model) return model;
}
return null;

View file

@ -11,6 +11,13 @@ export const globalModelConnectionsAtom = atomWithQuery(() => ({
queryFn: () => modelConnectionsApiService.getGlobalConnections(),
}));
export const modelProvidersAtom = atomWithQuery(() => ({
queryKey: cacheKeys.modelConnections.providers(),
enabled: !!getBearerToken(),
staleTime: 60 * 60 * 1000,
queryFn: () => modelConnectionsApiService.getModelProviders(),
}));
export const modelConnectionsAtom = atomWithQuery((get) => {
const searchSpaceId = Number(get(activeSearchSpaceIdAtom));
return {

View file

@ -48,10 +48,10 @@ import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dial
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
import { membersAtom } from "@/atoms/members/members-query.atoms";
import {
globalNewLLMConfigsAtom,
llmPreferencesAtom,
newLLMConfigsAtom,
} from "@/atoms/new-llm-config/new-llm-config-query.atoms";
globalModelConnectionsAtom,
modelConnectionsAtom,
modelRolesAtom,
} from "@/atoms/model-connections/model-connections-query.atoms";
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status";
@ -976,9 +976,9 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
if (url) setPendingScreenImages((prev) => [...prev, url]);
}, [electronAPI, setPendingScreenImages]);
const { data: userConfigs } = useAtomValue(newLLMConfigsAtom);
const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom);
const { data: preferences } = useAtomValue(llmPreferencesAtom);
const { data: globalModelConnections } = useAtomValue(globalModelConnectionsAtom);
const { data: modelConnections } = useAtomValue(modelConnectionsAtom);
const { data: modelRoles } = useAtomValue(modelRolesAtom);
const { data: agentTools } = useAtomValue(agentToolsAtom);
const disabledTools = useAtomValue(disabledToolsAtom);
@ -1065,15 +1065,18 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
}, [hydrateDisabled]);
const hasModelConfigured = useMemo(() => {
if (!preferences) return false;
const agentLlmId = preferences.agent_llm_id;
if (agentLlmId === null || agentLlmId === undefined) return false;
if (agentLlmId <= 0) {
return globalConfigs?.some((c) => c.id === agentLlmId) ?? false;
const chatModelId = modelRoles?.chat_model_id ?? 0;
if (chatModelId === 0) {
return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) =>
connection.models.some((model) => model.enabled && Boolean(model.supports_chat))
);
}
return userConfigs?.some((c) => c.id === agentLlmId) ?? false;
}, [preferences, globalConfigs, userConfigs]);
return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) =>
connection.models.some(
(model) => model.id === chatModelId && model.enabled && Boolean(model.supports_chat)
)
);
}, [modelRoles?.chat_model_id, globalModelConnections, modelConnections]);
const isSendDisabled = isComposerEmpty || !hasModelConfigured || isBlockedByOtherUser;

View file

@ -56,18 +56,18 @@ function modelName(model: ModelRead) {
function connectionLabel(connection: ConnectionRead) {
if (connection.scope === "GLOBAL") return "Hosted";
return connection.litellm_provider || connection.protocol;
return connection.provider;
}
function flattenChatModels(connections: ConnectionRead[]) {
return connections.flatMap((connection) =>
connection.models
.filter((model) => model.enabled && Boolean(model.capabilities?.chat))
.filter((model) => model.enabled && Boolean(model.supports_chat))
.map((model) => ({
...model,
connectionId: connection.id,
connectionLabel: connectionLabel(connection),
provider: connection.litellm_provider || connection.protocol,
provider: connection.provider,
}))
);
}
@ -184,9 +184,14 @@ export function ModelSelector({
<span className="truncate">{modelName(model)}</span>
</div>
<div className="truncate text-xs text-muted-foreground">{model.model_id}</div>
{model.max_input_tokens ? (
<div className="text-xs text-muted-foreground">
{model.max_input_tokens.toLocaleString()} context
</div>
) : null}
</div>
<div className="ml-3 flex items-center gap-2">
{!model.capabilities?.vision ? (
{!model.supports_image_input ? (
<Badge variant="outline" className="gap-1">
<ImageOff className="h-3 w-3" /> No image
</Badge>

View file

@ -1,11 +1,12 @@
"use client";
import { useAtom, useAtomValue } from "jotai";
import { CheckCircle2, PlugZap, Plus, RefreshCcw, XCircle } from "lucide-react";
import { CheckCircle2, PlugZap, Plus, RefreshCcw, Trash2, XCircle } from "lucide-react";
import { useState } from "react";
import {
addManualModelMutationAtom,
createModelConnectionMutationAtom,
deleteModelConnectionMutationAtom,
discoverConnectionModelsMutationAtom,
testModelMutationAtom,
updateModelConnectionMutationAtom,
@ -16,6 +17,7 @@ import {
import {
globalModelConnectionsAtom,
modelConnectionsAtom,
modelProvidersAtom,
modelRolesAtom,
} from "@/atoms/model-connections/model-connections-query.atoms";
import { Badge } from "@/components/ui/badge";
@ -30,37 +32,9 @@ import {
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import type {
ConnectionProtocol,
ConnectionRead,
ModelRead,
} from "@/contracts/types/model-connections.types";
import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types";
import { getProviderIcon } from "@/lib/provider-icons";
const PROTOCOL_OPTIONS: { value: ConnectionProtocol; label: string; description: string }[] = [
{
value: "OPENAI_COMPATIBLE",
label: "OpenAI-compatible",
description: "Use for OpenAI, OpenRouter, Groq, vLLM, LM Studio, and compatible APIs.",
},
{
value: "ANTHROPIC",
label: "Anthropic",
description: "Use for Claude endpoints that require Anthropic headers.",
},
{
value: "OLLAMA",
label: "Ollama",
description: "Use for Ollama's native API.",
},
];
function defaultLitellmProvider(protocol: ConnectionProtocol) {
if (protocol === "OLLAMA") return "ollama_chat";
if (protocol === "ANTHROPIC") return "anthropic";
return "openai";
}
// Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict
// what the user can type — any OpenAI-compatible endpoint works.
const URL_SUGGESTIONS = [
@ -82,9 +56,19 @@ function modelLabel(model: ModelRead) {
}
function capability(model: ModelRead, key: "chat" | "vision" | "image_gen") {
return Boolean(model.capabilities?.[key]);
if (key === "chat") return Boolean(model.supports_chat);
if (key === "vision") return Boolean(model.supports_image_input);
return Boolean(model.supports_image_generation);
}
type ModelCapabilityFilter = "chat" | "vision" | "image_gen";
const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] = [
{ key: "chat", label: "Chat" },
{ key: "vision", label: "Vision" },
{ key: "image_gen", label: "Image" },
];
function StatusBadge({ connection }: { connection: ConnectionRead }) {
if (connection.last_status === "OK") {
return (
@ -107,9 +91,9 @@ function flattenModels(connections: ConnectionRead[]) {
return connections.flatMap((connection) =>
connection.models.map((model) => ({
...model,
connectionName: connection.litellm_provider || connection.protocol,
connectionName: connection.provider,
connectionId: connection.id,
provider: connection.litellm_provider || connection.protocol,
provider: connection.provider,
}))
);
}
@ -118,6 +102,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom);
const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom);
const updateConnection = useAtomValue(updateModelConnectionMutationAtom);
const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom);
const addManualModel = useAtomValue(addManualModelMutationAtom);
const updateModel = useAtomValue(updateModelMutationAtom);
const testModel = useAtomValue(testModelMutationAtom);
@ -127,9 +112,16 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
: [];
const [allowlistText, setAllowlistText] = useState(allowlist.join(", "));
const [manualModelId, setManualModelId] = useState("");
const [modelFilter, setModelFilter] = useState<ModelCapabilityFilter | null>(null);
const providerLabel = connection.litellm_provider || connection.protocol;
const isLocal = connection.protocol === "OLLAMA" || !connection.base_url?.startsWith("https");
const providerLabel = connection.provider;
const isLocal =
connection.provider === "ollama_chat" ||
connection.provider === "lm_studio" ||
!connection.base_url?.startsWith("https");
const filteredModels = modelFilter
? connection.models.filter((model) => capability(model, modelFilter))
: connection.models;
function saveAllowlist() {
const ids = allowlistText
@ -151,6 +143,14 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
);
}
function deleteCurrentConnection() {
const confirmed = window.confirm(
`Delete the ${providerLabel} connection and all of its models? This cannot be undone.`
);
if (!confirmed) return;
deleteConnection.mutate(connection.id);
}
return (
<div className="rounded-lg border p-4">
<div className="flex flex-wrap items-center justify-between gap-3">
@ -175,6 +175,14 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
<Button variant="outline" size="sm" onClick={() => discoverModels.mutate(connection.id)}>
<RefreshCcw className="mr-2 h-4 w-4" /> Discover
</Button>
<Button
variant="destructive"
size="sm"
onClick={deleteCurrentConnection}
disabled={deleteConnection.isPending}
>
<Trash2 className="mr-2 h-4 w-4" /> Delete
</Button>
</div>
</div>
@ -232,8 +240,38 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
</Button>
</div>
{connection.models.length > 0 ? (
<div className="mt-4 flex flex-wrap items-center gap-2">
<span className="text-xs font-medium text-muted-foreground">Filter models</span>
{MODEL_CAPABILITY_FILTERS.map((filter) => {
const count = connection.models.filter((model) => capability(model, filter.key)).length;
const isActive = modelFilter === filter.key;
return (
<Button
key={filter.key}
type="button"
variant={isActive ? "secondary" : "outline"}
size="sm"
className="h-7 rounded-full px-3 text-xs"
onClick={() => setModelFilter(isActive ? null : filter.key)}
>
{filter.label}
<span className="ml-1 text-muted-foreground">{count}</span>
</Button>
);
})}
</div>
) : null}
<div className="mt-4 grid gap-2">
{connection.models.map((model) => (
{filteredModels.length === 0 && modelFilter ? (
<div className="rounded-md bg-muted/30 px-3 py-2 text-xs text-muted-foreground">
No {MODEL_CAPABILITY_FILTERS.find((filter) => filter.key === modelFilter)?.label.toLowerCase()}{" "}
models found on this connection.
</div>
) : null}
{filteredModels.map((model) => (
<div
key={model.id}
className="flex flex-wrap items-center justify-between gap-2 rounded-md bg-muted/40 px-3 py-2"
@ -250,8 +288,8 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
</div>
<div className="text-xs text-muted-foreground">
{["chat", "vision", "image_gen"]
.filter((key) => Boolean(model.capabilities?.[key]))
.join(", ") || "No verified capabilities"}
.filter((key) => capability(model, key as "chat" | "vision" | "image_gen"))
.join(", ") || "No discovered capabilities"}
</div>
</div>
<div className="flex gap-2">
@ -278,18 +316,16 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: number }) {
const [{ data: globalConnections = [] }] = useAtom(globalModelConnectionsAtom);
const [{ data: connections = [] }] = useAtom(modelConnectionsAtom);
const [{ data: providers = [] }] = useAtom(modelProvidersAtom);
const [{ data: roles }] = useAtom(modelRolesAtom);
const createConnection = useAtomValue(createModelConnectionMutationAtom);
const updateRoles = useAtomValue(updateModelRolesMutationAtom);
const [protocol, setProtocol] = useState<ConnectionProtocol>("OPENAI_COMPATIBLE");
const [provider, setProvider] = useState("openai_compatible");
const [baseUrl, setBaseUrl] = useState("");
const [apiKey, setApiKey] = useState("");
const [litellmProvider, setLitellmProvider] = useState("");
const [showAdvancedProvider, setShowAdvancedProvider] = useState(false);
const selectedProtocol = PROTOCOL_OPTIONS.find((item) => item.value === protocol);
const protocolDefaultProvider = defaultLitellmProvider(protocol);
const isOllama = protocol === "OLLAMA";
const selectedProvider = providers.find((item) => item.provider === provider);
const isOllama = provider === "ollama_chat";
const allConnections = [...globalConnections, ...connections];
const enabledModels = flattenModels(allConnections).filter((model) => model.enabled);
@ -298,11 +334,9 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
const imageModels = enabledModels.filter((model) => capability(model, "image_gen"));
function handleCreate() {
const explicitProvider = litellmProvider.trim();
createConnection.mutate(
{
protocol,
litellm_provider: explicitProvider ? explicitProvider : null,
provider,
base_url: baseUrl || null,
api_key: apiKey || null,
scope: "SEARCH_SPACE",
@ -337,18 +371,22 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
<CardContent className="space-y-6">
<div className="grid gap-3 md:grid-cols-[220px_1fr_1fr_auto]">
<div className="space-y-2">
<Label>Protocol</Label>
<Label>Provider</Label>
<Select
value={protocol}
onValueChange={(value) => setProtocol(value as ConnectionProtocol)}
value={provider}
onValueChange={(value) => {
setProvider(value);
const next = providers.find((item) => item.provider === value);
if (next?.default_base_url) setBaseUrl(next.default_base_url);
}}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
{PROTOCOL_OPTIONS.map((item) => (
<SelectItem key={item.value} value={item.value}>
{item.label}
{providers.map((item) => (
<SelectItem key={item.provider} value={item.provider}>
{item.provider}
</SelectItem>
))}
</SelectContent>
@ -382,7 +420,10 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
<div className="flex items-end">
<Button
onClick={handleCreate}
disabled={createConnection.isPending || !baseUrl.trim()}
disabled={
createConnection.isPending ||
Boolean(selectedProvider?.base_url_required && !baseUrl.trim())
}
>
<PlugZap className="mr-2 h-4 w-4" /> Add
</Button>
@ -390,35 +431,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
</div>
<div className="space-y-3">
<p className="text-xs text-muted-foreground">
{selectedProtocol?.description} Base URL is explicit and editable; no provider presets
are required. Local URLs are tested from the backend container, so use
host.docker.internal instead of localhost.
{selectedProvider
? `${selectedProvider.transport} transport, ${selectedProvider.discovery} discovery.`
: "Choose a provider preset."}{" "}
Base URL is explicit and editable. Local URLs are tested from the backend container,
so use host.docker.internal instead of localhost.
</p>
<div>
<Button
type="button"
variant="ghost"
size="sm"
className="h-auto px-0 text-xs"
onClick={() => setShowAdvancedProvider((current) => !current)}
>
Advanced: LiteLLM provider ({litellmProvider.trim() || protocolDefaultProvider})
</Button>
{showAdvancedProvider ? (
<div className="mt-2 max-w-sm space-y-2">
<Label>LiteLLM provider override</Label>
<Input
value={litellmProvider}
onChange={(event) => setLitellmProvider(event.target.value)}
placeholder={protocolDefaultProvider}
/>
<p className="text-xs text-muted-foreground">
Leave empty to use the protocol default. Set this for more accurate LiteLLM
capabilities/costs, for example openrouter, groq, gemini, or azure.
</p>
</div>
) : null}
</div>
</div>
<div className="space-y-3">

View file

@ -1,26 +1,19 @@
import { z } from "zod";
export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "ANTHROPIC"]);
export const connectionScopeEnum = z.enum(["GLOBAL", "SEARCH_SPACE", "USER"]);
export const modelSourceEnum = z.enum(["DISCOVERED", "MANUAL"]);
export const modelCapabilities = z.object({
chat: z.boolean().optional(),
vision: z.boolean().optional(),
image_gen: z.boolean().optional(),
embedding: z.boolean().optional(),
tools: z.boolean().optional(),
});
export const modelRead = z.object({
id: z.number(),
connection_id: z.number(),
model_id: z.string(),
display_name: z.string().nullable().optional(),
source: z.union([modelSourceEnum, z.string()]),
capabilities: z.record(z.string(), z.any()).default({}),
capabilities_declared: z.record(z.string(), z.any()).default({}),
capabilities_verified: z.record(z.string(), z.any()).default({}),
supports_chat: z.boolean().nullable().optional(),
max_input_tokens: z.number().nullable().optional(),
supports_image_input: z.boolean().nullable().optional(),
supports_tools: z.boolean().nullable().optional(),
supports_image_generation: z.boolean().nullable().optional(),
capabilities_override: z.record(z.string(), z.any()).default({}),
embedding_dimension: z.number().nullable().optional(),
enabled: z.boolean(),
@ -31,8 +24,7 @@ export const modelRead = z.object({
export const connectionRead = z.object({
id: z.number(),
protocol: z.union([connectionProtocolEnum, z.string()]),
litellm_provider: z.string().nullable().optional(),
provider: z.string(),
base_url: z.string().nullable().optional(),
extra: z.record(z.string(), z.any()).default({}),
scope: z.union([connectionScopeEnum, z.string()]),
@ -48,8 +40,7 @@ export const connectionRead = z.object({
});
export const connectionCreateRequest = z.object({
protocol: connectionProtocolEnum,
litellm_provider: z.string().nullable().optional(),
provider: z.string().min(1),
base_url: z.string().nullable().optional(),
api_key: z.string().nullable().optional(),
extra: z.record(z.string(), z.any()).default({}),
@ -59,7 +50,7 @@ export const connectionCreateRequest = z.object({
});
export const connectionUpdateRequest = z.object({
litellm_provider: z.string().nullable().optional(),
provider: z.string().nullable().optional(),
base_url: z.string().nullable().optional(),
api_key: z.string().nullable().optional(),
extra: z.record(z.string(), z.any()).optional(),
@ -74,6 +65,11 @@ export const modelCreateRequest = z.object({
export const modelUpdateRequest = z.object({
display_name: z.string().nullable().optional(),
enabled: z.boolean().optional(),
supports_chat: z.boolean().nullable().optional(),
max_input_tokens: z.number().nullable().optional(),
supports_image_input: z.boolean().nullable().optional(),
supports_tools: z.boolean().nullable().optional(),
supports_image_generation: z.boolean().nullable().optional(),
capabilities_override: z.record(z.string(), z.any()).optional(),
});
@ -89,10 +85,21 @@ export const modelRoles = z.object({
image_gen_model_id: z.number().nullable().optional(),
});
export const modelProviderRead = z.object({
provider: z.string(),
transport: z.string(),
discovery: z.string(),
default_base_url: z.string().nullable().optional(),
base_url_required: z.boolean(),
auth_style: z.string(),
local_only: z.boolean().default(false),
});
export const modelProviderListResponse = z.array(modelProviderRead);
export const connectionListResponse = z.array(connectionRead);
export const modelListResponse = z.array(modelRead);
export type ConnectionProtocol = z.infer<typeof connectionProtocolEnum>;
export type ConnectionScope = z.infer<typeof connectionScopeEnum>;
export type ModelRead = z.infer<typeof modelRead>;
export type ConnectionRead = z.infer<typeof connectionRead>;
@ -102,3 +109,4 @@ export type ModelCreateRequest = z.infer<typeof modelCreateRequest>;
export type ModelUpdateRequest = z.infer<typeof modelUpdateRequest>;
export type ModelRoles = z.infer<typeof modelRoles>;
export type VerifyConnectionResponse = z.infer<typeof verifyConnectionResponse>;
export type ModelProviderRead = z.infer<typeof modelProviderRead>;

View file

@ -47,11 +47,16 @@ function buildKind(
capability: "chat" | "image_gen" | "vision",
prefId: number | null | undefined
): EligibleModelKind {
const supportsCapability = (model: ModelRead) => {
if (capability === "chat") return Boolean(model.supports_chat);
if (capability === "vision") return Boolean(model.supports_image_input);
return Boolean(model.supports_image_generation);
};
const toOption = (connection: ConnectionRead, model: ModelRead, isBYOK: boolean) => ({
id: model.id,
name: model.display_name || model.model_id,
modelName: model.model_id,
provider: connection.litellm_provider || connection.protocol,
provider: connection.provider,
isBYOK,
});
@ -60,7 +65,7 @@ function buildKind(
.filter(
(model) =>
model.enabled &&
Boolean(model.capabilities?.[capability]) &&
supportsCapability(model) &&
String(model.billing_tier ?? "").toLowerCase() === "premium"
)
.map((model) => toOption(connection, model, false))
@ -68,7 +73,7 @@ function buildKind(
const byokOptions: EligibleModelOption[] = (byok ?? []).flatMap((connection) =>
connection.models
.filter((model) => model.enabled && Boolean(model.capabilities?.[capability]))
.filter((model) => model.enabled && supportsCapability(model))
.map((model) => toOption(connection, model, true))
);

View file

@ -7,10 +7,12 @@ import {
connectionRead,
connectionUpdateRequest,
type ModelCreateRequest,
type ModelProviderRead,
type ModelRead,
type ModelRoles,
type ModelUpdateRequest,
modelCreateRequest,
modelProviderListResponse,
modelListResponse,
modelRead,
modelRoles,
@ -26,6 +28,10 @@ class ModelConnectionsApiService {
return baseApiService.get(`/api/v1/global-model-connections`, connectionListResponse);
};
getModelProviders = async (): Promise<ModelProviderRead[]> => {
return baseApiService.get(`/api/v1/model-providers`, modelProviderListResponse);
};
getConnections = async (searchSpaceId: number): Promise<ConnectionRead[]> => {
return baseApiService.get(
`/api/v1/model-connections?search_space_id=${searchSpaceId}`,

View file

@ -47,6 +47,7 @@ export const cacheKeys = {
modelConnections: {
all: (searchSpaceId: number) => ["model-connections", searchSpaceId] as const,
global: () => ["model-connections", "global"] as const,
providers: () => ["model-connections", "providers"] as const,
roles: (searchSpaceId: number) => ["model-roles", searchSpaceId] as const,
},
imageGenConfigs: {