feat: enhance LLM response handling and token usage tracking in chat services and UI components

This commit is contained in:
Anish Sarkar 2026-04-14 15:29:02 +05:30
parent 5510c1de03
commit f21bdc0668
4 changed files with 67 additions and 37 deletions

View file

@ -820,7 +820,9 @@ class ChatLiteLLMRouter(BaseChatModel):
)
# Convert response to ChatResult with potential tool calls
message = self._convert_response_to_message(response.choices[0].message)
message = self._convert_response_to_message(
response.choices[0].message, response=response
)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@ -886,7 +888,9 @@ class ChatLiteLLMRouter(BaseChatModel):
)
# Convert response to ChatResult with potential tool calls
message = self._convert_response_to_message(response.choices[0].message)
message = self._convert_response_to_message(
response.choices[0].message, response=response
)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@ -1076,7 +1080,9 @@ class ChatLiteLLMRouter(BaseChatModel):
return result
def _convert_response_to_message(self, response_message: Any) -> AIMessage:
def _convert_response_to_message(
self, response_message: Any, response: Any = None
) -> AIMessage:
"""Convert a LiteLLM response message to a LangChain AIMessage."""
import json
@ -1099,9 +1105,22 @@ class ChatLiteLLMRouter(BaseChatModel):
tool_call["args"] = tc.function.arguments
tool_calls.append(tool_call)
extra_kwargs: dict[str, Any] = {}
if response:
usage = getattr(response, "usage", None)
if usage:
extra_kwargs["usage_metadata"] = {
"input_tokens": getattr(usage, "prompt_tokens", 0) or 0,
"output_tokens": getattr(usage, "completion_tokens", 0) or 0,
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
}
extra_kwargs["response_metadata"] = {
"model_name": getattr(response, "model", "unknown"),
}
if tool_calls:
return AIMessage(content=content, tool_calls=tool_calls)
return AIMessage(content=content)
return AIMessage(content=content, tool_calls=tool_calls, **extra_kwargs)
return AIMessage(content=content, **extra_kwargs)
def _convert_delta_to_chunk(self, delta: Any) -> AIMessageChunk | None:
"""Convert a streaming delta to an AIMessageChunk."""

View file

@ -4,6 +4,10 @@ Token usage tracking via LiteLLM custom callback.
Uses a ContextVar-scoped accumulator to group all LLM calls within a single
async request/turn. The accumulated data is emitted via SSE and persisted
when the frontend calls appendMessage.
Agent LLM calls are captured automatically via the async callback.
Title-generation usage is added explicitly from the LangChain response
metadata to avoid callback-timing issues.
"""
from __future__ import annotations

View file

@ -1459,22 +1459,35 @@ async def stream_new_chat(
)
is_first_response = (assistant_count_result.scalar() or 0) == 0
title_task: asyncio.Task[str | None] | None = None
title_task: asyncio.Task[tuple[str | None, dict[str, int] | None]] | None = None
if is_first_response:
async def _generate_title() -> str | None:
async def _generate_title() -> tuple[str | None, dict[str, int] | None]:
"""Return (title, usage_dict) where usage_dict has model/prompt/completion/total."""
try:
title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm
title_result = await title_chain.ainvoke(
{"user_query": user_query[:500]}
)
if title_result and hasattr(title_result, "content"):
raw_title = title_result.content.strip()
if raw_title and len(raw_title) <= 100:
return raw_title.strip("\"'")
usage_dict: dict[str, int] | None = None
if title_result:
um = getattr(title_result, "usage_metadata", None)
if um:
rm = getattr(title_result, "response_metadata", None) or {}
raw_model = rm.get("model_name", "unknown")
usage_dict = {
"model": raw_model.split("/", 1)[-1] if "/" in raw_model else raw_model,
"prompt_tokens": um.get("input_tokens", 0),
"completion_tokens": um.get("output_tokens", 0),
"total_tokens": um.get("total_tokens", 0),
}
if hasattr(title_result, "content"):
raw_title = title_result.content.strip()
if raw_title and len(raw_title) <= 100:
return raw_title.strip("\"'"), usage_dict
return None, usage_dict
except Exception:
pass
return None
return None, None
title_task = asyncio.create_task(_generate_title())
@ -1506,7 +1519,7 @@ async def stream_new_chat(
# Inject title update mid-stream as soon as the background task finishes
if title_task is not None and title_task.done() and not title_emitted:
generated_title = title_task.result()
generated_title, _title_usage = title_task.result()
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
@ -1532,7 +1545,6 @@ async def stream_new_chat(
if title_task is not None and not title_task.done():
title_task.cancel()
await asyncio.sleep(0.2)
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
@ -1554,7 +1566,7 @@ async def stream_new_chat(
# If the title task didn't finish during streaming, await it now
if title_task is not None and not title_emitted:
generated_title = await title_task
generated_title, _title_usage = await title_task
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
@ -1568,7 +1580,6 @@ async def stream_new_chat(
chat_id, generated_title
)
await asyncio.sleep(0.2)
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] normal new_chat: calls=%d total=%d summary=%s",
@ -1807,7 +1818,6 @@ async def stream_resume_chat(
time.perf_counter() - _t_stream_start,
chat_id,
)
await asyncio.sleep(0.2)
if stream_result.is_interrupted:
usage_summary = accumulator.per_message_summary()
_perf_log.info(

View file

@ -1,4 +1,5 @@
import {
ActionBarMorePrimitive,
ActionBarPrimitive,
AuiIf,
ErrorPrimitive,
@ -40,14 +41,7 @@ import {
DrawerHeader,
DrawerTitle,
} from "@/components/ui/drawer";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { DropdownMenuLabel } from "@/components/ui/dropdown-menu";
import { Button } from "@/components/ui/button";
import { useComments } from "@/hooks/use-comments";
import { useMediaQuery } from "@/hooks/use-media-query";
@ -397,14 +391,17 @@ const MessageInfoDropdown: FC = () => {
const hasUsage = usage && usage.total_tokens > 0;
return (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<ActionBarMorePrimitive.Root>
<ActionBarMorePrimitive.Trigger asChild>
<Button variant="ghost" size="icon" className="aui-button-icon size-6 p-1">
<MoreHorizontalIcon className="size-4" />
<span className="sr-only">More</span>
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="start" className="min-w-[180px]">
</ActionBarMorePrimitive.Trigger>
<ActionBarMorePrimitive.Content
align="start"
className="bg-muted text-popover-foreground z-50 max-h-(--radix-dropdown-menu-content-available-height) min-w-[180px] origin-(--radix-dropdown-menu-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border dark:border-neutral-700 p-1 shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2"
>
{createdAt && (
<DropdownMenuLabel className="text-xs text-muted-foreground font-normal select-none">
{formatMessageDate(createdAt)}
@ -412,27 +409,27 @@ const MessageInfoDropdown: FC = () => {
)}
{hasUsage && (
<>
<DropdownMenuSeparator />
<ActionBarMorePrimitive.Separator className="bg-border mx-2 my-1 h-px" />
{models.length > 0 ? (
models.map(([model, counts]) => (
<DropdownMenuItem key={model} className="flex-col items-start gap-0.5 cursor-default" onSelect={(e) => e.preventDefault()}>
<ActionBarMorePrimitive.Item key={model} className="focus:bg-neutral-200 dark:focus:bg-neutral-700 relative flex cursor-default flex-col items-start gap-0.5 rounded-sm px-2 py-1.5 text-sm outline-hidden select-none" onSelect={(e) => e.preventDefault()}>
<span className="text-xs font-medium">{model}</span>
<span className="text-xs text-muted-foreground">
{counts.total_tokens.toLocaleString()} tokens
</span>
</DropdownMenuItem>
</ActionBarMorePrimitive.Item>
))
) : (
<DropdownMenuItem className="flex-col items-start gap-0.5 cursor-default" onSelect={(e) => e.preventDefault()}>
<ActionBarMorePrimitive.Item className="focus:bg-neutral-200 dark:focus:bg-neutral-700 relative flex cursor-default flex-col items-start gap-0.5 rounded-sm px-2 py-1.5 text-sm outline-hidden select-none" onSelect={(e) => e.preventDefault()}>
<span className="text-xs text-muted-foreground">
{usage.total_tokens.toLocaleString()} tokens
</span>
</DropdownMenuItem>
</ActionBarMorePrimitive.Item>
)}
</>
)}
</DropdownMenuContent>
</DropdownMenu>
</ActionBarMorePrimitive.Content>
</ActionBarMorePrimitive.Root>
);
};