mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-24 21:38:09 +02:00
Merge pull request #1491 from AnishSarkar22/feat/unified-model-connections
feat: Fix model attribution for prefix-stripped token usage callbacks
This commit is contained in:
commit
69bdcf5946
7 changed files with 190 additions and 32 deletions
|
|
@ -32,6 +32,23 @@ from app.db import TokenUsage
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _bare_model_name(model: str) -> str:
|
||||||
|
"""Return a model identifier with any provider routing prefix stripped.
|
||||||
|
|
||||||
|
LiteLLM's ``get_llm_provider`` consumes the provider prefix we add in
|
||||||
|
``to_litellm`` (e.g. ``azure/gpt-5.2-chat`` → ``gpt-5.2-chat`` because
|
||||||
|
``azure`` is in ``litellm.provider_list``). The token-tracking success
|
||||||
|
callback therefore reports ``kwargs["model"]`` *without* that prefix,
|
||||||
|
while model metadata is registered under the *prefixed* string. Normalising
|
||||||
|
both sides to the last path segment lets the two reconcile so the per-model
|
||||||
|
breakdown carries provider/display_name and the UI attributes the turn to
|
||||||
|
the correct connection instead of falling back to a bare-name collision.
|
||||||
|
"""
|
||||||
|
if not model:
|
||||||
|
return model
|
||||||
|
return model.split("/")[-1]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TokenCallRecord:
|
class TokenCallRecord:
|
||||||
model: str
|
model: str
|
||||||
|
|
@ -52,6 +69,12 @@ class TurnTokenAccumulator:
|
||||||
|
|
||||||
calls: list[TokenCallRecord] = field(default_factory=list)
|
calls: list[TokenCallRecord] = field(default_factory=list)
|
||||||
model_metadata: dict[str, dict[str, str | None]] = field(default_factory=dict)
|
model_metadata: dict[str, dict[str, str | None]] = field(default_factory=dict)
|
||||||
|
# Secondary index keyed by the bare model name (provider prefix stripped) so
|
||||||
|
# the LiteLLM callback — which never sees our routing prefix — can still
|
||||||
|
# reconcile its ``kwargs["model"]`` back to the registered metadata.
|
||||||
|
model_metadata_by_bare: dict[str, dict[str, str | None]] = field(
|
||||||
|
default_factory=dict
|
||||||
|
)
|
||||||
|
|
||||||
def register_model_metadata(
|
def register_model_metadata(
|
||||||
self,
|
self,
|
||||||
|
|
@ -63,12 +86,28 @@ class TurnTokenAccumulator:
|
||||||
provider: str | None,
|
provider: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Attach resolved model metadata for later LiteLLM callback attribution."""
|
"""Attach resolved model metadata for later LiteLLM callback attribution."""
|
||||||
self.model_metadata[model] = {
|
metadata = {
|
||||||
"model_ref": model_ref,
|
"model_ref": model_ref,
|
||||||
"model_id": model_id,
|
"model_id": model_id,
|
||||||
"display_name": display_name,
|
"display_name": display_name,
|
||||||
"provider": provider,
|
"provider": provider,
|
||||||
}
|
}
|
||||||
|
self.model_metadata[model] = metadata
|
||||||
|
# Index every reconcilable alias: the prefixed string's bare form and
|
||||||
|
# the resolved ``model_id`` (which for some providers is itself the bare
|
||||||
|
# deployment LiteLLM reports). Exact lookups always take precedence.
|
||||||
|
self.model_metadata_by_bare[_bare_model_name(model)] = metadata
|
||||||
|
if model_id:
|
||||||
|
self.model_metadata_by_bare.setdefault(_bare_model_name(model_id), metadata)
|
||||||
|
|
||||||
|
def _lookup_metadata(self, model: str) -> dict[str, str | None]:
|
||||||
|
"""Resolve registered metadata for a callback model, tolerating the
|
||||||
|
provider-prefix stripping LiteLLM applies before the success callback
|
||||||
|
fires (see :func:`_bare_model_name`)."""
|
||||||
|
exact = self.model_metadata.get(model)
|
||||||
|
if exact is not None:
|
||||||
|
return exact
|
||||||
|
return self.model_metadata_by_bare.get(_bare_model_name(model), {})
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
|
|
@ -79,7 +118,7 @@ class TurnTokenAccumulator:
|
||||||
cost_micros: int = 0,
|
cost_micros: int = 0,
|
||||||
call_kind: str = "chat",
|
call_kind: str = "chat",
|
||||||
) -> None:
|
) -> None:
|
||||||
metadata = self.model_metadata.get(model, {})
|
metadata = self._lookup_metadata(model)
|
||||||
self.calls.append(
|
self.calls.append(
|
||||||
TokenCallRecord(
|
TokenCallRecord(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
||||||
|
|
@ -112,6 +112,77 @@ def test_per_message_summary_groups_cost_by_model():
|
||||||
assert summary["gpt-4o-mini"]["cost_micros"] == 200
|
assert summary["gpt-4o-mini"]["cost_micros"] == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_reconciles_metadata_when_litellm_strips_provider_prefix():
|
||||||
|
"""Regression: LiteLLM's ``get_llm_provider`` strips the provider prefix we
|
||||||
|
add in ``to_litellm`` (``azure/gpt-5.2-chat`` → ``gpt-5.2-chat`` because
|
||||||
|
``azure`` is in ``litellm.provider_list``), so the success callback reports
|
||||||
|
the bare model. Metadata registered under the *prefixed* string must still
|
||||||
|
attach to the call so the per-model breakdown carries provider/display_name
|
||||||
|
— otherwise the UI falls back to a bare-name collision and mis-attributes an
|
||||||
|
Azure turn to an OpenRouter model (e.g. shows "OpenAI: GPT-5.2 Chat").
|
||||||
|
"""
|
||||||
|
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||||
|
|
||||||
|
acc = TurnTokenAccumulator()
|
||||||
|
acc.register_model_metadata(
|
||||||
|
model="azure/gpt-5.2-chat",
|
||||||
|
model_ref="global:-1",
|
||||||
|
model_id="gpt-5.2-chat",
|
||||||
|
display_name="Azure GPT 5.2",
|
||||||
|
provider="azure",
|
||||||
|
)
|
||||||
|
# LiteLLM callback fires with the prefix-stripped model name.
|
||||||
|
acc.add(
|
||||||
|
model="gpt-5.2-chat",
|
||||||
|
prompt_tokens=100,
|
||||||
|
completion_tokens=50,
|
||||||
|
total_tokens=150,
|
||||||
|
cost_micros=4_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
summary = acc.per_message_summary()
|
||||||
|
entry = summary["gpt-5.2-chat"]
|
||||||
|
assert entry["provider"] == "azure"
|
||||||
|
assert entry["display_name"] == "Azure GPT 5.2"
|
||||||
|
assert entry["model_id"] == "gpt-5.2-chat"
|
||||||
|
assert entry["model_ref"] == "global:-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_prefers_exact_metadata_over_bare_alias():
|
||||||
|
"""When the callback model matches a registered key exactly, the exact
|
||||||
|
metadata wins even if another model shares the same bare name — so a turn
|
||||||
|
that legitimately used two same-named deployments stays correctly
|
||||||
|
attributed."""
|
||||||
|
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||||
|
|
||||||
|
acc = TurnTokenAccumulator()
|
||||||
|
acc.register_model_metadata(
|
||||||
|
model="azure/gpt-5.2-chat",
|
||||||
|
model_ref="global:-1",
|
||||||
|
model_id="gpt-5.2-chat",
|
||||||
|
display_name="Azure GPT 5.2",
|
||||||
|
provider="azure",
|
||||||
|
)
|
||||||
|
acc.register_model_metadata(
|
||||||
|
model="openai/gpt-5.2-chat",
|
||||||
|
model_ref="db:7",
|
||||||
|
model_id="gpt-5.2-chat",
|
||||||
|
display_name="OpenAI GPT 5.2",
|
||||||
|
provider="openai",
|
||||||
|
)
|
||||||
|
acc.add(
|
||||||
|
model="openai/gpt-5.2-chat",
|
||||||
|
prompt_tokens=10,
|
||||||
|
completion_tokens=5,
|
||||||
|
total_tokens=15,
|
||||||
|
cost_micros=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
entry = acc.per_message_summary()["openai/gpt-5.2-chat"]
|
||||||
|
assert entry["provider"] == "openai"
|
||||||
|
assert entry["display_name"] == "OpenAI GPT 5.2"
|
||||||
|
|
||||||
|
|
||||||
def test_serialized_calls_includes_cost_micros():
|
def test_serialized_calls_includes_cost_micros():
|
||||||
"""``serialized_calls`` is what flows into the SSE ``call_details``
|
"""``serialized_calls`` is what flows into the SSE ``call_details``
|
||||||
payload; cost_micros must be present on each entry so the FE message-info
|
payload; cost_micros must be present on each entry so the FE message-info
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ export const metadata: Metadata = {
|
||||||
alternates: {
|
alternates: {
|
||||||
canonical: "https://www.surfsense.com",
|
canonical: "https://www.surfsense.com",
|
||||||
},
|
},
|
||||||
title: "SurfSense – Open Source, Privacy-Focused NotebookLM Alternative for Teams",
|
title: "SurfSense - Open Source, Privacy-Focused NotebookLM Alternative for Teams",
|
||||||
description:
|
description:
|
||||||
"Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude AI, and any AI model for free.",
|
"Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude AI, and any AI model for free.",
|
||||||
keywords: [
|
keywords: [
|
||||||
|
|
@ -88,7 +88,7 @@ export const metadata: Metadata = {
|
||||||
"SurfSense",
|
"SurfSense",
|
||||||
],
|
],
|
||||||
openGraph: {
|
openGraph: {
|
||||||
title: "SurfSense – Open Source, Privacy-Focused NotebookLM Alternative for Teams",
|
title: "SurfSense - Open Source, Privacy-Focused NotebookLM Alternative for Teams",
|
||||||
description:
|
description:
|
||||||
"Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude, and any AI model for free.",
|
"Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude, and any AI model for free.",
|
||||||
url: "https://www.surfsense.com",
|
url: "https://www.surfsense.com",
|
||||||
|
|
@ -106,7 +106,7 @@ export const metadata: Metadata = {
|
||||||
},
|
},
|
||||||
twitter: {
|
twitter: {
|
||||||
card: "summary_large_image",
|
card: "summary_large_image",
|
||||||
title: "SurfSense – Open Source, Privacy-Focused NotebookLM Alternative for Teams",
|
title: "SurfSense - Open Source, Privacy-Focused NotebookLM Alternative for Teams",
|
||||||
description:
|
description:
|
||||||
"Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude AI, and any AI model for free.",
|
"Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude AI, and any AI model for free.",
|
||||||
creator: "@SurfSenseAI",
|
creator: "@SurfSenseAI",
|
||||||
|
|
|
||||||
|
|
@ -522,6 +522,11 @@ const Composer: FC = () => {
|
||||||
editorRef.current?.focus();
|
editorRef.current?.focus();
|
||||||
}, [isDesktop, showDocumentPopover, showPromptPicker, threadId]);
|
}, [isDesktop, showDocumentPopover, showPromptPicker, threadId]);
|
||||||
|
|
||||||
|
const handleChatModelSelected = useCallback(() => {
|
||||||
|
if (!isDesktop) return;
|
||||||
|
editorRef.current?.focus();
|
||||||
|
}, [isDesktop]);
|
||||||
|
|
||||||
// Close document picker when a sidebar slide-out panel (inbox, etc.) opens.
|
// Close document picker when a sidebar slide-out panel (inbox, etc.) opens.
|
||||||
// React only on changes to the tick — comparing against the previously-seen
|
// React only on changes to the tick — comparing against the previously-seen
|
||||||
// value preserves the one-shot semantics of the prior window-event approach
|
// value preserves the one-shot semantics of the prior window-event approach
|
||||||
|
|
@ -935,6 +940,7 @@ const Composer: FC = () => {
|
||||||
<ComposerAction
|
<ComposerAction
|
||||||
isBlockedByOtherUser={isBlockedByOtherUser}
|
isBlockedByOtherUser={isBlockedByOtherUser}
|
||||||
searchSpaceId={Number(search_space_id)}
|
searchSpaceId={Number(search_space_id)}
|
||||||
|
onChatModelSelected={handleChatModelSelected}
|
||||||
/>
|
/>
|
||||||
<ConnectorIndicator showTrigger={false} />
|
<ConnectorIndicator showTrigger={false} />
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -955,11 +961,13 @@ const Composer: FC = () => {
|
||||||
interface ComposerActionProps {
|
interface ComposerActionProps {
|
||||||
isBlockedByOtherUser?: boolean;
|
isBlockedByOtherUser?: boolean;
|
||||||
searchSpaceId: number;
|
searchSpaceId: number;
|
||||||
|
onChatModelSelected?: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
const ComposerAction: FC<ComposerActionProps> = ({
|
const ComposerAction: FC<ComposerActionProps> = ({
|
||||||
isBlockedByOtherUser = false,
|
isBlockedByOtherUser = false,
|
||||||
searchSpaceId,
|
searchSpaceId,
|
||||||
|
onChatModelSelected,
|
||||||
}) => {
|
}) => {
|
||||||
const mentionedDocuments = useAtomValue(mentionedDocumentsAtom);
|
const mentionedDocuments = useAtomValue(mentionedDocumentsAtom);
|
||||||
const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom);
|
const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom);
|
||||||
|
|
@ -1573,6 +1581,7 @@ const ComposerAction: FC<ComposerActionProps> = ({
|
||||||
<ChatHeader
|
<ChatHeader
|
||||||
searchSpaceId={searchSpaceId}
|
searchSpaceId={searchSpaceId}
|
||||||
className="h-9 max-w-[44vw] px-2 sm:max-w-[220px] sm:px-3"
|
className="h-9 max-w-[44vw] px-2 sm:max-w-[220px] sm:px-3"
|
||||||
|
onChatModelSelected={onChatModelSelected}
|
||||||
/>
|
/>
|
||||||
<AuiIf condition={({ thread }) => !thread.isRunning}>
|
<AuiIf condition={({ thread }) => !thread.isRunning}>
|
||||||
<ComposerPrimitive.Send asChild disabled={isSendDisabled}>
|
<ComposerPrimitive.Send asChild disabled={isSendDisabled}>
|
||||||
|
|
|
||||||
|
|
@ -1149,6 +1149,7 @@ function AuthenticatedDocumentsSidebarBase({
|
||||||
const showCloudSkeleton =
|
const showCloudSkeleton =
|
||||||
currentFilesystemTab === "cloud" &&
|
currentFilesystemTab === "cloud" &&
|
||||||
(zeroFoldersResult.type !== "complete" || zeroAllDocsResult.type !== "complete");
|
(zeroFoldersResult.type !== "complete" || zeroAllDocsResult.type !== "complete");
|
||||||
|
const connectorButtonLabel = connectorCount > 0 ? "Manage connectors" : "Connect your connectors";
|
||||||
|
|
||||||
const cloudContent = (
|
const cloudContent = (
|
||||||
<>
|
<>
|
||||||
|
|
@ -1161,9 +1162,7 @@ function AuthenticatedDocumentsSidebarBase({
|
||||||
className="shrink-0 mx-4 mt-6 mb-2.5 h-auto select-none justify-start gap-2 bg-muted px-3 py-1.5 text-xs text-muted-foreground"
|
className="shrink-0 mx-4 mt-6 mb-2.5 h-auto select-none justify-start gap-2 bg-muted px-3 py-1.5 text-xs text-muted-foreground"
|
||||||
>
|
>
|
||||||
<Unplug className="size-4 shrink-0" />
|
<Unplug className="size-4 shrink-0" />
|
||||||
<span className="truncate">
|
<span className="truncate">{connectorButtonLabel}</span>
|
||||||
{connectorCount > 0 ? "Manage connectors" : "Connect your connectors"}
|
|
||||||
</span>
|
|
||||||
{connectorCount > 0 && (
|
{connectorCount > 0 && (
|
||||||
<span className="shrink-0 rounded-full bg-muted-foreground/15 px-1.5 py-0.5 text-[10px] font-medium text-muted-foreground">
|
<span className="shrink-0 rounded-full bg-muted-foreground/15 px-1.5 py-0.5 text-[10px] font-medium text-muted-foreground">
|
||||||
{connectorCount}
|
{connectorCount}
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,17 @@ import { ModelSelector } from "./model-selector";
|
||||||
interface ChatHeaderProps {
|
interface ChatHeaderProps {
|
||||||
searchSpaceId: number;
|
searchSpaceId: number;
|
||||||
className?: string;
|
className?: string;
|
||||||
|
onChatModelSelected?: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) {
|
export function ChatHeader({ searchSpaceId, className, onChatModelSelected }: ChatHeaderProps) {
|
||||||
return (
|
return (
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
<ModelSelector searchSpaceId={searchSpaceId} className={className} />
|
<ModelSelector
|
||||||
|
searchSpaceId={searchSpaceId}
|
||||||
|
className={className}
|
||||||
|
onChatModelSelected={onChatModelSelected}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useAtom, useAtomValue } from "jotai";
|
import { useAtom, useAtomValue } from "jotai";
|
||||||
import { Check, ChevronDown, Search, Settings2 } from "lucide-react";
|
import { Check, ChevronDown, Search, SlidersHorizontal } from "lucide-react";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import type { UIEvent } from "react";
|
import type { UIEvent } from "react";
|
||||||
import { useCallback, useMemo, useState } from "react";
|
import { useCallback, useMemo, useState } from "react";
|
||||||
|
|
@ -33,6 +33,7 @@ import { providerDisplay } from "../settings/model-connections/provider-metadata
|
||||||
interface ModelSelectorProps {
|
interface ModelSelectorProps {
|
||||||
searchSpaceId: number;
|
searchSpaceId: number;
|
||||||
className?: string;
|
className?: string;
|
||||||
|
onChatModelSelected?: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatModel = ModelRead & {
|
type ChatModel = ModelRead & {
|
||||||
|
|
@ -42,6 +43,8 @@ type ChatModel = ModelRead & {
|
||||||
provider: string;
|
provider: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const AUTO_CHAT_MODEL_ID = 0;
|
||||||
|
|
||||||
function connectionLabel(connection: ConnectionRead) {
|
function connectionLabel(connection: ConnectionRead) {
|
||||||
if (connection.scope === "GLOBAL") return "Global";
|
if (connection.scope === "GLOBAL") return "Global";
|
||||||
return providerDisplay(connection.provider).name;
|
return providerDisplay(connection.provider).name;
|
||||||
|
|
@ -73,6 +76,17 @@ function modelName(model: ChatModel) {
|
||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function filterChatModels(models: ChatModel[], search: string) {
|
||||||
|
const normalized = search.trim().toLowerCase();
|
||||||
|
if (!normalized) return models;
|
||||||
|
return models.filter((model) =>
|
||||||
|
[modelName(model), model.model_id, model.connectionLabel]
|
||||||
|
.join(" ")
|
||||||
|
.toLowerCase()
|
||||||
|
.includes(normalized)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
function groupedModels(models: ChatModel[]) {
|
function groupedModels(models: ChatModel[]) {
|
||||||
return models.reduce<Record<string, ChatModel[]>>((groups, model) => {
|
return models.reduce<Record<string, ChatModel[]>>((groups, model) => {
|
||||||
const key = model.connectionLabel;
|
const key = model.connectionLabel;
|
||||||
|
|
@ -82,7 +96,11 @@ function groupedModels(models: ChatModel[]) {
|
||||||
}, {});
|
}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps) {
|
export function ModelSelector({
|
||||||
|
searchSpaceId,
|
||||||
|
className,
|
||||||
|
onChatModelSelected,
|
||||||
|
}: ModelSelectorProps) {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const isMobile = useIsMobile();
|
const isMobile = useIsMobile();
|
||||||
const [open, setOpen] = useState(false);
|
const [open, setOpen] = useState(false);
|
||||||
|
|
@ -95,25 +113,37 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
|
||||||
const [{ data: roles }] = useAtom(modelRolesAtom);
|
const [{ data: roles }] = useAtom(modelRolesAtom);
|
||||||
const updateRoles = useAtomValue(updateModelRolesMutationAtom);
|
const updateRoles = useAtomValue(updateModelRolesMutationAtom);
|
||||||
|
|
||||||
const chatModels = useMemo(() => {
|
const allChatModels = useMemo(
|
||||||
const normalized = search.trim().toLowerCase();
|
() => flattenChatModels([...globalConnections, ...connections]),
|
||||||
const models = flattenChatModels([...globalConnections, ...connections]);
|
[globalConnections, connections]
|
||||||
if (!normalized) return models;
|
);
|
||||||
return models.filter((model) =>
|
|
||||||
[modelName(model), model.model_id, model.connectionLabel]
|
|
||||||
.join(" ")
|
|
||||||
.toLowerCase()
|
|
||||||
.includes(normalized)
|
|
||||||
);
|
|
||||||
}, [globalConnections, connections, search]);
|
|
||||||
|
|
||||||
const selected = chatModels.find((model) => model.id === roles?.chat_model_id);
|
const visibleChatModels = useMemo(
|
||||||
const groups = groupedModels(chatModels);
|
() => filterChatModels(allChatModels, search),
|
||||||
|
[allChatModels, search]
|
||||||
|
);
|
||||||
|
const chatModelsById = useMemo(
|
||||||
|
() => new Map(allChatModels.map((model) => [model.id, model])),
|
||||||
|
[allChatModels]
|
||||||
|
);
|
||||||
|
const selectedModelId = roles?.chat_model_id ?? AUTO_CHAT_MODEL_ID;
|
||||||
|
const selected = chatModelsById.get(selectedModelId);
|
||||||
|
const groups = useMemo(() => groupedModels(visibleChatModels), [visibleChatModels]);
|
||||||
const loading = globalLoading || connectionsLoading;
|
const loading = globalLoading || connectionsLoading;
|
||||||
|
const hasSearchQuery = search.trim().length > 0;
|
||||||
|
|
||||||
|
function handleOpenChange(nextOpen: boolean) {
|
||||||
|
if (!nextOpen) setSearch("");
|
||||||
|
setOpen(nextOpen);
|
||||||
|
}
|
||||||
|
|
||||||
function selectModel(modelId: number) {
|
function selectModel(modelId: number) {
|
||||||
updateRoles.mutate({ chat_model_id: modelId });
|
updateRoles.mutate({ chat_model_id: modelId });
|
||||||
|
setSearch("");
|
||||||
setOpen(false);
|
setOpen(false);
|
||||||
|
requestAnimationFrame(() => {
|
||||||
|
onChatModelSelected?.();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
function manageModelConnections() {
|
function manageModelConnections() {
|
||||||
|
|
@ -152,7 +182,7 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
className="flex w-full items-center justify-between rounded-md px-3 py-2 text-left transition-colors hover:bg-accent hover:text-accent-foreground"
|
className="flex w-full items-center justify-between rounded-md px-3 py-2 text-left transition-colors hover:bg-accent hover:text-accent-foreground"
|
||||||
onClick={() => selectModel(0)}
|
onClick={() => selectModel(AUTO_CHAT_MODEL_ID)}
|
||||||
>
|
>
|
||||||
<div className="min-w-0 flex-1">
|
<div className="min-w-0 flex-1">
|
||||||
<div className="flex min-w-0 items-center gap-2 font-medium">
|
<div className="flex min-w-0 items-center gap-2 font-medium">
|
||||||
|
|
@ -160,7 +190,7 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
|
||||||
<span className="truncate">Auto</span>
|
<span className="truncate">Auto</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{(roles?.chat_model_id ?? 0) === 0 ? <Check className="h-4 w-4" /> : null}
|
{selectedModelId === AUTO_CHAT_MODEL_ID ? <Check className="h-4 w-4" /> : null}
|
||||||
</button>
|
</button>
|
||||||
{loading ? (
|
{loading ? (
|
||||||
<div className="flex items-center justify-center py-8">
|
<div className="flex items-center justify-center py-8">
|
||||||
|
|
@ -168,7 +198,9 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
|
||||||
</div>
|
</div>
|
||||||
) : Object.keys(groups).length === 0 ? (
|
) : Object.keys(groups).length === 0 ? (
|
||||||
<div className="px-3 py-8 text-center text-sm text-muted-foreground">
|
<div className="px-3 py-8 text-center text-sm text-muted-foreground">
|
||||||
No enabled chat models. Add or enable models in Settings.
|
{hasSearchQuery
|
||||||
|
? "No matching chat models."
|
||||||
|
: "No enabled chat models. Add or enable models in Settings."}
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
Object.entries(groups).map(([connection, models]) => (
|
Object.entries(groups).map(([connection, models]) => (
|
||||||
|
|
@ -228,7 +260,7 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
|
||||||
className="w-full justify-start rounded-md bg-foreground/5 hover:bg-foreground/10 hover:text-foreground"
|
className="w-full justify-start rounded-md bg-foreground/5 hover:bg-foreground/10 hover:text-foreground"
|
||||||
onClick={manageModelConnections}
|
onClick={manageModelConnections}
|
||||||
>
|
>
|
||||||
<Settings2 className="mr-2 h-4 w-4" /> Manage models
|
<SlidersHorizontal className="h-4 w-4" /> Manage models
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -259,7 +291,7 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
|
||||||
|
|
||||||
if (isMobile) {
|
if (isMobile) {
|
||||||
return (
|
return (
|
||||||
<Drawer open={open} onOpenChange={setOpen}>
|
<Drawer open={open} onOpenChange={handleOpenChange}>
|
||||||
<DrawerTrigger asChild>{trigger}</DrawerTrigger>
|
<DrawerTrigger asChild>{trigger}</DrawerTrigger>
|
||||||
<DrawerContent className="max-h-[85vh]">
|
<DrawerContent className="max-h-[85vh]">
|
||||||
<DrawerHandle />
|
<DrawerHandle />
|
||||||
|
|
@ -273,9 +305,12 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Popover open={open} onOpenChange={setOpen}>
|
<Popover open={open} onOpenChange={handleOpenChange}>
|
||||||
<PopoverTrigger asChild>{trigger}</PopoverTrigger>
|
<PopoverTrigger asChild>{trigger}</PopoverTrigger>
|
||||||
<PopoverContent align="start" className="w-[340px] p-0">
|
<PopoverContent
|
||||||
|
align="start"
|
||||||
|
className="w-[340px] border border-popover-border bg-popover p-0 text-popover-foreground shadow-md"
|
||||||
|
>
|
||||||
{content}
|
{content}
|
||||||
</PopoverContent>
|
</PopoverContent>
|
||||||
</Popover>
|
</Popover>
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue