feat: improve token usage tracking and response handling in chat routes and services

This commit is contained in:
Anish Sarkar 2026-04-14 14:28:31 +05:30
parent 55099a20ac
commit 5af6005163
3 changed files with 39 additions and 4 deletions

View file

@ -46,12 +46,12 @@ from app.schemas.new_chat import (
NewChatThreadWithMessages,
PublicChatSnapshotCreateResponse,
PublicChatSnapshotListResponse,
TokenUsageSummary,
RegenerateRequest,
ResumeRequest,
ThreadHistoryLoadResponse,
ThreadListItem,
ThreadListResponse,
TokenUsageSummary,
)
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user
@ -965,9 +965,17 @@ async def append_message(
await session.commit()
# Return the in-memory object (already has id from flush) instead of
# doing an extra refresh() SELECT.
return db_message
# Build response manually to avoid lazy-loading the token_usage
# relationship after commit (which would trigger MissingGreenlet).
return NewChatMessageRead(
id=db_message.id,
thread_id=db_message.thread_id,
role=db_message.role,
content=db_message.content,
created_at=db_message.created_at,
author_id=db_message.author_id,
token_usage=None,
)
except HTTPException:
raise
@ -1031,6 +1039,7 @@ async def list_messages(
# Get messages
query = (
select(NewChatMessage)
.options(selectinload(NewChatMessage.token_usage))
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
.offset(skip)

View file

@ -87,6 +87,7 @@ def start_turn() -> TurnTokenAccumulator:
"""Create a fresh accumulator for the current async context and return it."""
acc = TurnTokenAccumulator()
_turn_accumulator.set(acc)
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
return acc
@ -106,10 +107,12 @@ class TokenTrackingCallback(CustomLogger):
) -> None:
acc = _turn_accumulator.get()
if acc is None:
logger.debug("[TokenTracking] async_log_success_event fired but no accumulator in context")
return
usage = getattr(response_obj, "usage", None)
if not usage:
logger.debug("[TokenTracking] async_log_success_event fired but response has no usage data")
return
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
@ -124,6 +127,10 @@ class TokenTrackingCallback(CustomLogger):
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
logger.info(
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
model, prompt_tokens, completion_tokens, total_tokens, len(acc.calls),
)
token_tracker = TokenTrackingCallback()

View file

@ -1532,7 +1532,12 @@ async def stream_new_chat(
if title_task is not None and not title_task.done():
title_task.cancel()
await asyncio.sleep(0.2)
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
len(accumulator.calls), accumulator.grand_total, usage_summary,
)
if usage_summary:
yield streaming_service.format_data("token-usage", {
"usage": usage_summary,
@ -1563,7 +1568,12 @@ async def stream_new_chat(
chat_id, generated_title
)
await asyncio.sleep(0.2)
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] normal new_chat: calls=%d total=%d summary=%s",
len(accumulator.calls), accumulator.grand_total, usage_summary,
)
if usage_summary:
yield streaming_service.format_data("token-usage", {
"usage": usage_summary,
@ -1797,8 +1807,13 @@ async def stream_resume_chat(
time.perf_counter() - _t_stream_start,
chat_id,
)
await asyncio.sleep(0.2)
if stream_result.is_interrupted:
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s",
len(accumulator.calls), accumulator.grand_total, usage_summary,
)
if usage_summary:
yield streaming_service.format_data("token-usage", {
"usage": usage_summary,
@ -1814,6 +1829,10 @@ async def stream_resume_chat(
return
usage_summary = accumulator.per_message_summary()
_perf_log.info(
"[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
len(accumulator.calls), accumulator.grand_total, usage_summary,
)
if usage_summary:
yield streaming_service.format_data("token-usage", {
"usage": usage_summary,