diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 58eb58f4b..34bf0c09e 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -624,6 +624,7 @@ export default function NewChatPage() { }; const { contentParts, toolCallIndices } = contentPartsState; let wasInterrupted = false; + let tokenUsageData: Record | null = null; // Add placeholder assistant message setMessages((prev) => [ @@ -821,6 +822,10 @@ export default function NewChatPage() { break; } + case "data-token-usage": + tokenUsageData = parsed.data; + break; + case "error": throw new Error(parsed.errorText || "Server error"); } @@ -828,6 +833,16 @@ export default function NewChatPage() { batcher.flush(); + if (tokenUsageData) { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { ...m, metadata: { ...m.metadata, custom: { ...(m.metadata?.custom as Record ?? {}), usage: tokenUsageData } } } + : m + ) + ); + } + // Skip persistence for interrupted messages -- handleResume will persist the final version const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI); if (contentParts.length > 0 && !wasInterrupted) { @@ -835,6 +850,7 @@ export default function NewChatPage() { const savedMessage = await appendMessage(currentThreadId, { role: "assistant", content: finalContent, + token_usage: tokenUsageData ?? undefined, }); // Update message ID from temporary to database ID so comments work immediately @@ -965,6 +981,7 @@ export default function NewChatPage() { toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; + let tokenUsageData: Record | null = null; const existingMsg = messages.find((m) => m.id === assistantMsgId); if (existingMsg && Array.isArray(existingMsg.content)) { @@ -1149,6 +1166,10 @@ export default function NewChatPage() { break; } + case "data-token-usage": + tokenUsageData = parsed.data; + break; + case "error": throw new Error(parsed.errorText || "Server error"); } @@ -1156,12 +1177,23 @@ export default function NewChatPage() { batcher.flush(); + if (tokenUsageData) { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { ...m, metadata: { ...m.metadata, custom: { ...(m.metadata?.custom as Record ?? {}), usage: tokenUsageData } } } + : m + ) + ); + } + const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI); if (contentParts.length > 0) { try { const savedMessage = await appendMessage(resumeThreadId, { role: "assistant", content: finalContent, + token_usage: tokenUsageData ?? undefined, }); const newMsgId = `msg-${savedMessage.id}`; setMessages((prev) => @@ -1319,6 +1351,7 @@ export default function NewChatPage() { }; const { contentParts, toolCallIndices } = contentPartsState; const batcher = new FrameBatchedUpdater(); + let tokenUsageData: Record | null = null; // Add placeholder messages to UI // Always add back the user message (with new query for edit, or original content for reload) @@ -1428,6 +1461,10 @@ export default function NewChatPage() { break; } + case "data-token-usage": + tokenUsageData = parsed.data; + break; + case "error": throw new Error(parsed.errorText || "Server error"); } @@ -1435,6 +1472,16 @@ export default function NewChatPage() { batcher.flush(); + if (tokenUsageData) { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { ...m, metadata: { ...m.metadata, custom: { ...(m.metadata?.custom as Record ?? {}), usage: tokenUsageData } } } + : m + ) + ); + } + // Persist messages after streaming completes const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI); if (contentParts.length > 0) { @@ -1459,6 +1506,7 @@ export default function NewChatPage() { const savedMessage = await appendMessage(threadId, { role: "assistant", content: finalContent, + token_usage: tokenUsageData ?? undefined, }); // Update assistant message ID to database ID diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 764acabba..25a579947 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -15,6 +15,7 @@ import { ExternalLink, Globe, MessageSquare, + MoreHorizontalIcon, RefreshCwIcon, } from "lucide-react"; import dynamic from "next/dynamic"; @@ -39,6 +40,14 @@ import { DrawerHeader, DrawerTitle, } from "@/components/ui/drawer"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { Button } from "@/components/ui/button"; import { useComments } from "@/hooks/use-comments"; import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI } from "@/hooks/use-platform"; @@ -366,6 +375,56 @@ export const MessageError: FC = () => { ); }; +const TokenUsageDropdown: FC = () => { + const usage = useAuiState(({ message }) => { + const custom = message?.metadata?.custom as Record | undefined; + return custom?.usage as Record | undefined; + }); + + if (!usage) return null; + + const totalTokens = (usage.total_tokens as number) ?? 0; + if (totalTokens === 0) return null; + + const modelBreakdown = (usage.usage ?? usage.model_breakdown) as + | Record + | undefined; + + const models = modelBreakdown ? Object.entries(modelBreakdown) : []; + + return ( + + + + + + + Token Usage + + {models.length > 0 ? ( + models.map(([model, counts]) => ( + e.preventDefault()}> + {model} + + {counts.total_tokens.toLocaleString()} tokens + + + )) + ) : ( + e.preventDefault()}> + + {totalTokens.toLocaleString()} tokens + + + )} + + + ); +}; + const AssistantMessageInner: FC = () => { const isMobile = !useMediaQuery("(min-width: 768px)"); @@ -427,7 +486,7 @@ const AssistantMessageInner: FC = () => { )} -
+
@@ -624,6 +683,7 @@ const AssistantActionBar: FC = () => { )} + ); }; diff --git a/surfsense_web/lib/chat/message-utils.ts b/surfsense_web/lib/chat/message-utils.ts index 7c0da03c4..6ec5bd53d 100644 --- a/surfsense_web/lib/chat/message-utils.ts +++ b/surfsense_web/lib/chat/message-utils.ts @@ -39,13 +39,16 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike { content = [{ type: "text", text: String(msg.content) }]; } - const metadata = msg.author_id + const metadata = (msg.author_id || msg.token_usage) ? { custom: { - author: { - displayName: msg.author_display_name ?? null, - avatarUrl: msg.author_avatar_url ?? null, - }, + ...(msg.author_id && { + author: { + displayName: msg.author_display_name ?? null, + avatarUrl: msg.author_avatar_url ?? null, + }, + }), + ...(msg.token_usage && { usage: msg.token_usage }), }, } : undefined; diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index d54650d40..e5d77672f 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -238,6 +238,16 @@ export type SSEEvent = | { type: "data-thread-title-update"; data: { threadId: number; title: string } } | { type: "data-interrupt-request"; data: Record } | { type: "data-documents-updated"; data: Record } + | { + type: "data-token-usage"; + data: { + usage: Record; + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + call_details: Array<{ model: string; prompt_tokens: number; completion_tokens: number; total_tokens: number }>; + }; + } | { type: "error"; errorText: string }; /** diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts index 08c08ba78..de9827c32 100644 --- a/surfsense_web/lib/chat/thread-persistence.ts +++ b/surfsense_web/lib/chat/thread-persistence.ts @@ -26,6 +26,13 @@ export interface ThreadRecord { has_comments?: boolean; } +export interface TokenUsageSummary { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + model_breakdown?: Record | null; +} + export interface MessageRecord { id: number; thread_id: number; @@ -35,6 +42,7 @@ export interface MessageRecord { author_id?: string | null; author_display_name?: string | null; author_avatar_url?: string | null; + token_usage?: TokenUsageSummary | null; } export interface ThreadListResponse { @@ -111,11 +119,11 @@ export async function getThreadMessages(threadId: number): Promise { return baseApiService.post(`/api/v1/threads/${threadId}/messages`, undefined, { body: message,