From 8fe9c21e765267f56d13476b5c59292718cddd38 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 03:08:35 +0530 Subject: [PATCH] feat(token-tracking): add model metadata registration and enhance token usage tracking --- .../app/services/token_tracking_service.py | 57 ++++++++++- .../chat/streaming/flows/shared/llm_bundle.py | 27 +++++- .../assistant-ui/assistant-message.tsx | 95 +++++++++++++++---- .../assistant-ui/token-usage-context.tsx | 32 +++---- 4 files changed, 167 insertions(+), 44 deletions(-) diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py index 3f07e6f9e..8383770a6 100644 --- a/surfsense_backend/app/services/token_tracking_service.py +++ b/surfsense_backend/app/services/token_tracking_service.py @@ -40,6 +40,10 @@ class TokenCallRecord: total_tokens: int cost_micros: int = 0 call_kind: str = "chat" + model_ref: str | None = None + model_id: str | None = None + display_name: str | None = None + provider: str | None = None @dataclass @@ -47,6 +51,24 @@ class TurnTokenAccumulator: """Accumulates token usage across all LLM calls within a single user turn.""" calls: list[TokenCallRecord] = field(default_factory=list) + model_metadata: dict[str, dict[str, str | None]] = field(default_factory=dict) + + def register_model_metadata( + self, + *, + model: str, + model_ref: str | None, + model_id: str | None, + display_name: str | None, + provider: str | None, + ) -> None: + """Attach resolved model metadata for later LiteLLM callback attribution.""" + self.model_metadata[model] = { + "model_ref": model_ref, + "model_id": model_id, + "display_name": display_name, + "provider": provider, + } def add( self, @@ -57,9 +79,14 @@ class TurnTokenAccumulator: cost_micros: int = 0, call_kind: str = "chat", ) -> None: + metadata = self.model_metadata.get(model, {}) self.calls.append( TokenCallRecord( model=model, + model_ref=metadata.get("model_ref"), + model_id=metadata.get("model_id"), + display_name=metadata.get("display_name"), + provider=metadata.get("provider"), prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, @@ -68,13 +95,18 @@ class TurnTokenAccumulator: ) ) - def per_message_summary(self) -> dict[str, dict[str, int]]: + def per_message_summary(self) -> dict[str, dict[str, Any]]: """Return token counts (and cost) grouped by model name.""" - by_model: dict[str, dict[str, int]] = {} + by_model: dict[str, dict[str, Any]] = {} for c in self.calls: entry = by_model.setdefault( c.model, { + "model": c.model, + "model_ref": c.model_ref, + "model_id": c.model_id, + "display_name": c.display_name, + "provider": c.provider, "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, @@ -142,6 +174,27 @@ def get_current_accumulator() -> TurnTokenAccumulator | None: return _turn_accumulator.get() +def register_model_usage_metadata( + *, + model: str, + model_ref: str | None, + model_id: str | None, + display_name: str | None, + provider: str | None, +) -> None: + """Register resolved model metadata with the current turn, if one exists.""" + acc = _turn_accumulator.get() + if acc is None: + return + acc.register_model_metadata( + model=model, + model_ref=model_ref, + model_id=model_id, + display_name=display_name, + provider=provider, + ) + + @asynccontextmanager async def scoped_turn() -> AsyncIterator[TurnTokenAccumulator]: """Async context manager that scopes a fresh ``TurnTokenAccumulator`` diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py index cfd50950e..b318ee3ce 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py @@ -24,6 +24,7 @@ from app.config import config from app.db import Model, SearchSpace from app.services.model_capabilities import has_capability from app.services.model_resolver import to_litellm +from app.services.token_tracking_service import register_model_usage_metadata def _agent_config_from_resolved( @@ -104,10 +105,19 @@ async def load_llm_bundle( f"Failed to load chat model with id {config_id}", ) model_string, litellm_kwargs = to_litellm(model.connection, model.model_id) + display_name = model.display_name or model.model_id + provider = model.connection.provider or "" + register_model_usage_metadata( + model=model_string, + model_ref=f"db:{model.id}", + model_id=model.model_id, + display_name=display_name, + provider=provider, + ) agent_config = _agent_config_from_resolved( config_id=config_id, - config_name=model.display_name or model.model_id, - provider=model.connection.provider or "", + config_name=display_name, + provider=provider, model_name=model.model_id, api_key=model.connection.api_key, api_base=model.connection.base_url, @@ -135,10 +145,19 @@ async def load_llm_bundle( if not global_connection: return None, None, f"Failed to load global connection for model {config_id}" model_string, litellm_kwargs = to_litellm(global_connection, global_model["model_id"]) + display_name = global_model.get("display_name") or global_model.get("model_id") + provider = global_connection.get("provider") or "" + register_model_usage_metadata( + model=model_string, + model_ref=f"global:{config_id}", + model_id=global_model["model_id"], + display_name=display_name, + provider=provider, + ) agent_config = _agent_config_from_resolved( config_id=config_id, - config_name=global_model.get("display_name") or global_model.get("model_id"), - provider=global_connection.get("provider") or "", + config_name=display_name, + provider=provider, model_name=global_model["model_id"], api_key=global_connection.get("api_key"), api_base=global_connection.get("base_url"), diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index d084ac0fd..59006b26e 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -26,9 +26,9 @@ import type { FC } from "react"; import { useEffect, useMemo, useRef, useState } from "react"; import { commentsEnabledAtom, targetCommentIdAtom } from "@/atoms/chat/current-thread.atom"; import { - globalNewLLMConfigsAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; + globalModelConnectionsAtom, + modelConnectionsAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { CitationMetadataProvider, @@ -37,7 +37,10 @@ import { import { MarkdownText } from "@/components/assistant-ui/markdown-text"; import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part"; import { RevertTurnButton } from "@/components/assistant-ui/revert-turn-button"; -import { useTokenUsage } from "@/components/assistant-ui/token-usage-context"; +import { + type TokenUsageModelBreakdown, + useTokenUsage, +} from "@/components/assistant-ui/token-usage-context"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { CommentPanelContainer } from "@/components/chat-comments/comment-panel-container/comment-panel-container"; import { CommentSheet } from "@/components/chat-comments/comment-sheet/comment-sheet"; @@ -268,29 +271,81 @@ function formatTurnCost(micros: number): string { return "$0"; } +function normalizeUsageModelKey(modelKey: string): string { + return modelKey.trim().replace(/^~/, ""); +} + +function bareModelKey(modelKey: string): string { + const normalized = normalizeUsageModelKey(modelKey); + const parts = normalized.split("/"); + return parts[parts.length - 1] || normalized; +} + +function inferProviderFromModelKey(modelKey: string) { + const normalized = normalizeUsageModelKey(modelKey); + const [provider] = normalized.split("/"); + return provider && provider !== normalized ? provider : null; +} + +function titleCaseModelPart(part: string) { + if (!part) return ""; + const upper = part.toUpperCase(); + if (/^\d+(\.\d+)?[BKM]$/.test(upper)) return upper; + if (["gpt", "oai", "api", "llm", "vlm"].includes(part.toLowerCase())) return upper; + return part.charAt(0).toUpperCase() + part.slice(1); +} + +function humanizeModelId(modelKey: string): string { + const bare = bareModelKey(modelKey) + .replace(/:latest$/i, "") + .replace(/[-_]+/g, " ") + .trim(); + if (!bare) return modelKey; + return bare.split(/\s+/).map(titleCaseModelPart).join(" "); +} + const MessageInfoDropdown: FC<{ chatTurnId: string | null | undefined }> = ({ chatTurnId }) => { const messageId = useAuiState(({ message }) => message?.id); const createdAt = useAuiState(({ message }) => message?.createdAt); const usage = useTokenUsage(messageId); - const { data: localConfigs } = useAtomValue(newLLMConfigsAtom); - const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom); + const { data: globalConnections = [] } = useAtomValue(globalModelConnectionsAtom); + const { data: localConnections = [] } = useAtomValue(modelConnectionsAtom); - const configByModel = useMemo(() => { - const map = new Map(); - for (const c of [...(globalConfigs ?? []), ...(localConfigs ?? [])]) { - map.set(c.model_name, { name: c.name, provider: c.provider }); + const modelConnectionByKey = useMemo(() => { + const map = new Map(); + for (const connection of [...globalConnections, ...localConnections]) { + for (const model of connection.models) { + const normalizedModelId = normalizeUsageModelKey(model.model_id); + const entry = { + name: model.display_name || model.model_id, + provider: connection.provider, + modelId: model.model_id, + }; + map.set(model.model_id, entry); + map.set(normalizedModelId, entry); + map.set(bareModelKey(model.model_id), entry); + } } return map; - }, [localConfigs, globalConfigs]); + }, [globalConnections, localConnections]); - const resolveModel = (modelKey: string) => { - const parts = modelKey.split("/"); - const bare = parts[parts.length - 1] ?? modelKey; - const config = configByModel.get(modelKey) ?? configByModel.get(bare); - return config - ? { name: config.name, icon: getProviderIcon(config.provider, { className: "size-3.5" }) } - : { name: modelKey, icon: null }; + const resolveModel = (modelKey: string, counts: TokenUsageModelBreakdown) => { + const normalizedKey = normalizeUsageModelKey(counts.model_id || counts.model || modelKey); + const connectionModel = + modelConnectionByKey.get(modelKey) ?? + modelConnectionByKey.get(normalizeUsageModelKey(modelKey)) ?? + modelConnectionByKey.get(normalizedKey) ?? + modelConnectionByKey.get(bareModelKey(normalizedKey)); + const provider = + counts.provider || connectionModel?.provider || inferProviderFromModelKey(normalizedKey); + const modelId = counts.model_id || connectionModel?.modelId || modelKey; + const name = counts.display_name || connectionModel?.name || humanizeModelId(modelId); + return { + name, + modelId, + icon: provider ? getProviderIcon(provider, { className: "size-3.5 shrink-0" }) : null, + }; }; const modelBreakdown = usage ? (usage.usage ?? usage.model_breakdown) : undefined; @@ -319,12 +374,12 @@ const MessageInfoDropdown: FC<{ chatTurnId: string | null | undefined }> = ({ ch {models.length > 0 ? ( models.map(([model, counts]) => { - const { name, icon } = resolveModel(model); + const { name, icon } = resolveModel(model, counts); const costMicros = counts.cost_micros; return ( e.preventDefault()} > diff --git a/surfsense_web/components/assistant-ui/token-usage-context.tsx b/surfsense_web/components/assistant-ui/token-usage-context.tsx index dd80bcac3..8db8c2b50 100644 --- a/surfsense_web/components/assistant-ui/token-usage-context.tsx +++ b/surfsense_web/components/assistant-ui/token-usage-context.tsx @@ -9,6 +9,18 @@ import { useSyncExternalStore, } from "react"; +export interface TokenUsageModelBreakdown { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cost_micros?: number; + model?: string | null; + model_ref?: string | null; + model_id?: string | null; + display_name?: string | null; + provider?: string | null; +} + export interface TokenUsageData { prompt_tokens: number; completion_tokens: number; @@ -20,24 +32,8 @@ export interface TokenUsageData { * before the migration won't have it. */ cost_micros?: number; - usage?: Record< - string, - { - prompt_tokens: number; - completion_tokens: number; - total_tokens: number; - cost_micros?: number; - } - >; - model_breakdown?: Record< - string, - { - prompt_tokens: number; - completion_tokens: number; - total_tokens: number; - cost_micros?: number; - } - >; + usage?: Record; + model_breakdown?: Record; } type Listener = () => void;