mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-12 20:45:20 +02:00
refactor(model-connections): update frontend for provider-based models
This commit is contained in:
parent
3dd54230e7
commit
610ff063d6
10 changed files with 177 additions and 124 deletions
|
|
@ -49,7 +49,7 @@ export function DashboardClientLayout({
|
|||
|
||||
const firstGlobalChatModel = useMemo(() => {
|
||||
for (const connection of globalConnections) {
|
||||
const model = connection.models.find((item) => item.enabled && item.capabilities?.chat);
|
||||
const model = connection.models.find((item) => item.enabled && item.supports_chat);
|
||||
if (model) return model;
|
||||
}
|
||||
return null;
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ export default function OnboardPage() {
|
|||
|
||||
const firstGlobalChatModel = useMemo(() => {
|
||||
for (const connection of globalConnections) {
|
||||
const model = connection.models.find((item) => item.enabled && item.capabilities?.chat);
|
||||
const model = connection.models.find((item) => item.enabled && item.supports_chat);
|
||||
if (model) return model;
|
||||
}
|
||||
return null;
|
||||
|
|
|
|||
|
|
@ -11,6 +11,13 @@ export const globalModelConnectionsAtom = atomWithQuery(() => ({
|
|||
queryFn: () => modelConnectionsApiService.getGlobalConnections(),
|
||||
}));
|
||||
|
||||
export const modelProvidersAtom = atomWithQuery(() => ({
|
||||
queryKey: cacheKeys.modelConnections.providers(),
|
||||
enabled: !!getBearerToken(),
|
||||
staleTime: 60 * 60 * 1000,
|
||||
queryFn: () => modelConnectionsApiService.getModelProviders(),
|
||||
}));
|
||||
|
||||
export const modelConnectionsAtom = atomWithQuery((get) => {
|
||||
const searchSpaceId = Number(get(activeSearchSpaceIdAtom));
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -48,10 +48,10 @@ import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dial
|
|||
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
|
||||
import { membersAtom } from "@/atoms/members/members-query.atoms";
|
||||
import {
|
||||
globalNewLLMConfigsAtom,
|
||||
llmPreferencesAtom,
|
||||
newLLMConfigsAtom,
|
||||
} from "@/atoms/new-llm-config/new-llm-config-query.atoms";
|
||||
globalModelConnectionsAtom,
|
||||
modelConnectionsAtom,
|
||||
modelRolesAtom,
|
||||
} from "@/atoms/model-connections/model-connections-query.atoms";
|
||||
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
||||
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
|
||||
import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status";
|
||||
|
|
@ -976,9 +976,9 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
if (url) setPendingScreenImages((prev) => [...prev, url]);
|
||||
}, [electronAPI, setPendingScreenImages]);
|
||||
|
||||
const { data: userConfigs } = useAtomValue(newLLMConfigsAtom);
|
||||
const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom);
|
||||
const { data: preferences } = useAtomValue(llmPreferencesAtom);
|
||||
const { data: globalModelConnections } = useAtomValue(globalModelConnectionsAtom);
|
||||
const { data: modelConnections } = useAtomValue(modelConnectionsAtom);
|
||||
const { data: modelRoles } = useAtomValue(modelRolesAtom);
|
||||
|
||||
const { data: agentTools } = useAtomValue(agentToolsAtom);
|
||||
const disabledTools = useAtomValue(disabledToolsAtom);
|
||||
|
|
@ -1065,15 +1065,18 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
|||
}, [hydrateDisabled]);
|
||||
|
||||
const hasModelConfigured = useMemo(() => {
|
||||
if (!preferences) return false;
|
||||
const agentLlmId = preferences.agent_llm_id;
|
||||
if (agentLlmId === null || agentLlmId === undefined) return false;
|
||||
|
||||
if (agentLlmId <= 0) {
|
||||
return globalConfigs?.some((c) => c.id === agentLlmId) ?? false;
|
||||
const chatModelId = modelRoles?.chat_model_id ?? 0;
|
||||
if (chatModelId === 0) {
|
||||
return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) =>
|
||||
connection.models.some((model) => model.enabled && Boolean(model.supports_chat))
|
||||
);
|
||||
}
|
||||
return userConfigs?.some((c) => c.id === agentLlmId) ?? false;
|
||||
}, [preferences, globalConfigs, userConfigs]);
|
||||
return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) =>
|
||||
connection.models.some(
|
||||
(model) => model.id === chatModelId && model.enabled && Boolean(model.supports_chat)
|
||||
)
|
||||
);
|
||||
}, [modelRoles?.chat_model_id, globalModelConnections, modelConnections]);
|
||||
|
||||
const isSendDisabled = isComposerEmpty || !hasModelConfigured || isBlockedByOtherUser;
|
||||
|
||||
|
|
|
|||
|
|
@ -56,18 +56,18 @@ function modelName(model: ModelRead) {
|
|||
|
||||
function connectionLabel(connection: ConnectionRead) {
|
||||
if (connection.scope === "GLOBAL") return "Hosted";
|
||||
return connection.litellm_provider || connection.protocol;
|
||||
return connection.provider;
|
||||
}
|
||||
|
||||
function flattenChatModels(connections: ConnectionRead[]) {
|
||||
return connections.flatMap((connection) =>
|
||||
connection.models
|
||||
.filter((model) => model.enabled && Boolean(model.capabilities?.chat))
|
||||
.filter((model) => model.enabled && Boolean(model.supports_chat))
|
||||
.map((model) => ({
|
||||
...model,
|
||||
connectionId: connection.id,
|
||||
connectionLabel: connectionLabel(connection),
|
||||
provider: connection.litellm_provider || connection.protocol,
|
||||
provider: connection.provider,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
|
@ -184,9 +184,14 @@ export function ModelSelector({
|
|||
<span className="truncate">{modelName(model)}</span>
|
||||
</div>
|
||||
<div className="truncate text-xs text-muted-foreground">{model.model_id}</div>
|
||||
{model.max_input_tokens ? (
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{model.max_input_tokens.toLocaleString()} context
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
<div className="ml-3 flex items-center gap-2">
|
||||
{!model.capabilities?.vision ? (
|
||||
{!model.supports_image_input ? (
|
||||
<Badge variant="outline" className="gap-1">
|
||||
<ImageOff className="h-3 w-3" /> No image
|
||||
</Badge>
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
"use client";
|
||||
|
||||
import { useAtom, useAtomValue } from "jotai";
|
||||
import { CheckCircle2, PlugZap, Plus, RefreshCcw, XCircle } from "lucide-react";
|
||||
import { CheckCircle2, PlugZap, Plus, RefreshCcw, Trash2, XCircle } from "lucide-react";
|
||||
import { useState } from "react";
|
||||
import {
|
||||
addManualModelMutationAtom,
|
||||
createModelConnectionMutationAtom,
|
||||
deleteModelConnectionMutationAtom,
|
||||
discoverConnectionModelsMutationAtom,
|
||||
testModelMutationAtom,
|
||||
updateModelConnectionMutationAtom,
|
||||
|
|
@ -16,6 +17,7 @@ import {
|
|||
import {
|
||||
globalModelConnectionsAtom,
|
||||
modelConnectionsAtom,
|
||||
modelProvidersAtom,
|
||||
modelRolesAtom,
|
||||
} from "@/atoms/model-connections/model-connections-query.atoms";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
|
|
@ -30,37 +32,9 @@ import {
|
|||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import type {
|
||||
ConnectionProtocol,
|
||||
ConnectionRead,
|
||||
ModelRead,
|
||||
} from "@/contracts/types/model-connections.types";
|
||||
import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types";
|
||||
import { getProviderIcon } from "@/lib/provider-icons";
|
||||
|
||||
const PROTOCOL_OPTIONS: { value: ConnectionProtocol; label: string; description: string }[] = [
|
||||
{
|
||||
value: "OPENAI_COMPATIBLE",
|
||||
label: "OpenAI-compatible",
|
||||
description: "Use for OpenAI, OpenRouter, Groq, vLLM, LM Studio, and compatible APIs.",
|
||||
},
|
||||
{
|
||||
value: "ANTHROPIC",
|
||||
label: "Anthropic",
|
||||
description: "Use for Claude endpoints that require Anthropic headers.",
|
||||
},
|
||||
{
|
||||
value: "OLLAMA",
|
||||
label: "Ollama",
|
||||
description: "Use for Ollama's native API.",
|
||||
},
|
||||
];
|
||||
|
||||
function defaultLitellmProvider(protocol: ConnectionProtocol) {
|
||||
if (protocol === "OLLAMA") return "ollama_chat";
|
||||
if (protocol === "ANTHROPIC") return "anthropic";
|
||||
return "openai";
|
||||
}
|
||||
|
||||
// Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict
|
||||
// what the user can type — any OpenAI-compatible endpoint works.
|
||||
const URL_SUGGESTIONS = [
|
||||
|
|
@ -82,9 +56,19 @@ function modelLabel(model: ModelRead) {
|
|||
}
|
||||
|
||||
function capability(model: ModelRead, key: "chat" | "vision" | "image_gen") {
|
||||
return Boolean(model.capabilities?.[key]);
|
||||
if (key === "chat") return Boolean(model.supports_chat);
|
||||
if (key === "vision") return Boolean(model.supports_image_input);
|
||||
return Boolean(model.supports_image_generation);
|
||||
}
|
||||
|
||||
type ModelCapabilityFilter = "chat" | "vision" | "image_gen";
|
||||
|
||||
const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] = [
|
||||
{ key: "chat", label: "Chat" },
|
||||
{ key: "vision", label: "Vision" },
|
||||
{ key: "image_gen", label: "Image" },
|
||||
];
|
||||
|
||||
function StatusBadge({ connection }: { connection: ConnectionRead }) {
|
||||
if (connection.last_status === "OK") {
|
||||
return (
|
||||
|
|
@ -107,9 +91,9 @@ function flattenModels(connections: ConnectionRead[]) {
|
|||
return connections.flatMap((connection) =>
|
||||
connection.models.map((model) => ({
|
||||
...model,
|
||||
connectionName: connection.litellm_provider || connection.protocol,
|
||||
connectionName: connection.provider,
|
||||
connectionId: connection.id,
|
||||
provider: connection.litellm_provider || connection.protocol,
|
||||
provider: connection.provider,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
|
@ -118,6 +102,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
|||
const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom);
|
||||
const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom);
|
||||
const updateConnection = useAtomValue(updateModelConnectionMutationAtom);
|
||||
const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom);
|
||||
const addManualModel = useAtomValue(addManualModelMutationAtom);
|
||||
const updateModel = useAtomValue(updateModelMutationAtom);
|
||||
const testModel = useAtomValue(testModelMutationAtom);
|
||||
|
|
@ -127,9 +112,16 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
|||
: [];
|
||||
const [allowlistText, setAllowlistText] = useState(allowlist.join(", "));
|
||||
const [manualModelId, setManualModelId] = useState("");
|
||||
const [modelFilter, setModelFilter] = useState<ModelCapabilityFilter | null>(null);
|
||||
|
||||
const providerLabel = connection.litellm_provider || connection.protocol;
|
||||
const isLocal = connection.protocol === "OLLAMA" || !connection.base_url?.startsWith("https");
|
||||
const providerLabel = connection.provider;
|
||||
const isLocal =
|
||||
connection.provider === "ollama_chat" ||
|
||||
connection.provider === "lm_studio" ||
|
||||
!connection.base_url?.startsWith("https");
|
||||
const filteredModels = modelFilter
|
||||
? connection.models.filter((model) => capability(model, modelFilter))
|
||||
: connection.models;
|
||||
|
||||
function saveAllowlist() {
|
||||
const ids = allowlistText
|
||||
|
|
@ -151,6 +143,14 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
|||
);
|
||||
}
|
||||
|
||||
function deleteCurrentConnection() {
|
||||
const confirmed = window.confirm(
|
||||
`Delete the ${providerLabel} connection and all of its models? This cannot be undone.`
|
||||
);
|
||||
if (!confirmed) return;
|
||||
deleteConnection.mutate(connection.id);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border p-4">
|
||||
<div className="flex flex-wrap items-center justify-between gap-3">
|
||||
|
|
@ -175,6 +175,14 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
|||
<Button variant="outline" size="sm" onClick={() => discoverModels.mutate(connection.id)}>
|
||||
<RefreshCcw className="mr-2 h-4 w-4" /> Discover
|
||||
</Button>
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
onClick={deleteCurrentConnection}
|
||||
disabled={deleteConnection.isPending}
|
||||
>
|
||||
<Trash2 className="mr-2 h-4 w-4" /> Delete
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
|
@ -232,8 +240,38 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
|||
</Button>
|
||||
</div>
|
||||
|
||||
{connection.models.length > 0 ? (
|
||||
<div className="mt-4 flex flex-wrap items-center gap-2">
|
||||
<span className="text-xs font-medium text-muted-foreground">Filter models</span>
|
||||
{MODEL_CAPABILITY_FILTERS.map((filter) => {
|
||||
const count = connection.models.filter((model) => capability(model, filter.key)).length;
|
||||
const isActive = modelFilter === filter.key;
|
||||
|
||||
return (
|
||||
<Button
|
||||
key={filter.key}
|
||||
type="button"
|
||||
variant={isActive ? "secondary" : "outline"}
|
||||
size="sm"
|
||||
className="h-7 rounded-full px-3 text-xs"
|
||||
onClick={() => setModelFilter(isActive ? null : filter.key)}
|
||||
>
|
||||
{filter.label}
|
||||
<span className="ml-1 text-muted-foreground">{count}</span>
|
||||
</Button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
<div className="mt-4 grid gap-2">
|
||||
{connection.models.map((model) => (
|
||||
{filteredModels.length === 0 && modelFilter ? (
|
||||
<div className="rounded-md bg-muted/30 px-3 py-2 text-xs text-muted-foreground">
|
||||
No {MODEL_CAPABILITY_FILTERS.find((filter) => filter.key === modelFilter)?.label.toLowerCase()}{" "}
|
||||
models found on this connection.
|
||||
</div>
|
||||
) : null}
|
||||
{filteredModels.map((model) => (
|
||||
<div
|
||||
key={model.id}
|
||||
className="flex flex-wrap items-center justify-between gap-2 rounded-md bg-muted/40 px-3 py-2"
|
||||
|
|
@ -250,8 +288,8 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
|||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{["chat", "vision", "image_gen"]
|
||||
.filter((key) => Boolean(model.capabilities?.[key]))
|
||||
.join(", ") || "No verified capabilities"}
|
||||
.filter((key) => capability(model, key as "chat" | "vision" | "image_gen"))
|
||||
.join(", ") || "No discovered capabilities"}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
|
|
@ -278,18 +316,16 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
|||
export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: number }) {
|
||||
const [{ data: globalConnections = [] }] = useAtom(globalModelConnectionsAtom);
|
||||
const [{ data: connections = [] }] = useAtom(modelConnectionsAtom);
|
||||
const [{ data: providers = [] }] = useAtom(modelProvidersAtom);
|
||||
const [{ data: roles }] = useAtom(modelRolesAtom);
|
||||
const createConnection = useAtomValue(createModelConnectionMutationAtom);
|
||||
const updateRoles = useAtomValue(updateModelRolesMutationAtom);
|
||||
|
||||
const [protocol, setProtocol] = useState<ConnectionProtocol>("OPENAI_COMPATIBLE");
|
||||
const [provider, setProvider] = useState("openai_compatible");
|
||||
const [baseUrl, setBaseUrl] = useState("");
|
||||
const [apiKey, setApiKey] = useState("");
|
||||
const [litellmProvider, setLitellmProvider] = useState("");
|
||||
const [showAdvancedProvider, setShowAdvancedProvider] = useState(false);
|
||||
const selectedProtocol = PROTOCOL_OPTIONS.find((item) => item.value === protocol);
|
||||
const protocolDefaultProvider = defaultLitellmProvider(protocol);
|
||||
const isOllama = protocol === "OLLAMA";
|
||||
const selectedProvider = providers.find((item) => item.provider === provider);
|
||||
const isOllama = provider === "ollama_chat";
|
||||
|
||||
const allConnections = [...globalConnections, ...connections];
|
||||
const enabledModels = flattenModels(allConnections).filter((model) => model.enabled);
|
||||
|
|
@ -298,11 +334,9 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
const imageModels = enabledModels.filter((model) => capability(model, "image_gen"));
|
||||
|
||||
function handleCreate() {
|
||||
const explicitProvider = litellmProvider.trim();
|
||||
createConnection.mutate(
|
||||
{
|
||||
protocol,
|
||||
litellm_provider: explicitProvider ? explicitProvider : null,
|
||||
provider,
|
||||
base_url: baseUrl || null,
|
||||
api_key: apiKey || null,
|
||||
scope: "SEARCH_SPACE",
|
||||
|
|
@ -337,18 +371,22 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
<CardContent className="space-y-6">
|
||||
<div className="grid gap-3 md:grid-cols-[220px_1fr_1fr_auto]">
|
||||
<div className="space-y-2">
|
||||
<Label>Protocol</Label>
|
||||
<Label>Provider</Label>
|
||||
<Select
|
||||
value={protocol}
|
||||
onValueChange={(value) => setProtocol(value as ConnectionProtocol)}
|
||||
value={provider}
|
||||
onValueChange={(value) => {
|
||||
setProvider(value);
|
||||
const next = providers.find((item) => item.provider === value);
|
||||
if (next?.default_base_url) setBaseUrl(next.default_base_url);
|
||||
}}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{PROTOCOL_OPTIONS.map((item) => (
|
||||
<SelectItem key={item.value} value={item.value}>
|
||||
{item.label}
|
||||
{providers.map((item) => (
|
||||
<SelectItem key={item.provider} value={item.provider}>
|
||||
{item.provider}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
|
|
@ -382,7 +420,10 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
<div className="flex items-end">
|
||||
<Button
|
||||
onClick={handleCreate}
|
||||
disabled={createConnection.isPending || !baseUrl.trim()}
|
||||
disabled={
|
||||
createConnection.isPending ||
|
||||
Boolean(selectedProvider?.base_url_required && !baseUrl.trim())
|
||||
}
|
||||
>
|
||||
<PlugZap className="mr-2 h-4 w-4" /> Add
|
||||
</Button>
|
||||
|
|
@ -390,35 +431,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
</div>
|
||||
<div className="space-y-3">
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{selectedProtocol?.description} Base URL is explicit and editable; no provider presets
|
||||
are required. Local URLs are tested from the backend container, so use
|
||||
host.docker.internal instead of localhost.
|
||||
{selectedProvider
|
||||
? `${selectedProvider.transport} transport, ${selectedProvider.discovery} discovery.`
|
||||
: "Choose a provider preset."}{" "}
|
||||
Base URL is explicit and editable. Local URLs are tested from the backend container,
|
||||
so use host.docker.internal instead of localhost.
|
||||
</p>
|
||||
<div>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-auto px-0 text-xs"
|
||||
onClick={() => setShowAdvancedProvider((current) => !current)}
|
||||
>
|
||||
Advanced: LiteLLM provider ({litellmProvider.trim() || protocolDefaultProvider})
|
||||
</Button>
|
||||
{showAdvancedProvider ? (
|
||||
<div className="mt-2 max-w-sm space-y-2">
|
||||
<Label>LiteLLM provider override</Label>
|
||||
<Input
|
||||
value={litellmProvider}
|
||||
onChange={(event) => setLitellmProvider(event.target.value)}
|
||||
placeholder={protocolDefaultProvider}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Leave empty to use the protocol default. Set this for more accurate LiteLLM
|
||||
capabilities/costs, for example openrouter, groq, gemini, or azure.
|
||||
</p>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-3">
|
||||
|
|
|
|||
|
|
@ -1,26 +1,19 @@
|
|||
import { z } from "zod";
|
||||
|
||||
export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "ANTHROPIC"]);
|
||||
export const connectionScopeEnum = z.enum(["GLOBAL", "SEARCH_SPACE", "USER"]);
|
||||
export const modelSourceEnum = z.enum(["DISCOVERED", "MANUAL"]);
|
||||
|
||||
export const modelCapabilities = z.object({
|
||||
chat: z.boolean().optional(),
|
||||
vision: z.boolean().optional(),
|
||||
image_gen: z.boolean().optional(),
|
||||
embedding: z.boolean().optional(),
|
||||
tools: z.boolean().optional(),
|
||||
});
|
||||
|
||||
export const modelRead = z.object({
|
||||
id: z.number(),
|
||||
connection_id: z.number(),
|
||||
model_id: z.string(),
|
||||
display_name: z.string().nullable().optional(),
|
||||
source: z.union([modelSourceEnum, z.string()]),
|
||||
capabilities: z.record(z.string(), z.any()).default({}),
|
||||
capabilities_declared: z.record(z.string(), z.any()).default({}),
|
||||
capabilities_verified: z.record(z.string(), z.any()).default({}),
|
||||
supports_chat: z.boolean().nullable().optional(),
|
||||
max_input_tokens: z.number().nullable().optional(),
|
||||
supports_image_input: z.boolean().nullable().optional(),
|
||||
supports_tools: z.boolean().nullable().optional(),
|
||||
supports_image_generation: z.boolean().nullable().optional(),
|
||||
capabilities_override: z.record(z.string(), z.any()).default({}),
|
||||
embedding_dimension: z.number().nullable().optional(),
|
||||
enabled: z.boolean(),
|
||||
|
|
@ -31,8 +24,7 @@ export const modelRead = z.object({
|
|||
|
||||
export const connectionRead = z.object({
|
||||
id: z.number(),
|
||||
protocol: z.union([connectionProtocolEnum, z.string()]),
|
||||
litellm_provider: z.string().nullable().optional(),
|
||||
provider: z.string(),
|
||||
base_url: z.string().nullable().optional(),
|
||||
extra: z.record(z.string(), z.any()).default({}),
|
||||
scope: z.union([connectionScopeEnum, z.string()]),
|
||||
|
|
@ -48,8 +40,7 @@ export const connectionRead = z.object({
|
|||
});
|
||||
|
||||
export const connectionCreateRequest = z.object({
|
||||
protocol: connectionProtocolEnum,
|
||||
litellm_provider: z.string().nullable().optional(),
|
||||
provider: z.string().min(1),
|
||||
base_url: z.string().nullable().optional(),
|
||||
api_key: z.string().nullable().optional(),
|
||||
extra: z.record(z.string(), z.any()).default({}),
|
||||
|
|
@ -59,7 +50,7 @@ export const connectionCreateRequest = z.object({
|
|||
});
|
||||
|
||||
export const connectionUpdateRequest = z.object({
|
||||
litellm_provider: z.string().nullable().optional(),
|
||||
provider: z.string().nullable().optional(),
|
||||
base_url: z.string().nullable().optional(),
|
||||
api_key: z.string().nullable().optional(),
|
||||
extra: z.record(z.string(), z.any()).optional(),
|
||||
|
|
@ -74,6 +65,11 @@ export const modelCreateRequest = z.object({
|
|||
export const modelUpdateRequest = z.object({
|
||||
display_name: z.string().nullable().optional(),
|
||||
enabled: z.boolean().optional(),
|
||||
supports_chat: z.boolean().nullable().optional(),
|
||||
max_input_tokens: z.number().nullable().optional(),
|
||||
supports_image_input: z.boolean().nullable().optional(),
|
||||
supports_tools: z.boolean().nullable().optional(),
|
||||
supports_image_generation: z.boolean().nullable().optional(),
|
||||
capabilities_override: z.record(z.string(), z.any()).optional(),
|
||||
});
|
||||
|
||||
|
|
@ -89,10 +85,21 @@ export const modelRoles = z.object({
|
|||
image_gen_model_id: z.number().nullable().optional(),
|
||||
});
|
||||
|
||||
export const modelProviderRead = z.object({
|
||||
provider: z.string(),
|
||||
transport: z.string(),
|
||||
discovery: z.string(),
|
||||
default_base_url: z.string().nullable().optional(),
|
||||
base_url_required: z.boolean(),
|
||||
auth_style: z.string(),
|
||||
local_only: z.boolean().default(false),
|
||||
});
|
||||
|
||||
export const modelProviderListResponse = z.array(modelProviderRead);
|
||||
|
||||
export const connectionListResponse = z.array(connectionRead);
|
||||
export const modelListResponse = z.array(modelRead);
|
||||
|
||||
export type ConnectionProtocol = z.infer<typeof connectionProtocolEnum>;
|
||||
export type ConnectionScope = z.infer<typeof connectionScopeEnum>;
|
||||
export type ModelRead = z.infer<typeof modelRead>;
|
||||
export type ConnectionRead = z.infer<typeof connectionRead>;
|
||||
|
|
@ -102,3 +109,4 @@ export type ModelCreateRequest = z.infer<typeof modelCreateRequest>;
|
|||
export type ModelUpdateRequest = z.infer<typeof modelUpdateRequest>;
|
||||
export type ModelRoles = z.infer<typeof modelRoles>;
|
||||
export type VerifyConnectionResponse = z.infer<typeof verifyConnectionResponse>;
|
||||
export type ModelProviderRead = z.infer<typeof modelProviderRead>;
|
||||
|
|
|
|||
|
|
@ -47,11 +47,16 @@ function buildKind(
|
|||
capability: "chat" | "image_gen" | "vision",
|
||||
prefId: number | null | undefined
|
||||
): EligibleModelKind {
|
||||
const supportsCapability = (model: ModelRead) => {
|
||||
if (capability === "chat") return Boolean(model.supports_chat);
|
||||
if (capability === "vision") return Boolean(model.supports_image_input);
|
||||
return Boolean(model.supports_image_generation);
|
||||
};
|
||||
const toOption = (connection: ConnectionRead, model: ModelRead, isBYOK: boolean) => ({
|
||||
id: model.id,
|
||||
name: model.display_name || model.model_id,
|
||||
modelName: model.model_id,
|
||||
provider: connection.litellm_provider || connection.protocol,
|
||||
provider: connection.provider,
|
||||
isBYOK,
|
||||
});
|
||||
|
||||
|
|
@ -60,7 +65,7 @@ function buildKind(
|
|||
.filter(
|
||||
(model) =>
|
||||
model.enabled &&
|
||||
Boolean(model.capabilities?.[capability]) &&
|
||||
supportsCapability(model) &&
|
||||
String(model.billing_tier ?? "").toLowerCase() === "premium"
|
||||
)
|
||||
.map((model) => toOption(connection, model, false))
|
||||
|
|
@ -68,7 +73,7 @@ function buildKind(
|
|||
|
||||
const byokOptions: EligibleModelOption[] = (byok ?? []).flatMap((connection) =>
|
||||
connection.models
|
||||
.filter((model) => model.enabled && Boolean(model.capabilities?.[capability]))
|
||||
.filter((model) => model.enabled && supportsCapability(model))
|
||||
.map((model) => toOption(connection, model, true))
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,12 @@ import {
|
|||
connectionRead,
|
||||
connectionUpdateRequest,
|
||||
type ModelCreateRequest,
|
||||
type ModelProviderRead,
|
||||
type ModelRead,
|
||||
type ModelRoles,
|
||||
type ModelUpdateRequest,
|
||||
modelCreateRequest,
|
||||
modelProviderListResponse,
|
||||
modelListResponse,
|
||||
modelRead,
|
||||
modelRoles,
|
||||
|
|
@ -26,6 +28,10 @@ class ModelConnectionsApiService {
|
|||
return baseApiService.get(`/api/v1/global-model-connections`, connectionListResponse);
|
||||
};
|
||||
|
||||
getModelProviders = async (): Promise<ModelProviderRead[]> => {
|
||||
return baseApiService.get(`/api/v1/model-providers`, modelProviderListResponse);
|
||||
};
|
||||
|
||||
getConnections = async (searchSpaceId: number): Promise<ConnectionRead[]> => {
|
||||
return baseApiService.get(
|
||||
`/api/v1/model-connections?search_space_id=${searchSpaceId}`,
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ export const cacheKeys = {
|
|||
modelConnections: {
|
||||
all: (searchSpaceId: number) => ["model-connections", searchSpaceId] as const,
|
||||
global: () => ["model-connections", "global"] as const,
|
||||
providers: () => ["model-connections", "providers"] as const,
|
||||
roles: (searchSpaceId: number) => ["model-roles", searchSpaceId] as const,
|
||||
},
|
||||
imageGenConfigs: {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue