feat: enhance token usage tracking in chat messages with UI integration and dropdown display

This commit is contained in:
Anish Sarkar 2026-04-14 13:40:46 +05:30
parent 3cfe53fb7f
commit 55099a20ac
5 changed files with 137 additions and 8 deletions

View file

@ -624,6 +624,7 @@ export default function NewChatPage() {
}; };
const { contentParts, toolCallIndices } = contentPartsState; const { contentParts, toolCallIndices } = contentPartsState;
let wasInterrupted = false; let wasInterrupted = false;
let tokenUsageData: Record<string, unknown> | null = null;
// Add placeholder assistant message // Add placeholder assistant message
setMessages((prev) => [ setMessages((prev) => [
@ -821,6 +822,10 @@ export default function NewChatPage() {
break; break;
} }
case "data-token-usage":
tokenUsageData = parsed.data;
break;
case "error": case "error":
throw new Error(parsed.errorText || "Server error"); throw new Error(parsed.errorText || "Server error");
} }
@ -828,6 +833,16 @@ export default function NewChatPage() {
batcher.flush(); batcher.flush();
if (tokenUsageData) {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? { ...m, metadata: { ...m.metadata, custom: { ...(m.metadata?.custom as Record<string, unknown> ?? {}), usage: tokenUsageData } } }
: m
)
);
}
// Skip persistence for interrupted messages -- handleResume will persist the final version // Skip persistence for interrupted messages -- handleResume will persist the final version
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI); const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
if (contentParts.length > 0 && !wasInterrupted) { if (contentParts.length > 0 && !wasInterrupted) {
@ -835,6 +850,7 @@ export default function NewChatPage() {
const savedMessage = await appendMessage(currentThreadId, { const savedMessage = await appendMessage(currentThreadId, {
role: "assistant", role: "assistant",
content: finalContent, content: finalContent,
token_usage: tokenUsageData ?? undefined,
}); });
// Update message ID from temporary to database ID so comments work immediately // Update message ID from temporary to database ID so comments work immediately
@ -965,6 +981,7 @@ export default function NewChatPage() {
toolCallIndices: new Map(), toolCallIndices: new Map(),
}; };
const { contentParts, toolCallIndices } = contentPartsState; const { contentParts, toolCallIndices } = contentPartsState;
let tokenUsageData: Record<string, unknown> | null = null;
const existingMsg = messages.find((m) => m.id === assistantMsgId); const existingMsg = messages.find((m) => m.id === assistantMsgId);
if (existingMsg && Array.isArray(existingMsg.content)) { if (existingMsg && Array.isArray(existingMsg.content)) {
@ -1149,6 +1166,10 @@ export default function NewChatPage() {
break; break;
} }
case "data-token-usage":
tokenUsageData = parsed.data;
break;
case "error": case "error":
throw new Error(parsed.errorText || "Server error"); throw new Error(parsed.errorText || "Server error");
} }
@ -1156,12 +1177,23 @@ export default function NewChatPage() {
batcher.flush(); batcher.flush();
if (tokenUsageData) {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? { ...m, metadata: { ...m.metadata, custom: { ...(m.metadata?.custom as Record<string, unknown> ?? {}), usage: tokenUsageData } } }
: m
)
);
}
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI); const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
if (contentParts.length > 0) { if (contentParts.length > 0) {
try { try {
const savedMessage = await appendMessage(resumeThreadId, { const savedMessage = await appendMessage(resumeThreadId, {
role: "assistant", role: "assistant",
content: finalContent, content: finalContent,
token_usage: tokenUsageData ?? undefined,
}); });
const newMsgId = `msg-${savedMessage.id}`; const newMsgId = `msg-${savedMessage.id}`;
setMessages((prev) => setMessages((prev) =>
@ -1319,6 +1351,7 @@ export default function NewChatPage() {
}; };
const { contentParts, toolCallIndices } = contentPartsState; const { contentParts, toolCallIndices } = contentPartsState;
const batcher = new FrameBatchedUpdater(); const batcher = new FrameBatchedUpdater();
let tokenUsageData: Record<string, unknown> | null = null;
// Add placeholder messages to UI // Add placeholder messages to UI
// Always add back the user message (with new query for edit, or original content for reload) // Always add back the user message (with new query for edit, or original content for reload)
@ -1428,6 +1461,10 @@ export default function NewChatPage() {
break; break;
} }
case "data-token-usage":
tokenUsageData = parsed.data;
break;
case "error": case "error":
throw new Error(parsed.errorText || "Server error"); throw new Error(parsed.errorText || "Server error");
} }
@ -1435,6 +1472,16 @@ export default function NewChatPage() {
batcher.flush(); batcher.flush();
if (tokenUsageData) {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? { ...m, metadata: { ...m.metadata, custom: { ...(m.metadata?.custom as Record<string, unknown> ?? {}), usage: tokenUsageData } } }
: m
)
);
}
// Persist messages after streaming completes // Persist messages after streaming completes
const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI); const finalContent = buildContentForPersistence(contentPartsState, TOOLS_WITH_UI);
if (contentParts.length > 0) { if (contentParts.length > 0) {
@ -1459,6 +1506,7 @@ export default function NewChatPage() {
const savedMessage = await appendMessage(threadId, { const savedMessage = await appendMessage(threadId, {
role: "assistant", role: "assistant",
content: finalContent, content: finalContent,
token_usage: tokenUsageData ?? undefined,
}); });
// Update assistant message ID to database ID // Update assistant message ID to database ID

View file

@ -15,6 +15,7 @@ import {
ExternalLink, ExternalLink,
Globe, Globe,
MessageSquare, MessageSquare,
MoreHorizontalIcon,
RefreshCwIcon, RefreshCwIcon,
} from "lucide-react"; } from "lucide-react";
import dynamic from "next/dynamic"; import dynamic from "next/dynamic";
@ -39,6 +40,14 @@ import {
DrawerHeader, DrawerHeader,
DrawerTitle, DrawerTitle,
} from "@/components/ui/drawer"; } from "@/components/ui/drawer";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { Button } from "@/components/ui/button";
import { useComments } from "@/hooks/use-comments"; import { useComments } from "@/hooks/use-comments";
import { useMediaQuery } from "@/hooks/use-media-query"; import { useMediaQuery } from "@/hooks/use-media-query";
import { useElectronAPI } from "@/hooks/use-platform"; import { useElectronAPI } from "@/hooks/use-platform";
@ -366,6 +375,56 @@ export const MessageError: FC = () => {
); );
}; };
const TokenUsageDropdown: FC = () => {
const usage = useAuiState(({ message }) => {
const custom = message?.metadata?.custom as Record<string, unknown> | undefined;
return custom?.usage as Record<string, unknown> | undefined;
});
if (!usage) return null;
const totalTokens = (usage.total_tokens as number) ?? 0;
if (totalTokens === 0) return null;
const modelBreakdown = (usage.usage ?? usage.model_breakdown) as
| Record<string, { prompt_tokens: number; completion_tokens: number; total_tokens: number }>
| undefined;
const models = modelBreakdown ? Object.entries(modelBreakdown) : [];
return (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button variant="ghost" size="icon" className="aui-button-icon size-6 p-1">
<MoreHorizontalIcon className="size-4" />
<span className="sr-only">More</span>
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="start" className="min-w-[180px]">
<DropdownMenuLabel className="text-xs text-muted-foreground font-normal">
Token Usage
</DropdownMenuLabel>
{models.length > 0 ? (
models.map(([model, counts]) => (
<DropdownMenuItem key={model} className="flex-col items-start gap-0.5 cursor-default" onSelect={(e) => e.preventDefault()}>
<span className="text-xs font-medium">{model}</span>
<span className="text-xs text-muted-foreground">
{counts.total_tokens.toLocaleString()} tokens
</span>
</DropdownMenuItem>
))
) : (
<DropdownMenuItem className="flex-col items-start gap-0.5 cursor-default" onSelect={(e) => e.preventDefault()}>
<span className="text-xs text-muted-foreground">
{totalTokens.toLocaleString()} tokens
</span>
</DropdownMenuItem>
)}
</DropdownMenuContent>
</DropdownMenu>
);
};
const AssistantMessageInner: FC = () => { const AssistantMessageInner: FC = () => {
const isMobile = !useMediaQuery("(min-width: 768px)"); const isMobile = !useMediaQuery("(min-width: 768px)");
@ -427,7 +486,7 @@ const AssistantMessageInner: FC = () => {
</div> </div>
)} )}
<div className="aui-assistant-message-footer mt-1 mb-5 ml-2 flex"> <div className="aui-assistant-message-footer mt-1 mb-5 ml-2 flex items-center gap-2">
<AssistantActionBar /> <AssistantActionBar />
</div> </div>
</CitationMetadataProvider> </CitationMetadataProvider>
@ -624,6 +683,7 @@ const AssistantActionBar: FC = () => {
<ClipboardPaste /> <ClipboardPaste />
</TooltipIconButton> </TooltipIconButton>
)} )}
<TokenUsageDropdown />
</ActionBarPrimitive.Root> </ActionBarPrimitive.Root>
); );
}; };

