mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-14 20:55:15 +02:00
feat(model-connections): integrate model provider connections panel and connection card components
This commit is contained in:
parent
15d9983669
commit
5e86885a03
8 changed files with 461 additions and 382 deletions
|
|
@ -79,7 +79,7 @@ REGISTRY: dict[str, ProviderSpec] = {
|
|||
Transport.OPENAI_COMPATIBLE,
|
||||
"openai",
|
||||
"openai_models",
|
||||
"http://localhost:1234/v1",
|
||||
"http://host.docker.internal:1234/v1",
|
||||
True,
|
||||
"bearer",
|
||||
"LM Studio",
|
||||
|
|
@ -88,7 +88,7 @@ REGISTRY: dict[str, ProviderSpec] = {
|
|||
Transport.OLLAMA,
|
||||
"ollama_chat",
|
||||
"ollama",
|
||||
"http://localhost:11434",
|
||||
"http://ollama:11434",
|
||||
True,
|
||||
"none",
|
||||
"Ollama",
|
||||
|
|
|
|||
|
|
@ -7,11 +7,13 @@ import { toast } from "sonner";
|
|||
import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms";
|
||||
import {
|
||||
globalModelConnectionsAtom,
|
||||
modelConnectionsAtom,
|
||||
modelRolesAtom,
|
||||
} from "@/atoms/model-connections/model-connections-query.atoms";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { ModelProviderConnectionsPanel } from "@/components/settings/model-connections/model-provider-connections-panel";
|
||||
import { capability } from "@/components/settings/model-connections/model-utils";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
|
||||
import { getBearerToken, redirectToLogin } from "@/lib/auth-utils";
|
||||
|
||||
|
|
@ -22,6 +24,8 @@ export default function OnboardPage() {
|
|||
const { data: globalConnections = [], isFetching: globalLoading } = useAtomValue(
|
||||
globalModelConnectionsAtom
|
||||
);
|
||||
const { data: connections = [], isFetching: connectionsLoading } =
|
||||
useAtomValue(modelConnectionsAtom);
|
||||
const { data: roles = {}, isFetching: rolesLoading } = useAtomValue(modelRolesAtom);
|
||||
const { mutateAsync: updateRoles, isPending } = useAtomValue(updateModelRolesMutationAtom);
|
||||
const [isAutoConfiguring, setIsAutoConfiguring] = useState(false);
|
||||
|
|
@ -38,6 +42,15 @@ export default function OnboardPage() {
|
|||
}
|
||||
return null;
|
||||
}, [globalConnections]);
|
||||
const hasEnabledChatModel = useMemo(
|
||||
() =>
|
||||
connections.some(
|
||||
(connection) =>
|
||||
connection.enabled &&
|
||||
connection.models.some((model) => model.enabled && capability(model, "chat"))
|
||||
),
|
||||
[connections]
|
||||
);
|
||||
|
||||
const isComplete = (roles.chat_model_id ?? 0) !== 0 || Boolean(firstGlobalChatModel);
|
||||
|
||||
|
|
@ -73,28 +86,37 @@ export default function OnboardPage() {
|
|||
updateRoles,
|
||||
]);
|
||||
|
||||
const isLoading = globalLoading || rolesLoading || isAutoConfiguring || isPending;
|
||||
const isLoading =
|
||||
globalLoading || connectionsLoading || rolesLoading || isAutoConfiguring || isPending;
|
||||
useGlobalLoadingEffect(isLoading);
|
||||
|
||||
if (isLoading || isComplete) return null;
|
||||
|
||||
return (
|
||||
<div className="flex h-screen select-none flex-col items-center justify-center bg-main-panel p-4">
|
||||
<div className="w-full max-w-md space-y-6 rounded-xl border bg-main-panel p-8 text-center">
|
||||
<div className="flex min-h-screen select-none flex-col items-center justify-center bg-main-panel p-4">
|
||||
<div className="w-full max-w-3xl space-y-6 text-center">
|
||||
<Logo className="mx-auto h-12 w-12" />
|
||||
<div className="space-y-2">
|
||||
<h1 className="text-2xl font-semibold tracking-tight">Connect a Model</h1>
|
||||
<h1 className="text-2xl font-semibold tracking-tight">Choose a model</h1>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Add one connection, discover its models, then choose a chat model for this search space.
|
||||
Connect any supported provider, then enable the models you want SurfSense to use.
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
className="min-w-[180px]"
|
||||
onClick={() => router.push(`/dashboard/${searchSpaceId}/search-space-settings/models`)}
|
||||
>
|
||||
Open Models Settings
|
||||
</Button>
|
||||
{isPending ? <Spinner size="sm" /> : null}
|
||||
<ModelProviderConnectionsPanel
|
||||
searchSpaceId={searchSpaceId}
|
||||
connections={connections}
|
||||
className="flex flex-col gap-6 text-left"
|
||||
footerAction={
|
||||
<Button
|
||||
className="min-w-[112px]"
|
||||
disabled={!hasEnabledChatModel}
|
||||
onClick={() => router.push(`/dashboard/${searchSpaceId}/new-chat`)}
|
||||
>
|
||||
Start
|
||||
</Button>
|
||||
}
|
||||
showAddProviderHeader={false}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,35 +1,13 @@
|
|||
"use client";
|
||||
|
||||
import { useAtom, useAtomValue } from "jotai";
|
||||
import { Dot, Trash2 } from "lucide-react";
|
||||
import { useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import {
|
||||
createModelConnectionMutationAtom,
|
||||
deleteModelConnectionMutationAtom,
|
||||
previewConnectionModelsMutationAtom,
|
||||
testPreviewModelMutationAtom,
|
||||
updateModelRolesMutationAtom,
|
||||
} from "@/atoms/model-connections/model-connections-mutation.atoms";
|
||||
import { Dot } from "lucide-react";
|
||||
import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms";
|
||||
import {
|
||||
globalModelConnectionsAtom,
|
||||
modelConnectionsAtom,
|
||||
modelProvidersAtom,
|
||||
modelRolesAtom,
|
||||
} from "@/atoms/model-connections/model-connections-query.atoms";
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogAction,
|
||||
AlertDialogCancel,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogTitle,
|
||||
AlertDialogTrigger,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import {
|
||||
Select,
|
||||
|
|
@ -39,20 +17,10 @@ import {
|
|||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import type {
|
||||
ConnectionRead,
|
||||
ModelRead,
|
||||
ModelSelection,
|
||||
} from "@/contracts/types/model-connections.types";
|
||||
import { ConnectionSettingsDialog } from "./model-connections/connection-settings-dialog";
|
||||
import { capability, modelLabel, type SelectableModel } from "./model-connections/model-utils";
|
||||
import { ProviderConnectDialog } from "./model-connections/provider-connect-dialog";
|
||||
import {
|
||||
type ConnectionDraft,
|
||||
PROVIDER_ORDER,
|
||||
providerDisplay,
|
||||
providerIcon,
|
||||
} from "./model-connections/provider-metadata";
|
||||
import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types";
|
||||
import { ModelProviderConnectionsPanel } from "./model-connections/model-provider-connections-panel";
|
||||
import { capability, modelLabel } from "./model-connections/model-utils";
|
||||
import { providerDisplay, providerIcon } from "./model-connections/provider-metadata";
|
||||
|
||||
function flattenModels(connections: ConnectionRead[]) {
|
||||
return connections.flatMap((connection) =>
|
||||
|
|
@ -70,271 +38,18 @@ function roleSelectValue(modelId: number | null | undefined, models: Array<{ id:
|
|||
return models.some((model) => model.id === modelId) ? String(modelId) : "0";
|
||||
}
|
||||
|
||||
function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
||||
const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom);
|
||||
|
||||
const providerMeta = providerDisplay(connection.provider);
|
||||
const providerLabel = providerMeta.name;
|
||||
|
||||
function deleteCurrentConnection() {
|
||||
deleteConnection.mutate(connection.id);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border border-border/60 overflow-hidden">
|
||||
<div className="flex items-center justify-between gap-3 p-4 transition-colors hover:bg-accent">
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-2 font-semibold">
|
||||
{providerIcon(connection.provider)}
|
||||
<span className="truncate">{providerLabel}</span>
|
||||
{connection.scope === "GLOBAL" ? (
|
||||
<Badge variant="outline" className="text-[10px]">
|
||||
Default
|
||||
</Badge>
|
||||
) : null}
|
||||
</div>
|
||||
<div className="truncate text-sm text-muted-foreground">
|
||||
{connection.base_url || "Provider default endpoint"}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex shrink-0 items-center gap-2">
|
||||
<ConnectionSettingsDialog connection={connection} providerLabel={providerLabel} />
|
||||
<AlertDialog>
|
||||
<AlertDialogTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="text-muted-foreground hover:text-accent-foreground"
|
||||
disabled={deleteConnection.isPending}
|
||||
aria-label={`Delete ${providerLabel}`}
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
</AlertDialogTrigger>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Delete this provider?</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
<span className="font-medium text-foreground">{providerLabel}</span> and all of
|
||||
its models will be removed from this search space. This cannot be undone.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel disabled={deleteConnection.isPending}>Cancel</AlertDialogCancel>
|
||||
<AlertDialogAction
|
||||
onClick={deleteCurrentConnection}
|
||||
disabled={deleteConnection.isPending}
|
||||
className="bg-destructive text-white hover:bg-destructive/90"
|
||||
>
|
||||
Delete
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: number }) {
|
||||
const [{ data: globalConnections = [] }] = useAtom(globalModelConnectionsAtom);
|
||||
const [{ data: connections = [] }] = useAtom(modelConnectionsAtom);
|
||||
const [{ data: providers = [] }] = useAtom(modelProvidersAtom);
|
||||
const [{ data: roles }] = useAtom(modelRolesAtom);
|
||||
const createConnection = useAtomValue(createModelConnectionMutationAtom);
|
||||
const previewModels = useAtomValue(previewConnectionModelsMutationAtom);
|
||||
const testPreviewModel = useAtomValue(testPreviewModelMutationAtom);
|
||||
const updateRoles = useAtomValue(updateModelRolesMutationAtom);
|
||||
|
||||
const [isAddProviderOpen, setIsAddProviderOpen] = useState(false);
|
||||
const [provider, setProvider] = useState("openai_compatible");
|
||||
const [connectModels, setConnectModels] = useState<ModelSelection[]>([]);
|
||||
const selectedProvider = providers.find((item) => item.provider === provider);
|
||||
|
||||
const sortedProviders = [...providers].sort((left, right) => {
|
||||
const leftIndex = PROVIDER_ORDER.indexOf(left.provider);
|
||||
const rightIndex = PROVIDER_ORDER.indexOf(right.provider);
|
||||
if (leftIndex !== -1 || rightIndex !== -1) {
|
||||
return (
|
||||
(leftIndex === -1 ? Number.MAX_SAFE_INTEGER : leftIndex) -
|
||||
(rightIndex === -1 ? Number.MAX_SAFE_INTEGER : rightIndex)
|
||||
);
|
||||
}
|
||||
return providerDisplay(left.provider).name.localeCompare(providerDisplay(right.provider).name);
|
||||
});
|
||||
|
||||
const allConnections = [...globalConnections, ...connections];
|
||||
const enabledModels = flattenModels(allConnections).filter((model) => model.enabled);
|
||||
const chatModels = enabledModels.filter((model) => capability(model, "chat"));
|
||||
const visionModels = enabledModels.filter((model) => capability(model, "vision"));
|
||||
const imageModels = enabledModels.filter((model) => capability(model, "image_gen"));
|
||||
|
||||
function resetConnectState() {
|
||||
setConnectModels([]);
|
||||
}
|
||||
|
||||
function handleConnectOpenChange(open: boolean) {
|
||||
setIsAddProviderOpen(open);
|
||||
if (!open) {
|
||||
resetConnectState();
|
||||
}
|
||||
}
|
||||
|
||||
function toModelSelection(model: SelectableModel): ModelSelection {
|
||||
return {
|
||||
model_id: model.model_id,
|
||||
display_name: model.display_name,
|
||||
source: model.source || "DISCOVERED",
|
||||
supports_chat: model.supports_chat,
|
||||
max_input_tokens: model.max_input_tokens,
|
||||
supports_image_input: model.supports_image_input,
|
||||
supports_tools: model.supports_tools,
|
||||
supports_image_generation: model.supports_image_generation,
|
||||
enabled: model.enabled,
|
||||
metadata: "metadata" in model ? (model.metadata ?? {}) : (model.catalog ?? {}),
|
||||
};
|
||||
}
|
||||
|
||||
function mergePreviewModels(fetchedModels: SelectableModel[]) {
|
||||
setConnectModels((current) => {
|
||||
const currentById = new Map(current.map((model) => [model.model_id, model]));
|
||||
return fetchedModels.map((model) => {
|
||||
const prior = currentById.get(model.model_id);
|
||||
return {
|
||||
...toModelSelection(model),
|
||||
enabled: prior ? prior.enabled : model.enabled,
|
||||
};
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function connectionModelsForDraft(draft: ConnectionDraft) {
|
||||
const models = [...connectModels];
|
||||
if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) {
|
||||
models.push({
|
||||
model_id: draft.seedModelId,
|
||||
display_name: draft.seedModelId,
|
||||
source: "MANUAL",
|
||||
enabled: true,
|
||||
metadata: {},
|
||||
});
|
||||
}
|
||||
return models;
|
||||
}
|
||||
|
||||
function representativeTestModel(models: ModelSelection[]) {
|
||||
const enabledModels = models.filter((model) => model.enabled);
|
||||
return enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0];
|
||||
}
|
||||
|
||||
// Each provider connect form builds its own credential payload; the backend
|
||||
// resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM.
|
||||
function handleCreate(draft: ConnectionDraft) {
|
||||
const models = connectionModelsForDraft(draft);
|
||||
const testModel = representativeTestModel(models);
|
||||
if (!testModel) {
|
||||
toast.error("Select at least one model before connecting");
|
||||
return;
|
||||
}
|
||||
|
||||
const request = {
|
||||
provider,
|
||||
base_url: draft.base_url,
|
||||
api_key: draft.api_key,
|
||||
scope: "SEARCH_SPACE" as const,
|
||||
search_space_id: searchSpaceId,
|
||||
extra: draft.extra,
|
||||
enabled: true,
|
||||
models,
|
||||
};
|
||||
|
||||
testPreviewModel.mutate(
|
||||
{ ...request, model_id: testModel.model_id },
|
||||
{
|
||||
onSuccess: (result) => {
|
||||
if (!result.ok) return;
|
||||
createConnection.mutate(request, {
|
||||
onSuccess: () => {
|
||||
setIsAddProviderOpen(false);
|
||||
resetConnectState();
|
||||
},
|
||||
});
|
||||
},
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
function openProviderDialog(providerId: string) {
|
||||
resetConnectState();
|
||||
setProvider(providerId);
|
||||
setIsAddProviderOpen(true);
|
||||
if (providerId === "vertex_ai") {
|
||||
previewModels.mutate(
|
||||
{
|
||||
provider: providerId,
|
||||
base_url: null,
|
||||
api_key: null,
|
||||
scope: "SEARCH_SPACE",
|
||||
search_space_id: searchSpaceId,
|
||||
extra: {},
|
||||
enabled: true,
|
||||
models: [],
|
||||
},
|
||||
{
|
||||
onSuccess: mergePreviewModels,
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
function refreshConnectModels(draft: ConnectionDraft) {
|
||||
previewModels.mutate(
|
||||
{
|
||||
provider,
|
||||
base_url: draft.base_url,
|
||||
api_key: draft.api_key,
|
||||
scope: "SEARCH_SPACE",
|
||||
search_space_id: searchSpaceId,
|
||||
extra: draft.extra,
|
||||
enabled: true,
|
||||
models: [],
|
||||
},
|
||||
{
|
||||
onSuccess: mergePreviewModels,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
function addConnectModel(modelId: string) {
|
||||
setConnectModels((current) => {
|
||||
if (current.some((model) => model.model_id === modelId)) return current;
|
||||
return [
|
||||
...current,
|
||||
{
|
||||
model_id: modelId,
|
||||
display_name: modelId,
|
||||
source: "MANUAL",
|
||||
enabled: true,
|
||||
metadata: {},
|
||||
},
|
||||
];
|
||||
});
|
||||
}
|
||||
|
||||
function toggleConnectModel(model: SelectableModel, enabled: boolean) {
|
||||
setConnectModels((current) =>
|
||||
current.map((item) => (item.model_id === model.model_id ? { ...item, enabled } : item))
|
||||
);
|
||||
}
|
||||
|
||||
function bulkToggleConnectModels(models: SelectableModel[], enabled: boolean) {
|
||||
const modelIds = new Set(models.map((model) => model.model_id));
|
||||
setConnectModels((current) =>
|
||||
current.map((item) => (modelIds.has(item.model_id) ? { ...item, enabled } : item))
|
||||
);
|
||||
}
|
||||
|
||||
function renderModelOption(model: ModelRead & { connectionName: string; provider: string }) {
|
||||
return (
|
||||
<SelectItem key={model.id} value={String(model.id)}>
|
||||
|
|
@ -420,71 +135,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
|
||||
<Separator />
|
||||
|
||||
<div className="flex flex-col gap-6">
|
||||
<div className="flex flex-col gap-3">
|
||||
<div>
|
||||
<h3 className="text-base font-semibold">Add Provider</h3>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
SurfSense supports popular providers and self-hosted model endpoints.
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-3 md:grid-cols-2">
|
||||
{sortedProviders.map((item) => {
|
||||
const meta = providerDisplay(item.provider);
|
||||
|
||||
return (
|
||||
<Button
|
||||
key={item.provider}
|
||||
variant="ghost"
|
||||
type="button"
|
||||
className="h-auto justify-between gap-3 rounded-lg border border-border/60 p-4 text-left whitespace-normal transition-colors hover:bg-accent hover:text-accent-foreground"
|
||||
onClick={() => openProviderDialog(item.provider)}
|
||||
>
|
||||
<span className="flex min-w-0 items-center gap-3">
|
||||
{providerIcon(item.provider, "size-5")}
|
||||
<span className="min-w-0">
|
||||
<span className="block truncate text-sm font-semibold">{meta.name}</span>
|
||||
<span className="block truncate text-xs text-muted-foreground">
|
||||
{meta.subtitle}
|
||||
</span>
|
||||
</span>
|
||||
</span>
|
||||
<span className="shrink-0 text-sm font-medium text-muted-foreground">
|
||||
Connect
|
||||
</span>
|
||||
</Button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<ProviderConnectDialog
|
||||
open={isAddProviderOpen}
|
||||
onOpenChange={handleConnectOpenChange}
|
||||
provider={provider}
|
||||
selectedProvider={selectedProvider}
|
||||
isPending={createConnection.isPending || testPreviewModel.isPending}
|
||||
onSubmit={handleCreate}
|
||||
previewModels={connectModels}
|
||||
isPreviewingModels={previewModels.isPending}
|
||||
onPreviewModels={refreshConnectModels}
|
||||
onAddPreviewModel={addConnectModel}
|
||||
onTogglePreviewModel={toggleConnectModel}
|
||||
onBulkTogglePreviewModels={bulkToggleConnectModels}
|
||||
/>
|
||||
|
||||
{connections.length > 0 ? (
|
||||
<div className="flex flex-col gap-3">
|
||||
<Separator />
|
||||
<h3 className="text-base font-semibold">Available Providers</h3>
|
||||
<div className="flex flex-col gap-3">
|
||||
{connections.map((connection) => (
|
||||
<ConnectionCard key={connection.id} connection={connection} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
<ModelProviderConnectionsPanel searchSpaceId={searchSpaceId} connections={connections} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import { Button } from "@/components/ui/button";
|
|||
import { DialogFooter } from "@/components/ui/dialog";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
|
||||
interface ApiBaseUrlFieldProps {
|
||||
value: string;
|
||||
|
|
@ -93,8 +94,13 @@ export function ConnectFormFooter({
|
|||
<Button variant="secondary" onClick={onCancel}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={onSubmit} disabled={isPending || !canSubmit}>
|
||||
Connect
|
||||
<Button
|
||||
onClick={onSubmit}
|
||||
disabled={isPending || !canSubmit}
|
||||
className="relative min-w-[96px]"
|
||||
>
|
||||
<span className={isPending ? "opacity-0" : ""}>Connect</span>
|
||||
{isPending ? <Spinner size="sm" className="absolute" /> : null}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,88 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import { Trash2 } from "lucide-react";
|
||||
import { deleteModelConnectionMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms";
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogAction,
|
||||
AlertDialogCancel,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogTitle,
|
||||
AlertDialogTrigger,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import type { ConnectionRead } from "@/contracts/types/model-connections.types";
|
||||
import { ConnectionSettingsDialog } from "./connection-settings-dialog";
|
||||
import { providerDisplay, providerIcon } from "./provider-metadata";
|
||||
|
||||
export function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
||||
const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom);
|
||||
|
||||
const providerMeta = providerDisplay(connection.provider);
|
||||
const providerLabel = providerMeta.name;
|
||||
|
||||
function deleteCurrentConnection() {
|
||||
deleteConnection.mutate(connection.id);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="overflow-hidden rounded-lg border border-border/60">
|
||||
<div className="flex items-center justify-between gap-3 p-4 transition-colors hover:bg-accent">
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-2 font-semibold">
|
||||
{providerIcon(connection.provider)}
|
||||
<span className="truncate">{providerLabel}</span>
|
||||
{connection.scope === "GLOBAL" ? (
|
||||
<Badge variant="outline" className="text-[10px]">
|
||||
Default
|
||||
</Badge>
|
||||
) : null}
|
||||
</div>
|
||||
<div className="truncate text-sm text-muted-foreground">
|
||||
{connection.base_url || "Provider default endpoint"}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex shrink-0 items-center gap-2">
|
||||
<ConnectionSettingsDialog connection={connection} providerLabel={providerLabel} />
|
||||
<AlertDialog>
|
||||
<AlertDialogTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="text-muted-foreground hover:text-accent-foreground"
|
||||
disabled={deleteConnection.isPending}
|
||||
aria-label={`Delete ${providerLabel}`}
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
</AlertDialogTrigger>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Delete this provider?</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
<span className="font-medium text-foreground">{providerLabel}</span> and all of
|
||||
its models will be removed from this search space. This cannot be undone.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel disabled={deleteConnection.isPending}>Cancel</AlertDialogCancel>
|
||||
<AlertDialogAction
|
||||
onClick={deleteCurrentConnection}
|
||||
disabled={deleteConnection.isPending}
|
||||
className="bg-destructive text-white hover:bg-destructive/90"
|
||||
>
|
||||
Delete
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -22,6 +22,7 @@ import {
|
|||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import type {
|
||||
ConnectionRead,
|
||||
ConnectionUpdateRequest,
|
||||
|
|
@ -54,6 +55,7 @@ export function ConnectionSettingsDialog({
|
|||
const [apiKeyDraft, setApiKeyDraft] = useState("");
|
||||
const [showApiKey, setShowApiKey] = useState(false);
|
||||
const [allowlistText, setAllowlistText] = useState(allowlist.join(", "));
|
||||
const [isSavingConnectionSettings, setIsSavingConnectionSettings] = useState(false);
|
||||
|
||||
const isLocal =
|
||||
connection.provider === "ollama_chat" ||
|
||||
|
|
@ -70,10 +72,13 @@ export function ConnectionSettingsDialog({
|
|||
setApiKeyDraft(connection.api_key ?? "");
|
||||
setShowApiKey(false);
|
||||
setAllowlistText(allowlist.join(", "));
|
||||
setIsSavingConnectionSettings(false);
|
||||
}
|
||||
}
|
||||
|
||||
function saveConnectionSettings() {
|
||||
if (isSavingConnectionSettings) return;
|
||||
|
||||
const data: ConnectionUpdateRequest = {
|
||||
base_url: baseUrlDraft.trim() || null,
|
||||
};
|
||||
|
|
@ -86,13 +91,14 @@ export function ConnectionSettingsDialog({
|
|||
: (connection.api_key ?? null);
|
||||
|
||||
const enabledModels = connection.models.filter((model) => model.enabled);
|
||||
const testModel =
|
||||
enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0];
|
||||
const testModel = enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0];
|
||||
setIsSavingConnectionSettings(true);
|
||||
if (!testModel) {
|
||||
updateConnection.mutate(
|
||||
{ id: connection.id, data },
|
||||
{
|
||||
onSuccess: () => setApiKeyDraft(""),
|
||||
onSettled: () => setIsSavingConnectionSettings(false),
|
||||
}
|
||||
);
|
||||
return;
|
||||
|
|
@ -112,14 +118,19 @@ export function ConnectionSettingsDialog({
|
|||
},
|
||||
{
|
||||
onSuccess: (result) => {
|
||||
if (!result.ok) return;
|
||||
if (!result.ok) {
|
||||
setIsSavingConnectionSettings(false);
|
||||
return;
|
||||
}
|
||||
updateConnection.mutate(
|
||||
{ id: connection.id, data },
|
||||
{
|
||||
onSuccess: () => setApiKeyDraft(""),
|
||||
onSettled: () => setIsSavingConnectionSettings(false),
|
||||
}
|
||||
);
|
||||
},
|
||||
onError: () => setIsSavingConnectionSettings(false),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
@ -257,18 +268,17 @@ export function ConnectionSettingsDialog({
|
|||
onToggleModel={handleToggleModel}
|
||||
onBulkToggle={handleBulkToggle}
|
||||
/>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<DialogFooter className="shrink-0 border-t bg-popover px-6 py-4">
|
||||
<Button
|
||||
onClick={saveConnectionSettings}
|
||||
disabled={
|
||||
updateConnection.isPending || testPreviewModel.isPending || !hasConnectionChanges
|
||||
}
|
||||
disabled={isSavingConnectionSettings || !hasConnectionChanges}
|
||||
className="relative min-w-[96px]"
|
||||
>
|
||||
Update
|
||||
<span className={isSavingConnectionSettings ? "opacity-0" : ""}>Update</span>
|
||||
{isSavingConnectionSettings ? <Spinner size="sm" className="absolute" /> : null}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,299 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import { type ReactNode, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import {
|
||||
createModelConnectionMutationAtom,
|
||||
previewConnectionModelsMutationAtom,
|
||||
testPreviewModelMutationAtom,
|
||||
} from "@/atoms/model-connections/model-connections-mutation.atoms";
|
||||
import { modelProvidersAtom } from "@/atoms/model-connections/model-connections-query.atoms";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import type { ConnectionRead, ModelSelection } from "@/contracts/types/model-connections.types";
|
||||
import { ConnectionCard } from "./connection-card";
|
||||
import { capability, type SelectableModel } from "./model-utils";
|
||||
import { ProviderConnectDialog } from "./provider-connect-dialog";
|
||||
import {
|
||||
type ConnectionDraft,
|
||||
PROVIDER_ORDER,
|
||||
providerDisplay,
|
||||
providerIcon,
|
||||
} from "./provider-metadata";
|
||||
|
||||
interface ModelProviderConnectionsPanelProps {
|
||||
searchSpaceId: number;
|
||||
connections: ConnectionRead[];
|
||||
className?: string;
|
||||
addProviderTitle?: string;
|
||||
addProviderDescription?: string;
|
||||
availableProvidersTitle?: string;
|
||||
footerAction?: ReactNode;
|
||||
showAddProviderHeader?: boolean;
|
||||
}
|
||||
|
||||
function toModelSelection(model: SelectableModel): ModelSelection {
|
||||
return {
|
||||
model_id: model.model_id,
|
||||
display_name: model.display_name,
|
||||
source: model.source || "DISCOVERED",
|
||||
supports_chat: model.supports_chat,
|
||||
max_input_tokens: model.max_input_tokens,
|
||||
supports_image_input: model.supports_image_input,
|
||||
supports_tools: model.supports_tools,
|
||||
supports_image_generation: model.supports_image_generation,
|
||||
enabled: model.enabled,
|
||||
metadata: "metadata" in model ? (model.metadata ?? {}) : (model.catalog ?? {}),
|
||||
};
|
||||
}
|
||||
|
||||
export function ModelProviderConnectionsPanel({
|
||||
searchSpaceId,
|
||||
connections,
|
||||
className,
|
||||
addProviderTitle = "Add Provider",
|
||||
addProviderDescription = "SurfSense supports popular providers and self-hosted model endpoints.",
|
||||
availableProvidersTitle = "Available Providers",
|
||||
footerAction,
|
||||
showAddProviderHeader = true,
|
||||
}: ModelProviderConnectionsPanelProps) {
|
||||
const { data: providers = [] } = useAtomValue(modelProvidersAtom);
|
||||
const createConnection = useAtomValue(createModelConnectionMutationAtom);
|
||||
const previewModels = useAtomValue(previewConnectionModelsMutationAtom);
|
||||
const testPreviewModel = useAtomValue(testPreviewModelMutationAtom);
|
||||
|
||||
const [isAddProviderOpen, setIsAddProviderOpen] = useState(false);
|
||||
const [provider, setProvider] = useState("openai_compatible");
|
||||
const [connectModels, setConnectModels] = useState<ModelSelection[]>([]);
|
||||
const selectedProvider = providers.find((item) => item.provider === provider);
|
||||
|
||||
const sortedProviders = [...providers].sort((left, right) => {
|
||||
const leftIndex = PROVIDER_ORDER.indexOf(left.provider);
|
||||
const rightIndex = PROVIDER_ORDER.indexOf(right.provider);
|
||||
if (leftIndex !== -1 || rightIndex !== -1) {
|
||||
return (
|
||||
(leftIndex === -1 ? Number.MAX_SAFE_INTEGER : leftIndex) -
|
||||
(rightIndex === -1 ? Number.MAX_SAFE_INTEGER : rightIndex)
|
||||
);
|
||||
}
|
||||
return providerDisplay(left.provider).name.localeCompare(providerDisplay(right.provider).name);
|
||||
});
|
||||
|
||||
function resetConnectState() {
|
||||
setConnectModels([]);
|
||||
}
|
||||
|
||||
function handleConnectOpenChange(open: boolean) {
|
||||
setIsAddProviderOpen(open);
|
||||
if (!open) {
|
||||
resetConnectState();
|
||||
}
|
||||
}
|
||||
|
||||
function mergePreviewModels(fetchedModels: SelectableModel[]) {
|
||||
setConnectModels((current) => {
|
||||
const currentById = new Map(current.map((model) => [model.model_id, model]));
|
||||
return fetchedModels.map((model) => {
|
||||
const prior = currentById.get(model.model_id);
|
||||
return {
|
||||
...toModelSelection(model),
|
||||
enabled: prior ? prior.enabled : model.enabled,
|
||||
};
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function connectionModelsForDraft(draft: ConnectionDraft) {
|
||||
const models = [...connectModels];
|
||||
if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) {
|
||||
models.push({
|
||||
model_id: draft.seedModelId,
|
||||
display_name: draft.seedModelId,
|
||||
source: "MANUAL",
|
||||
enabled: true,
|
||||
metadata: {},
|
||||
});
|
||||
}
|
||||
return models;
|
||||
}
|
||||
|
||||
function representativeTestModel(models: ModelSelection[]) {
|
||||
const enabledModels = models.filter((model) => model.enabled);
|
||||
return enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0];
|
||||
}
|
||||
|
||||
// Each provider connect form builds its own credential payload; the backend
|
||||
// resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM.
|
||||
function handleCreate(draft: ConnectionDraft) {
|
||||
const models = connectionModelsForDraft(draft);
|
||||
const testModel = representativeTestModel(models);
|
||||
if (!testModel) {
|
||||
toast.error("Select at least one model before connecting");
|
||||
return;
|
||||
}
|
||||
|
||||
const request = {
|
||||
provider,
|
||||
base_url: draft.base_url,
|
||||
api_key: draft.api_key,
|
||||
scope: "SEARCH_SPACE" as const,
|
||||
search_space_id: searchSpaceId,
|
||||
extra: draft.extra,
|
||||
enabled: true,
|
||||
models,
|
||||
};
|
||||
|
||||
testPreviewModel.mutate(
|
||||
{ ...request, model_id: testModel.model_id },
|
||||
{
|
||||
onSuccess: (result) => {
|
||||
if (!result.ok) return;
|
||||
createConnection.mutate(request, {
|
||||
onSuccess: () => {
|
||||
setIsAddProviderOpen(false);
|
||||
resetConnectState();
|
||||
},
|
||||
});
|
||||
},
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
function openProviderDialog(providerId: string) {
|
||||
resetConnectState();
|
||||
setProvider(providerId);
|
||||
setIsAddProviderOpen(true);
|
||||
if (providerId === "vertex_ai") {
|
||||
previewModels.mutate(
|
||||
{
|
||||
provider: providerId,
|
||||
base_url: null,
|
||||
api_key: null,
|
||||
scope: "SEARCH_SPACE",
|
||||
search_space_id: searchSpaceId,
|
||||
extra: {},
|
||||
enabled: true,
|
||||
models: [],
|
||||
},
|
||||
{
|
||||
onSuccess: mergePreviewModels,
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
function refreshConnectModels(draft: ConnectionDraft) {
|
||||
previewModels.mutate(
|
||||
{
|
||||
provider,
|
||||
base_url: draft.base_url,
|
||||
api_key: draft.api_key,
|
||||
scope: "SEARCH_SPACE",
|
||||
search_space_id: searchSpaceId,
|
||||
extra: draft.extra,
|
||||
enabled: true,
|
||||
models: [],
|
||||
},
|
||||
{
|
||||
onSuccess: mergePreviewModels,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
function addConnectModel(modelId: string) {
|
||||
setConnectModels((current) => {
|
||||
if (current.some((model) => model.model_id === modelId)) return current;
|
||||
return [
|
||||
...current,
|
||||
{
|
||||
model_id: modelId,
|
||||
display_name: modelId,
|
||||
source: "MANUAL",
|
||||
enabled: true,
|
||||
metadata: {},
|
||||
},
|
||||
];
|
||||
});
|
||||
}
|
||||
|
||||
function toggleConnectModel(model: SelectableModel, enabled: boolean) {
|
||||
setConnectModels((current) =>
|
||||
current.map((item) => (item.model_id === model.model_id ? { ...item, enabled } : item))
|
||||
);
|
||||
}
|
||||
|
||||
function bulkToggleConnectModels(models: SelectableModel[], enabled: boolean) {
|
||||
const modelIds = new Set(models.map((model) => model.model_id));
|
||||
setConnectModels((current) =>
|
||||
current.map((item) => (modelIds.has(item.model_id) ? { ...item, enabled } : item))
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={className ?? "flex flex-col gap-6"}>
|
||||
<div className="flex flex-col gap-3">
|
||||
{showAddProviderHeader ? (
|
||||
<div>
|
||||
<h3 className="text-base font-semibold">{addProviderTitle}</h3>
|
||||
<p className="text-sm text-muted-foreground">{addProviderDescription}</p>
|
||||
</div>
|
||||
) : null}
|
||||
<div className="grid gap-3 md:grid-cols-2">
|
||||
{sortedProviders.map((item) => {
|
||||
const meta = providerDisplay(item.provider);
|
||||
|
||||
return (
|
||||
<Button
|
||||
key={item.provider}
|
||||
variant="ghost"
|
||||
type="button"
|
||||
className="h-auto justify-between gap-3 whitespace-normal rounded-lg border border-border/60 p-4 text-left transition-colors hover:bg-accent hover:text-accent-foreground"
|
||||
onClick={() => openProviderDialog(item.provider)}
|
||||
>
|
||||
<span className="flex min-w-0 items-center gap-3">
|
||||
{providerIcon(item.provider, "size-5")}
|
||||
<span className="min-w-0">
|
||||
<span className="block truncate text-sm font-semibold">{meta.name}</span>
|
||||
<span className="block truncate text-xs text-muted-foreground">
|
||||
{meta.subtitle}
|
||||
</span>
|
||||
</span>
|
||||
</span>
|
||||
<span className="shrink-0 text-sm font-medium text-muted-foreground">Connect</span>
|
||||
</Button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<ProviderConnectDialog
|
||||
open={isAddProviderOpen}
|
||||
onOpenChange={handleConnectOpenChange}
|
||||
provider={provider}
|
||||
selectedProvider={selectedProvider}
|
||||
isPending={createConnection.isPending || testPreviewModel.isPending}
|
||||
onSubmit={handleCreate}
|
||||
previewModels={connectModels}
|
||||
isPreviewingModels={previewModels.isPending}
|
||||
onPreviewModels={refreshConnectModels}
|
||||
onAddPreviewModel={addConnectModel}
|
||||
onTogglePreviewModel={toggleConnectModel}
|
||||
onBulkTogglePreviewModels={bulkToggleConnectModels}
|
||||
/>
|
||||
|
||||
{connections.length > 0 ? (
|
||||
<div className="flex flex-col gap-3">
|
||||
<Separator />
|
||||
<h3 className="text-base font-semibold">{availableProvidersTitle}</h3>
|
||||
<div className="flex flex-col gap-3">
|
||||
{connections.map((connection) => (
|
||||
<ConnectionCard key={connection.id} connection={connection} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
{footerAction ? <div className="flex justify-center pt-2">{footerAction}</div> : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ import { Badge } from "@/components/ui/badge";
|
|||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import {
|
||||
capability,
|
||||
capabilityLabels,
|
||||
|
|
@ -117,8 +118,10 @@ export function ModelsSelectionPanel({
|
|||
type="button"
|
||||
onClick={addModel}
|
||||
disabled={isAddingManual || !manualModelId.trim()}
|
||||
className="relative min-w-[88px]"
|
||||
>
|
||||
Add model
|
||||
<span className={isAddingManual ? "opacity-0" : ""}>Add model</span>
|
||||
{isAddingManual ? <Spinner size="xs" className="absolute" /> : null}
|
||||
</Button>
|
||||
</div>
|
||||
) : null}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue