mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-07-02 22:01:05 +02:00
refactor(model-selector, connection-settings): streamline model name handling and enhance connection settings dialog with improved state management for enabled models
This commit is contained in:
parent
01cb4f281e
commit
212c8af682
2 changed files with 107 additions and 57 deletions
|
|
@ -42,10 +42,6 @@ type ChatModel = ModelRead & {
|
||||||
provider: string;
|
provider: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
function modelName(model: ModelRead) {
|
|
||||||
return (model.display_name || model.model_id).replace(/\s+\(free\)$/i, "");
|
|
||||||
}
|
|
||||||
|
|
||||||
function connectionLabel(connection: ConnectionRead) {
|
function connectionLabel(connection: ConnectionRead) {
|
||||||
if (connection.scope === "GLOBAL") return "Global";
|
if (connection.scope === "GLOBAL") return "Global";
|
||||||
return providerDisplay(connection.provider).name;
|
return providerDisplay(connection.provider).name;
|
||||||
|
|
@ -69,6 +65,14 @@ function isFreeGlobalModel(model: ChatModel) {
|
||||||
return model.connectionScope === "GLOBAL" && model.billing_tier?.toLowerCase() === "free";
|
return model.connectionScope === "GLOBAL" && model.billing_tier?.toLowerCase() === "free";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function modelName(model: ChatModel) {
|
||||||
|
const name = model.display_name || model.model_id;
|
||||||
|
if (model.connectionScope === "GLOBAL") {
|
||||||
|
return name.replace(/\s+\(free\)$/i, "");
|
||||||
|
}
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
function groupedModels(models: ChatModel[]) {
|
function groupedModels(models: ChatModel[]) {
|
||||||
return models.reduce<Record<string, ChatModel[]>>((groups, model) => {
|
return models.reduce<Record<string, ChatModel[]>>((groups, model) => {
|
||||||
const key = model.connectionLabel;
|
const key = model.connectionLabel;
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
import { useAtomValue } from "jotai";
|
import { useAtomValue } from "jotai";
|
||||||
import { Eye, EyeOff, Settings } from "lucide-react";
|
import { Eye, EyeOff, Settings } from "lucide-react";
|
||||||
import { useState } from "react";
|
import { useMemo, useState } from "react";
|
||||||
import {
|
import {
|
||||||
addManualModelMutationAtom,
|
addManualModelMutationAtom,
|
||||||
bulkUpdateModelsMutationAtom,
|
bulkUpdateModelsMutationAtom,
|
||||||
discoverConnectionModelsMutationAtom,
|
discoverConnectionModelsMutationAtom,
|
||||||
testPreviewModelMutationAtom,
|
testPreviewModelMutationAtom,
|
||||||
updateModelConnectionMutationAtom,
|
updateModelConnectionMutationAtom,
|
||||||
updateModelMutationAtom,
|
|
||||||
} from "@/atoms/model-connections/model-connections-mutation.atoms";
|
} from "@/atoms/model-connections/model-connections-mutation.atoms";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import {
|
import {
|
||||||
|
|
@ -36,6 +35,14 @@ interface ConnectionSettingsDialogProps {
|
||||||
providerLabel: string;
|
providerLabel: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function enabledModelIds(models: SelectableModel[]) {
|
||||||
|
return new Set(
|
||||||
|
models
|
||||||
|
.filter((model) => typeof model.id === "number" && model.enabled)
|
||||||
|
.map((model) => Number(model.id))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export function ConnectionSettingsDialog({
|
export function ConnectionSettingsDialog({
|
||||||
connection,
|
connection,
|
||||||
providerLabel,
|
providerLabel,
|
||||||
|
|
@ -44,7 +51,6 @@ export function ConnectionSettingsDialog({
|
||||||
const testPreviewModel = useAtomValue(testPreviewModelMutationAtom);
|
const testPreviewModel = useAtomValue(testPreviewModelMutationAtom);
|
||||||
const updateConnection = useAtomValue(updateModelConnectionMutationAtom);
|
const updateConnection = useAtomValue(updateModelConnectionMutationAtom);
|
||||||
const addManualModel = useAtomValue(addManualModelMutationAtom);
|
const addManualModel = useAtomValue(addManualModelMutationAtom);
|
||||||
const updateModel = useAtomValue(updateModelMutationAtom);
|
|
||||||
const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom);
|
const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom);
|
||||||
|
|
||||||
const allowlist = Array.isArray(connection.extra?.model_ids)
|
const allowlist = Array.isArray(connection.extra?.model_ids)
|
||||||
|
|
@ -56,6 +62,9 @@ export function ConnectionSettingsDialog({
|
||||||
const [showApiKey, setShowApiKey] = useState(false);
|
const [showApiKey, setShowApiKey] = useState(false);
|
||||||
const [allowlistText, setAllowlistText] = useState(allowlist.join(", "));
|
const [allowlistText, setAllowlistText] = useState(allowlist.join(", "));
|
||||||
const [isSavingConnectionSettings, setIsSavingConnectionSettings] = useState(false);
|
const [isSavingConnectionSettings, setIsSavingConnectionSettings] = useState(false);
|
||||||
|
const [draftEnabledModelIds, setDraftEnabledModelIds] = useState(() =>
|
||||||
|
enabledModelIds(connection.models)
|
||||||
|
);
|
||||||
|
|
||||||
const isLocal =
|
const isLocal =
|
||||||
connection.provider === "ollama_chat" ||
|
connection.provider === "ollama_chat" ||
|
||||||
|
|
@ -64,6 +73,19 @@ export function ConnectionSettingsDialog({
|
||||||
const hasConnectionChanges =
|
const hasConnectionChanges =
|
||||||
baseUrlDraft.trim() !== (connection.base_url ?? "") ||
|
baseUrlDraft.trim() !== (connection.base_url ?? "") ||
|
||||||
apiKeyDraft.trim() !== (connection.api_key ?? "");
|
apiKeyDraft.trim() !== (connection.api_key ?? "");
|
||||||
|
const draftModels = useMemo(
|
||||||
|
() =>
|
||||||
|
connection.models.map((model) =>
|
||||||
|
typeof model.id === "number"
|
||||||
|
? { ...model, enabled: draftEnabledModelIds.has(model.id) }
|
||||||
|
: model
|
||||||
|
),
|
||||||
|
[connection.models, draftEnabledModelIds]
|
||||||
|
);
|
||||||
|
const hasModelChanges = connection.models.some(
|
||||||
|
(model) => typeof model.id === "number" && draftEnabledModelIds.has(model.id) !== model.enabled
|
||||||
|
);
|
||||||
|
const canUpdate = hasConnectionChanges || hasModelChanges;
|
||||||
|
|
||||||
function handleOpenChange(open: boolean) {
|
function handleOpenChange(open: boolean) {
|
||||||
setIsOpen(open);
|
setIsOpen(open);
|
||||||
|
|
@ -73,10 +95,35 @@ export function ConnectionSettingsDialog({
|
||||||
setShowApiKey(false);
|
setShowApiKey(false);
|
||||||
setAllowlistText(allowlist.join(", "));
|
setAllowlistText(allowlist.join(", "));
|
||||||
setIsSavingConnectionSettings(false);
|
setIsSavingConnectionSettings(false);
|
||||||
|
setDraftEnabledModelIds(enabledModelIds(connection.models));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function saveConnectionSettings() {
|
async function saveModelChanges() {
|
||||||
|
const toEnable = connection.models
|
||||||
|
.filter((model) => typeof model.id === "number" && draftEnabledModelIds.has(model.id))
|
||||||
|
.filter((model) => !model.enabled)
|
||||||
|
.map((model) => Number(model.id));
|
||||||
|
const toDisable = connection.models
|
||||||
|
.filter((model) => typeof model.id === "number" && !draftEnabledModelIds.has(model.id))
|
||||||
|
.filter((model) => model.enabled)
|
||||||
|
.map((model) => Number(model.id));
|
||||||
|
|
||||||
|
if (toEnable.length > 0) {
|
||||||
|
await bulkUpdateModels.mutateAsync({
|
||||||
|
connectionId: connection.id,
|
||||||
|
data: { model_ids: toEnable, enabled: true },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if (toDisable.length > 0) {
|
||||||
|
await bulkUpdateModels.mutateAsync({
|
||||||
|
connectionId: connection.id,
|
||||||
|
data: { model_ids: toDisable, enabled: false },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function saveConnectionSettings() {
|
||||||
if (isSavingConnectionSettings) return;
|
if (isSavingConnectionSettings) return;
|
||||||
|
|
||||||
const data: ConnectionUpdateRequest = {
|
const data: ConnectionUpdateRequest = {
|
||||||
|
|
@ -90,49 +137,35 @@ export function ConnectionSettingsDialog({
|
||||||
? (data.api_key ?? null)
|
? (data.api_key ?? null)
|
||||||
: (connection.api_key ?? null);
|
: (connection.api_key ?? null);
|
||||||
|
|
||||||
const enabledModels = connection.models.filter((model) => model.enabled);
|
const enabledModels = draftModels.filter((model) => model.enabled);
|
||||||
const testModel = enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0];
|
const testModel = enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0];
|
||||||
setIsSavingConnectionSettings(true);
|
setIsSavingConnectionSettings(true);
|
||||||
if (!testModel) {
|
try {
|
||||||
updateConnection.mutate(
|
if (hasConnectionChanges) {
|
||||||
{ id: connection.id, data },
|
if (testModel) {
|
||||||
{
|
const result = await testPreviewModel.mutateAsync({
|
||||||
onSuccess: () => setApiKeyDraft(""),
|
provider: connection.provider,
|
||||||
onSettled: () => setIsSavingConnectionSettings(false),
|
base_url: data.base_url,
|
||||||
|
api_key: apiKeyForTest,
|
||||||
|
scope: "SEARCH_SPACE",
|
||||||
|
search_space_id: connection.search_space_id,
|
||||||
|
extra: connection.extra ?? {},
|
||||||
|
enabled: connection.enabled,
|
||||||
|
models: [],
|
||||||
|
model_id: testModel.model_id,
|
||||||
|
});
|
||||||
|
if (!result.ok) return;
|
||||||
}
|
}
|
||||||
);
|
await updateConnection.mutateAsync({ id: connection.id, data });
|
||||||
return;
|
setApiKeyDraft("");
|
||||||
}
|
|
||||||
|
|
||||||
testPreviewModel.mutate(
|
|
||||||
{
|
|
||||||
provider: connection.provider,
|
|
||||||
base_url: data.base_url,
|
|
||||||
api_key: apiKeyForTest,
|
|
||||||
scope: "SEARCH_SPACE",
|
|
||||||
search_space_id: connection.search_space_id,
|
|
||||||
extra: connection.extra ?? {},
|
|
||||||
enabled: connection.enabled,
|
|
||||||
models: [],
|
|
||||||
model_id: testModel.model_id,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
onSuccess: (result) => {
|
|
||||||
if (!result.ok) {
|
|
||||||
setIsSavingConnectionSettings(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
updateConnection.mutate(
|
|
||||||
{ id: connection.id, data },
|
|
||||||
{
|
|
||||||
onSuccess: () => setApiKeyDraft(""),
|
|
||||||
onSettled: () => setIsSavingConnectionSettings(false),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
},
|
|
||||||
onError: () => setIsSavingConnectionSettings(false),
|
|
||||||
}
|
}
|
||||||
);
|
|
||||||
|
if (hasModelChanges) {
|
||||||
|
await saveModelChanges();
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setIsSavingConnectionSettings(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function saveAllowlist() {
|
function saveAllowlist() {
|
||||||
|
|
@ -148,9 +181,15 @@ export function ConnectionSettingsDialog({
|
||||||
|
|
||||||
function handleToggleModel(model: SelectableModel, enabled: boolean) {
|
function handleToggleModel(model: SelectableModel, enabled: boolean) {
|
||||||
if (typeof model.id !== "number") return;
|
if (typeof model.id !== "number") return;
|
||||||
updateModel.mutate({
|
const modelId = model.id;
|
||||||
id: model.id,
|
setDraftEnabledModelIds((current) => {
|
||||||
data: { enabled },
|
const next = new Set(current);
|
||||||
|
if (enabled) {
|
||||||
|
next.add(modelId);
|
||||||
|
} else {
|
||||||
|
next.delete(modelId);
|
||||||
|
}
|
||||||
|
return next;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -159,9 +198,16 @@ export function ConnectionSettingsDialog({
|
||||||
.map((model) => model.id)
|
.map((model) => model.id)
|
||||||
.filter((id): id is number => typeof id === "number");
|
.filter((id): id is number => typeof id === "number");
|
||||||
if (modelIds.length === 0) return;
|
if (modelIds.length === 0) return;
|
||||||
bulkUpdateModels.mutate({
|
setDraftEnabledModelIds((current) => {
|
||||||
connectionId: connection.id,
|
const next = new Set(current);
|
||||||
data: { model_ids: modelIds, enabled },
|
for (const id of modelIds) {
|
||||||
|
if (enabled) {
|
||||||
|
next.add(id);
|
||||||
|
} else {
|
||||||
|
next.delete(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return next;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -252,11 +298,11 @@ export function ConnectionSettingsDialog({
|
||||||
<Separator className="bg-muted-foreground/20" />
|
<Separator className="bg-muted-foreground/20" />
|
||||||
|
|
||||||
<ModelsSelectionPanel
|
<ModelsSelectionPanel
|
||||||
models={connection.models}
|
models={draftModels}
|
||||||
isRefreshing={discoverModels.isPending}
|
isRefreshing={discoverModels.isPending}
|
||||||
isAddingManual={addManualModel.isPending}
|
isAddingManual={addManualModel.isPending}
|
||||||
isUpdatingModel={updateModel.isPending}
|
isUpdatingModel={isSavingConnectionSettings}
|
||||||
isBulkUpdating={bulkUpdateModels.isPending}
|
isBulkUpdating={isSavingConnectionSettings || bulkUpdateModels.isPending}
|
||||||
refreshLabel={`Refresh ${providerLabel} models`}
|
refreshLabel={`Refresh ${providerLabel} models`}
|
||||||
onRefresh={() => discoverModels.mutate(connection.id)}
|
onRefresh={() => discoverModels.mutate(connection.id)}
|
||||||
onAddManual={(modelId) =>
|
onAddManual={(modelId) =>
|
||||||
|
|
@ -274,7 +320,7 @@ export function ConnectionSettingsDialog({
|
||||||
<DialogFooter className="shrink-0 border-t bg-popover px-6 py-4">
|
<DialogFooter className="shrink-0 border-t bg-popover px-6 py-4">
|
||||||
<Button
|
<Button
|
||||||
onClick={saveConnectionSettings}
|
onClick={saveConnectionSettings}
|
||||||
disabled={isSavingConnectionSettings || !hasConnectionChanges}
|
disabled={isSavingConnectionSettings || !canUpdate}
|
||||||
className="relative min-w-[96px]"
|
className="relative min-w-[96px]"
|
||||||
>
|
>
|
||||||
<span className={isSavingConnectionSettings ? "opacity-0" : ""}>Update</span>
|
<span className={isSavingConnectionSettings ? "opacity-0" : ""}>Update</span>
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue