feat: implement token usage tracking for LLM calls with new accumulator and callback

This commit is contained in:
Anish Sarkar 2026-04-14 13:40:32 +05:30
parent 917f35eb33
commit 3cfe53fb7f
6 changed files with 223 additions and 4 deletions

View file

@ -30,6 +30,7 @@ from app.db import (
NewChatThread,
Permission,
SearchSpace,
TokenUsage,
User,
get_async_session,
shielded_async_session,
@ -45,6 +46,7 @@ from app.schemas.new_chat import (
NewChatThreadWithMessages,
PublicChatSnapshotCreateResponse,
PublicChatSnapshotListResponse,
TokenUsageSummary,
RegenerateRequest,
ResumeRequest,
ThreadHistoryLoadResponse,
@ -473,10 +475,13 @@ async def get_thread_messages(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Get messages with their authors loaded
# Get messages with their authors and token usage loaded
messages_result = await session.execute(
select(NewChatMessage)
.options(selectinload(NewChatMessage.author))
.options(
selectinload(NewChatMessage.author),
selectinload(NewChatMessage.token_usage),
)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
)
@ -493,6 +498,7 @@ async def get_thread_messages(
author_id=msg.author_id,
author_display_name=msg.author.display_name if msg.author else None,
author_avatar_url=msg.author.avatar_url if msg.author else None,
token_usage=TokenUsageSummary.model_validate(msg.token_usage) if msg.token_usage else None,
)
for msg in db_messages
]
@ -530,7 +536,11 @@ async def get_thread_full(
try:
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.options(
selectinload(NewChatThread.messages).selectinload(
NewChatMessage.token_usage
),
)
.filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
@ -935,6 +945,24 @@ async def append_message(
# flush assigns the PK/defaults without a round-trip SELECT
await session.flush()
# Persist token usage if provided (for assistant messages)
token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
token_usage_record = TokenUsage(
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
usage_type="chat",
thread_id=thread_id,
message_id=db_message.id,
search_space_id=thread.search_space_id,
user_id=user.id,
)
session.add(token_usage_record)
await session.commit()
# Return the in-memory object (already has id from flush) instead of

View file

@ -34,6 +34,14 @@ class NewChatMessageCreate(NewChatMessageBase):
thread_id: int
class TokenUsageSummary(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
model_breakdown: dict | None = None
model_config = ConfigDict(from_attributes=True)
class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
"""Schema for reading a message."""
@ -41,6 +49,7 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
author_id: UUID | None = None
author_display_name: str | None = None
author_avatar_url: str | None = None
token_usage: TokenUsageSummary | None = None
model_config = ConfigDict(from_attributes=True)

View file

@ -970,6 +970,7 @@ class ChatLiteLLMRouter(BaseChatModel):
messages=formatted_messages,
stop=stop,
stream=True,
stream_options={"include_usage": True},
**call_kwargs,
)
except ContextWindowExceededError as e:

View file

@ -22,10 +22,13 @@ litellm.drop_params = True
# Memory controls: prevent unbounded internal accumulation
litellm.telemetry = False
litellm.cache = None
litellm.success_callback = []
litellm.failure_callback = []
litellm.input_callback = []
from app.services.token_tracking_service import token_tracker
litellm.callbacks = [token_tracker]
logger = logging.getLogger(__name__)

View file

@ -0,0 +1,129 @@
"""
Token usage tracking via LiteLLM custom callback.
Uses a ContextVar-scoped accumulator to group all LLM calls within a single
async request/turn. The accumulated data is emitted via SSE and persisted
when the frontend calls appendMessage.
"""
from __future__ import annotations
import dataclasses
import logging
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import Any
from litellm.integrations.custom_logger import CustomLogger
logger = logging.getLogger(__name__)
@dataclass
class TokenCallRecord:
model: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
@dataclass
class TurnTokenAccumulator:
"""Accumulates token usage across all LLM calls within a single user turn."""
calls: list[TokenCallRecord] = field(default_factory=list)
def add(
self,
model: str,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
) -> None:
self.calls.append(
TokenCallRecord(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
)
def per_message_summary(self) -> dict[str, dict[str, int]]:
"""Return token counts grouped by model name."""
by_model: dict[str, dict[str, int]] = {}
for c in self.calls:
entry = by_model.setdefault(
c.model,
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
)
entry["prompt_tokens"] += c.prompt_tokens
entry["completion_tokens"] += c.completion_tokens
entry["total_tokens"] += c.total_tokens
return by_model
@property
def grand_total(self) -> int:
return sum(c.total_tokens for c in self.calls)
@property
def total_prompt_tokens(self) -> int:
return sum(c.prompt_tokens for c in self.calls)
@property
def total_completion_tokens(self) -> int:
return sum(c.completion_tokens for c in self.calls)
def serialized_calls(self) -> list[dict[str, Any]]:
return [dataclasses.asdict(c) for c in self.calls]
_turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
"_turn_accumulator", default=None
)
def start_turn() -> TurnTokenAccumulator:
"""Create a fresh accumulator for the current async context and return it."""
acc = TurnTokenAccumulator()
_turn_accumulator.set(acc)
return acc
def get_current_accumulator() -> TurnTokenAccumulator | None:
return _turn_accumulator.get()
class TokenTrackingCallback(CustomLogger):
"""LiteLLM callback that captures token usage into the turn accumulator."""
async def async_log_success_event(
self,
kwargs: dict[str, Any],
response_obj: Any,
start_time: Any,
end_time: Any,
) -> None:
acc = _turn_accumulator.get()
if acc is None:
return
usage = getattr(response_obj, "usage", None)
if not usage:
return
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
total_tokens = getattr(usage, "total_tokens", 0) or 0
model = kwargs.get("model", "unknown")
acc.add(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
token_tracker = TokenTrackingCallback()

View file

@ -1170,6 +1170,10 @@ async def stream_new_chat(
_t_total = time.perf_counter()
log_system_snapshot("stream_new_chat_START")
from app.services.token_tracking_service import start_turn
accumulator = start_turn()
session = async_session_maker()
try:
# Mark AI as responding to this user for live collaboration
@ -1527,6 +1531,17 @@ async def stream_new_chat(
if stream_result.is_interrupted:
if title_task is not None and not title_task.done():
title_task.cancel()
usage_summary = accumulator.per_message_summary()
if usage_summary:
yield streaming_service.format_data("token-usage", {
"usage": usage_summary,
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
"call_details": accumulator.serialized_calls(),
})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
@ -1548,6 +1563,16 @@ async def stream_new_chat(
chat_id, generated_title
)
usage_summary = accumulator.per_message_summary()
if usage_summary:
yield streaming_service.format_data("token-usage", {
"usage": usage_summary,
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
"call_details": accumulator.serialized_calls(),
})
# Fire background memory extraction if the agent didn't handle it.
# Shared threads write to team memory; private threads write to user memory.
if not stream_result.agent_called_update_memory:
@ -1646,6 +1671,10 @@ async def stream_resume_chat(
stream_result = StreamResult()
_t_total = time.perf_counter()
from app.services.token_tracking_service import start_turn
accumulator = start_turn()
session = async_session_maker()
try:
if user_id:
@ -1769,11 +1798,31 @@ async def stream_resume_chat(
chat_id,
)
if stream_result.is_interrupted:
usage_summary = accumulator.per_message_summary()
if usage_summary:
yield streaming_service.format_data("token-usage", {
"usage": usage_summary,
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
"call_details": accumulator.serialized_calls(),
})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
return
usage_summary = accumulator.per_message_summary()
if usage_summary:
yield streaming_service.format_data("token-usage", {
"usage": usage_summary,
"prompt_tokens": accumulator.total_prompt_tokens,
"completion_tokens": accumulator.total_completion_tokens,
"total_tokens": accumulator.grand_total,
"call_details": accumulator.serialized_calls(),
})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()