Merge pull request #1491 from AnishSarkar22/feat/unified-model-connections

feat: Fix model attribution for prefix-stripped token usage callbacks
This commit is contained in:
Rohan Verma 2026-06-14 17:50:48 -07:00 committed by GitHub
commit 69bdcf5946
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 190 additions and 32 deletions

View file

@ -32,6 +32,23 @@ from app.db import TokenUsage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _bare_model_name(model: str) -> str:
"""Return a model identifier with any provider routing prefix stripped.
LiteLLM's ``get_llm_provider`` consumes the provider prefix we add in
``to_litellm`` (e.g. ``azure/gpt-5.2-chat`` ``gpt-5.2-chat`` because
``azure`` is in ``litellm.provider_list``). The token-tracking success
callback therefore reports ``kwargs["model"]`` *without* that prefix,
while model metadata is registered under the *prefixed* string. Normalising
both sides to the last path segment lets the two reconcile so the per-model
breakdown carries provider/display_name and the UI attributes the turn to
the correct connection instead of falling back to a bare-name collision.
"""
if not model:
return model
return model.split("/")[-1]
@dataclass @dataclass
class TokenCallRecord: class TokenCallRecord:
model: str model: str
@ -52,6 +69,12 @@ class TurnTokenAccumulator:
calls: list[TokenCallRecord] = field(default_factory=list) calls: list[TokenCallRecord] = field(default_factory=list)
model_metadata: dict[str, dict[str, str | None]] = field(default_factory=dict) model_metadata: dict[str, dict[str, str | None]] = field(default_factory=dict)
# Secondary index keyed by the bare model name (provider prefix stripped) so
# the LiteLLM callback — which never sees our routing prefix — can still
# reconcile its ``kwargs["model"]`` back to the registered metadata.
model_metadata_by_bare: dict[str, dict[str, str | None]] = field(
default_factory=dict
)
def register_model_metadata( def register_model_metadata(
self, self,
@ -63,12 +86,28 @@ class TurnTokenAccumulator:
provider: str | None, provider: str | None,
) -> None: ) -> None:
"""Attach resolved model metadata for later LiteLLM callback attribution.""" """Attach resolved model metadata for later LiteLLM callback attribution."""
self.model_metadata[model] = { metadata = {
"model_ref": model_ref, "model_ref": model_ref,
"model_id": model_id, "model_id": model_id,
"display_name": display_name, "display_name": display_name,
"provider": provider, "provider": provider,
} }
self.model_metadata[model] = metadata
# Index every reconcilable alias: the prefixed string's bare form and
# the resolved ``model_id`` (which for some providers is itself the bare
# deployment LiteLLM reports). Exact lookups always take precedence.
self.model_metadata_by_bare[_bare_model_name(model)] = metadata
if model_id:
self.model_metadata_by_bare.setdefault(_bare_model_name(model_id), metadata)
def _lookup_metadata(self, model: str) -> dict[str, str | None]:
"""Resolve registered metadata for a callback model, tolerating the
provider-prefix stripping LiteLLM applies before the success callback
fires (see :func:`_bare_model_name`)."""
exact = self.model_metadata.get(model)
if exact is not None:
return exact
return self.model_metadata_by_bare.get(_bare_model_name(model), {})
def add( def add(
self, self,
@ -79,7 +118,7 @@ class TurnTokenAccumulator:
cost_micros: int = 0, cost_micros: int = 0,
call_kind: str = "chat", call_kind: str = "chat",
) -> None: ) -> None:
metadata = self.model_metadata.get(model, {}) metadata = self._lookup_metadata(model)
self.calls.append( self.calls.append(
TokenCallRecord( TokenCallRecord(
model=model, model=model,

View file

@ -112,6 +112,77 @@ def test_per_message_summary_groups_cost_by_model():
assert summary["gpt-4o-mini"]["cost_micros"] == 200 assert summary["gpt-4o-mini"]["cost_micros"] == 200
def test_add_reconciles_metadata_when_litellm_strips_provider_prefix():
"""Regression: LiteLLM's ``get_llm_provider`` strips the provider prefix we
add in ``to_litellm`` (``azure/gpt-5.2-chat`` ``gpt-5.2-chat`` because
``azure`` is in ``litellm.provider_list``), so the success callback reports
the bare model. Metadata registered under the *prefixed* string must still
attach to the call so the per-model breakdown carries provider/display_name
otherwise the UI falls back to a bare-name collision and mis-attributes an
Azure turn to an OpenRouter model (e.g. shows "OpenAI: GPT-5.2 Chat").
"""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
acc.register_model_metadata(
model="azure/gpt-5.2-chat",
model_ref="global:-1",
model_id="gpt-5.2-chat",
display_name="Azure GPT 5.2",
provider="azure",
)
# LiteLLM callback fires with the prefix-stripped model name.
acc.add(
model="gpt-5.2-chat",
prompt_tokens=100,
completion_tokens=50,
total_tokens=150,
cost_micros=4_000,
)
summary = acc.per_message_summary()
entry = summary["gpt-5.2-chat"]
assert entry["provider"] == "azure"
assert entry["display_name"] == "Azure GPT 5.2"
assert entry["model_id"] == "gpt-5.2-chat"
assert entry["model_ref"] == "global:-1"
def test_add_prefers_exact_metadata_over_bare_alias():
"""When the callback model matches a registered key exactly, the exact
metadata wins even if another model shares the same bare name so a turn
that legitimately used two same-named deployments stays correctly
attributed."""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
acc.register_model_metadata(
model="azure/gpt-5.2-chat",
model_ref="global:-1",
model_id="gpt-5.2-chat",
display_name="Azure GPT 5.2",
provider="azure",
)
acc.register_model_metadata(
model="openai/gpt-5.2-chat",
model_ref="db:7",
model_id="gpt-5.2-chat",
display_name="OpenAI GPT 5.2",
provider="openai",
)
acc.add(
model="openai/gpt-5.2-chat",
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
cost_micros=100,
)
entry = acc.per_message_summary()["openai/gpt-5.2-chat"]
assert entry["provider"] == "openai"
assert entry["display_name"] == "OpenAI GPT 5.2"
def test_serialized_calls_includes_cost_micros(): def test_serialized_calls_includes_cost_micros():
"""``serialized_calls`` is what flows into the SSE ``call_details`` """``serialized_calls`` is what flows into the SSE ``call_details``
payload; cost_micros must be present on each entry so the FE message-info payload; cost_micros must be present on each entry so the FE message-info

View file

@ -46,7 +46,7 @@ export const metadata: Metadata = {
alternates: { alternates: {
canonical: "https://www.surfsense.com", canonical: "https://www.surfsense.com",
}, },
title: "SurfSense Open Source, Privacy-Focused NotebookLM Alternative for Teams", title: "SurfSense - Open Source, Privacy-Focused NotebookLM Alternative for Teams",
description: description:
"Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude AI, and any AI model for free.", "Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude AI, and any AI model for free.",
keywords: [ keywords: [
@ -88,7 +88,7 @@ export const metadata: Metadata = {
"SurfSense", "SurfSense",
], ],
openGraph: { openGraph: {
title: "SurfSense Open Source, Privacy-Focused NotebookLM Alternative for Teams", title: "SurfSense - Open Source, Privacy-Focused NotebookLM Alternative for Teams",
description: description:
"Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude, and any AI model for free.", "Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude, and any AI model for free.",
url: "https://www.surfsense.com", url: "https://www.surfsense.com",
@ -106,7 +106,7 @@ export const metadata: Metadata = {
}, },
twitter: { twitter: {
card: "summary_large_image", card: "summary_large_image",
title: "SurfSense Open Source, Privacy-Focused NotebookLM Alternative for Teams", title: "SurfSense - Open Source, Privacy-Focused NotebookLM Alternative for Teams",
description: description:
"Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude AI, and any AI model for free.", "Open source NotebookLM alternative for teams with no data limits. Use ChatGPT, Claude AI, and any AI model for free.",
creator: "@SurfSenseAI", creator: "@SurfSenseAI",

View file

@ -522,6 +522,11 @@ const Composer: FC = () => {
editorRef.current?.focus(); editorRef.current?.focus();
}, [isDesktop, showDocumentPopover, showPromptPicker, threadId]); }, [isDesktop, showDocumentPopover, showPromptPicker, threadId]);
const handleChatModelSelected = useCallback(() => {
if (!isDesktop) return;
editorRef.current?.focus();
}, [isDesktop]);
// Close document picker when a sidebar slide-out panel (inbox, etc.) opens. // Close document picker when a sidebar slide-out panel (inbox, etc.) opens.
// React only on changes to the tick — comparing against the previously-seen // React only on changes to the tick — comparing against the previously-seen
// value preserves the one-shot semantics of the prior window-event approach // value preserves the one-shot semantics of the prior window-event approach
@ -935,6 +940,7 @@ const Composer: FC = () => {
<ComposerAction <ComposerAction
isBlockedByOtherUser={isBlockedByOtherUser} isBlockedByOtherUser={isBlockedByOtherUser}
searchSpaceId={Number(search_space_id)} searchSpaceId={Number(search_space_id)}
onChatModelSelected={handleChatModelSelected}
/> />
<ConnectorIndicator showTrigger={false} /> <ConnectorIndicator showTrigger={false} />
</div> </div>
@ -955,11 +961,13 @@ const Composer: FC = () => {
interface ComposerActionProps { interface ComposerActionProps {
isBlockedByOtherUser?: boolean; isBlockedByOtherUser?: boolean;
searchSpaceId: number; searchSpaceId: number;
onChatModelSelected?: () => void;
} }
const ComposerAction: FC<ComposerActionProps> = ({ const ComposerAction: FC<ComposerActionProps> = ({
isBlockedByOtherUser = false, isBlockedByOtherUser = false,
searchSpaceId, searchSpaceId,
onChatModelSelected,
}) => { }) => {
const mentionedDocuments = useAtomValue(mentionedDocumentsAtom); const mentionedDocuments = useAtomValue(mentionedDocumentsAtom);
const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom); const setConnectorDialogOpen = useSetAtom(connectorDialogOpenAtom);
@ -1573,6 +1581,7 @@ const ComposerAction: FC<ComposerActionProps> = ({
<ChatHeader <ChatHeader
searchSpaceId={searchSpaceId} searchSpaceId={searchSpaceId}
className="h-9 max-w-[44vw] px-2 sm:max-w-[220px] sm:px-3" className="h-9 max-w-[44vw] px-2 sm:max-w-[220px] sm:px-3"
onChatModelSelected={onChatModelSelected}
/> />
<AuiIf condition={({ thread }) => !thread.isRunning}> <AuiIf condition={({ thread }) => !thread.isRunning}>
<ComposerPrimitive.Send asChild disabled={isSendDisabled}> <ComposerPrimitive.Send asChild disabled={isSendDisabled}>

View file

@ -1149,6 +1149,7 @@ function AuthenticatedDocumentsSidebarBase({
const showCloudSkeleton = const showCloudSkeleton =
currentFilesystemTab === "cloud" && currentFilesystemTab === "cloud" &&
(zeroFoldersResult.type !== "complete" || zeroAllDocsResult.type !== "complete"); (zeroFoldersResult.type !== "complete" || zeroAllDocsResult.type !== "complete");
const connectorButtonLabel = connectorCount > 0 ? "Manage connectors" : "Connect your connectors";
const cloudContent = ( const cloudContent = (
<> <>
@ -1161,9 +1162,7 @@ function AuthenticatedDocumentsSidebarBase({
className="shrink-0 mx-4 mt-6 mb-2.5 h-auto select-none justify-start gap-2 bg-muted px-3 py-1.5 text-xs text-muted-foreground" className="shrink-0 mx-4 mt-6 mb-2.5 h-auto select-none justify-start gap-2 bg-muted px-3 py-1.5 text-xs text-muted-foreground"
> >
<Unplug className="size-4 shrink-0" /> <Unplug className="size-4 shrink-0" />
<span className="truncate"> <span className="truncate">{connectorButtonLabel}</span>
{connectorCount > 0 ? "Manage connectors" : "Connect your connectors"}
</span>
{connectorCount > 0 && ( {connectorCount > 0 && (
<span className="shrink-0 rounded-full bg-muted-foreground/15 px-1.5 py-0.5 text-[10px] font-medium text-muted-foreground"> <span className="shrink-0 rounded-full bg-muted-foreground/15 px-1.5 py-0.5 text-[10px] font-medium text-muted-foreground">
{connectorCount} {connectorCount}

View file

@ -5,12 +5,17 @@ import { ModelSelector } from "./model-selector";
interface ChatHeaderProps { interface ChatHeaderProps {
searchSpaceId: number; searchSpaceId: number;
className?: string; className?: string;
onChatModelSelected?: () => void;
} }
export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) { export function ChatHeader({ searchSpaceId, className, onChatModelSelected }: ChatHeaderProps) {
return ( return (
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<ModelSelector searchSpaceId={searchSpaceId} className={className} /> <ModelSelector
searchSpaceId={searchSpaceId}
className={className}
onChatModelSelected={onChatModelSelected}
/>
</div> </div>
); );
} }

View file

@ -1,7 +1,7 @@
"use client"; "use client";
import { useAtom, useAtomValue } from "jotai"; import { useAtom, useAtomValue } from "jotai";
import { Check, ChevronDown, Search, Settings2 } from "lucide-react"; import { Check, ChevronDown, Search, SlidersHorizontal } from "lucide-react";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import type { UIEvent } from "react"; import type { UIEvent } from "react";
import { useCallback, useMemo, useState } from "react"; import { useCallback, useMemo, useState } from "react";
@ -33,6 +33,7 @@ import { providerDisplay } from "../settings/model-connections/provider-metadata
interface ModelSelectorProps { interface ModelSelectorProps {
searchSpaceId: number; searchSpaceId: number;
className?: string; className?: string;
onChatModelSelected?: () => void;
} }
type ChatModel = ModelRead & { type ChatModel = ModelRead & {
@ -42,6 +43,8 @@ type ChatModel = ModelRead & {
provider: string; provider: string;
}; };
const AUTO_CHAT_MODEL_ID = 0;
function connectionLabel(connection: ConnectionRead) { function connectionLabel(connection: ConnectionRead) {
if (connection.scope === "GLOBAL") return "Global"; if (connection.scope === "GLOBAL") return "Global";
return providerDisplay(connection.provider).name; return providerDisplay(connection.provider).name;
@ -73,6 +76,17 @@ function modelName(model: ChatModel) {
return name; return name;
} }
function filterChatModels(models: ChatModel[], search: string) {
const normalized = search.trim().toLowerCase();
if (!normalized) return models;
return models.filter((model) =>
[modelName(model), model.model_id, model.connectionLabel]
.join(" ")
.toLowerCase()
.includes(normalized)
);
}
function groupedModels(models: ChatModel[]) { function groupedModels(models: ChatModel[]) {
return models.reduce<Record<string, ChatModel[]>>((groups, model) => { return models.reduce<Record<string, ChatModel[]>>((groups, model) => {
const key = model.connectionLabel; const key = model.connectionLabel;
@ -82,7 +96,11 @@ function groupedModels(models: ChatModel[]) {
}, {}); }, {});
} }
export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps) { export function ModelSelector({
searchSpaceId,
className,
onChatModelSelected,
}: ModelSelectorProps) {
const router = useRouter(); const router = useRouter();
const isMobile = useIsMobile(); const isMobile = useIsMobile();
const [open, setOpen] = useState(false); const [open, setOpen] = useState(false);
@ -95,25 +113,37 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
const [{ data: roles }] = useAtom(modelRolesAtom); const [{ data: roles }] = useAtom(modelRolesAtom);
const updateRoles = useAtomValue(updateModelRolesMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom);
const chatModels = useMemo(() => { const allChatModels = useMemo(
const normalized = search.trim().toLowerCase(); () => flattenChatModels([...globalConnections, ...connections]),
const models = flattenChatModels([...globalConnections, ...connections]); [globalConnections, connections]
if (!normalized) return models; );
return models.filter((model) =>
[modelName(model), model.model_id, model.connectionLabel]
.join(" ")
.toLowerCase()
.includes(normalized)
);
}, [globalConnections, connections, search]);
const selected = chatModels.find((model) => model.id === roles?.chat_model_id); const visibleChatModels = useMemo(
const groups = groupedModels(chatModels); () => filterChatModels(allChatModels, search),
[allChatModels, search]
);
const chatModelsById = useMemo(
() => new Map(allChatModels.map((model) => [model.id, model])),
[allChatModels]
);
const selectedModelId = roles?.chat_model_id ?? AUTO_CHAT_MODEL_ID;
const selected = chatModelsById.get(selectedModelId);
const groups = useMemo(() => groupedModels(visibleChatModels), [visibleChatModels]);
const loading = globalLoading || connectionsLoading; const loading = globalLoading || connectionsLoading;
const hasSearchQuery = search.trim().length > 0;
function handleOpenChange(nextOpen: boolean) {
if (!nextOpen) setSearch("");
setOpen(nextOpen);
}
function selectModel(modelId: number) { function selectModel(modelId: number) {
updateRoles.mutate({ chat_model_id: modelId }); updateRoles.mutate({ chat_model_id: modelId });
setSearch("");
setOpen(false); setOpen(false);
requestAnimationFrame(() => {
onChatModelSelected?.();
});
} }
function manageModelConnections() { function manageModelConnections() {
@ -152,7 +182,7 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
<button <button
type="button" type="button"
className="flex w-full items-center justify-between rounded-md px-3 py-2 text-left transition-colors hover:bg-accent hover:text-accent-foreground" className="flex w-full items-center justify-between rounded-md px-3 py-2 text-left transition-colors hover:bg-accent hover:text-accent-foreground"
onClick={() => selectModel(0)} onClick={() => selectModel(AUTO_CHAT_MODEL_ID)}
> >
<div className="min-w-0 flex-1"> <div className="min-w-0 flex-1">
<div className="flex min-w-0 items-center gap-2 font-medium"> <div className="flex min-w-0 items-center gap-2 font-medium">
@ -160,7 +190,7 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
<span className="truncate">Auto</span> <span className="truncate">Auto</span>
</div> </div>
</div> </div>
{(roles?.chat_model_id ?? 0) === 0 ? <Check className="h-4 w-4" /> : null} {selectedModelId === AUTO_CHAT_MODEL_ID ? <Check className="h-4 w-4" /> : null}
</button> </button>
{loading ? ( {loading ? (
<div className="flex items-center justify-center py-8"> <div className="flex items-center justify-center py-8">
@ -168,7 +198,9 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
</div> </div>
) : Object.keys(groups).length === 0 ? ( ) : Object.keys(groups).length === 0 ? (
<div className="px-3 py-8 text-center text-sm text-muted-foreground"> <div className="px-3 py-8 text-center text-sm text-muted-foreground">
No enabled chat models. Add or enable models in Settings. {hasSearchQuery
? "No matching chat models."
: "No enabled chat models. Add or enable models in Settings."}
</div> </div>
) : ( ) : (
Object.entries(groups).map(([connection, models]) => ( Object.entries(groups).map(([connection, models]) => (
@ -228,7 +260,7 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
className="w-full justify-start rounded-md bg-foreground/5 hover:bg-foreground/10 hover:text-foreground" className="w-full justify-start rounded-md bg-foreground/5 hover:bg-foreground/10 hover:text-foreground"
onClick={manageModelConnections} onClick={manageModelConnections}
> >
<Settings2 className="mr-2 h-4 w-4" /> Manage models <SlidersHorizontal className="h-4 w-4" /> Manage models
</Button> </Button>
</div> </div>
</div> </div>
@ -259,7 +291,7 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
if (isMobile) { if (isMobile) {
return ( return (
<Drawer open={open} onOpenChange={setOpen}> <Drawer open={open} onOpenChange={handleOpenChange}>
<DrawerTrigger asChild>{trigger}</DrawerTrigger> <DrawerTrigger asChild>{trigger}</DrawerTrigger>
<DrawerContent className="max-h-[85vh]"> <DrawerContent className="max-h-[85vh]">
<DrawerHandle /> <DrawerHandle />
@ -273,9 +305,12 @@ export function ModelSelector({ searchSpaceId, className }: ModelSelectorProps)
} }
return ( return (
<Popover open={open} onOpenChange={setOpen}> <Popover open={open} onOpenChange={handleOpenChange}>
<PopoverTrigger asChild>{trigger}</PopoverTrigger> <PopoverTrigger asChild>{trigger}</PopoverTrigger>
<PopoverContent align="start" className="w-[340px] p-0"> <PopoverContent
align="start"
className="w-[340px] border border-popover-border bg-popover p-0 text-popover-foreground shadow-md"
>
{content} {content}
</PopoverContent> </PopoverContent>
</Popover> </Popover>