View file

@ -39,13 +39,16 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
content = [{ type: "text", text: String(msg.content) }]; content = [{ type: "text", text: String(msg.content) }];
} }
const metadata = msg.author_id const metadata = (msg.author_id || msg.token_usage)
? { ? {
custom: { custom: {
author: { ...(msg.author_id && {
displayName: msg.author_display_name ?? null, author: {
avatarUrl: msg.author_avatar_url ?? null, displayName: msg.author_display_name ?? null,
}, avatarUrl: msg.author_avatar_url ?? null,
},
}),
...(msg.token_usage && { usage: msg.token_usage }),
}, },
} }
: undefined; : undefined;

View file

@ -238,6 +238,16 @@ export type SSEEvent =
| { type: "data-thread-title-update"; data: { threadId: number; title: string } } | { type: "data-thread-title-update"; data: { threadId: number; title: string } }
| { type: "data-interrupt-request"; data: Record<string, unknown> } | { type: "data-interrupt-request"; data: Record<string, unknown> }
| { type: "data-documents-updated"; data: Record<string, unknown> } | { type: "data-documents-updated"; data: Record<string, unknown> }
| {
type: "data-token-usage";
data: {
usage: Record<string, { prompt_tokens: number; completion_tokens: number; total_tokens: number }>;
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
call_details: Array<{ model: string; prompt_tokens: number; completion_tokens: number; total_tokens: number }>;
};
}
| { type: "error"; errorText: string }; | { type: "error"; errorText: string };
/** /**

View file

@ -26,6 +26,13 @@ export interface ThreadRecord {
has_comments?: boolean; has_comments?: boolean;
} }
export interface TokenUsageSummary {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
model_breakdown?: Record<string, { prompt_tokens: number; completion_tokens: number; total_tokens: number }> | null;
}
export interface MessageRecord { export interface MessageRecord {
id: number; id: number;
thread_id: number; thread_id: number;
@ -35,6 +42,7 @@ export interface MessageRecord {
author_id?: string | null; author_id?: string | null;
author_display_name?: string | null; author_display_name?: string | null;
author_avatar_url?: string | null; author_avatar_url?: string | null;
token_usage?: TokenUsageSummary | null;
} }
export interface ThreadListResponse { export interface ThreadListResponse {
@ -111,11 +119,11 @@ export async function getThreadMessages(threadId: number): Promise<ThreadHistory
} }
/** /**
* Append a message to a thread * Append a message to a thread.
*/ */
export async function appendMessage( export async function appendMessage(
threadId: number, threadId: number,
message: { role: "user" | "assistant" | "system"; content: unknown } message: { role: "user" | "assistant" | "system"; content: unknown; token_usage?: unknown }
): Promise<MessageRecord> { ): Promise<MessageRecord> {
return baseApiService.post<MessageRecord>(`/api/v1/threads/${threadId}/messages`, undefined, { return baseApiService.post<MessageRecord>(`/api/v1/threads/${threadId}/messages`, undefined, {
body: message, body: message,