feat: implement token usage recording in chat routes and enhance title generation handling

This commit is contained in:
Anish Sarkar 2026-04-14 20:56:07 +05:30
parent 292fcb1a2c
commit f01ddf3f0a
3 changed files with 121 additions and 36 deletions

View file

@ -30,7 +30,6 @@ from app.db import (
NewChatThread,
Permission,
SearchSpace,
TokenUsage,
User,
get_async_session,
shielded_async_session,
@ -53,6 +52,7 @@ from app.schemas.new_chat import (
ThreadListResponse,
TokenUsageSummary,
)
from app.services.token_tracking_service import record_token_usage
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user
from app.utils.rbac import check_permission
@ -949,19 +949,19 @@ async def append_message(
# Persist token usage if provided (for assistant messages)
token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
token_usage_record = TokenUsage(
await record_token_usage(
session,
usage_type="chat",
search_space_id=thread.search_space_id,
user_id=user.id,
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
usage_type="chat",
thread_id=thread_id,
message_id=db_message.id,
search_space_id=thread.search_space_id,
user_id=user.id,
)
session.add(token_usage_record)
await session.commit()

View file

@ -5,9 +5,11 @@ Uses a ContextVar-scoped accumulator to group all LLM calls within a single
async request/turn. The accumulated data is emitted via SSE and persisted
when the frontend calls appendMessage.
Agent LLM calls are captured automatically via the async callback.
Title-generation usage is added explicitly from the LangChain response
metadata to avoid callback-timing issues.
The module also provides ``record_token_usage``, a thin async helper that
creates a ``TokenUsage`` row for *any* usage type (chat, indexing, image
generation, podcasts, ). Call sites should prefer this helper over
constructing ``TokenUsage`` manually so that logging and error handling
stay consistent.
"""
from __future__ import annotations
@ -17,8 +19,12 @@ import logging
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
from litellm.integrations.custom_logger import CustomLogger
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import TokenUsage
logger = logging.getLogger(__name__)
@ -138,3 +144,54 @@ class TokenTrackingCallback(CustomLogger):
token_tracker = TokenTrackingCallback()
# ---------------------------------------------------------------------------
# Persistence helper
# ---------------------------------------------------------------------------
async def record_token_usage(
session: AsyncSession,
*,
usage_type: str,
search_space_id: int,
user_id: UUID,
prompt_tokens: int = 0,
completion_tokens: int = 0,
total_tokens: int = 0,
model_breakdown: dict[str, Any] | None = None,
call_details: dict[str, Any] | None = None,
thread_id: int | None = None,
message_id: int | None = None,
) -> TokenUsage | None:
"""Persist a single ``TokenUsage`` row.
Returns the record on success, ``None`` if persistence failed (the
failure is logged but never propagated so callers don't need to
wrap this in try/except).
"""
try:
record = TokenUsage(
usage_type=usage_type,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model_breakdown=model_breakdown,
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
search_space_id=search_space_id,
user_id=user_id,
)
session.add(record)
logger.debug(
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
usage_type, prompt_tokens, completion_tokens, total_tokens,
)
return record
except Exception:
logger.warning(
"[TokenTracking] failed to record %s token usage", usage_type, exc_info=True,
)
return None

View file

@ -51,7 +51,7 @@ from app.db import (
async_session_maker,
shielded_async_session,
)
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
from app.prompts import TITLE_GENERATION_PROMPT
from app.services.chat_session_state_service import (
clear_ai_responding,
set_ai_responding,
@ -1460,34 +1460,58 @@ async def stream_new_chat(
)
is_first_response = (assistant_count_result.scalar() or 0) == 0
title_task: asyncio.Task[tuple[str | None, dict[str, int] | None]] | None = None
title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None
if is_first_response:
async def _generate_title() -> tuple[str | None, dict[str, int] | None]:
"""Return (title, usage_dict) where usage_dict has model/prompt/completion/total."""
async def _generate_title() -> tuple[str | None, dict | None]:
"""Generate a short title via litellm.acompletion.
Returns (title, usage_dict). Usage is extracted directly from
the response object because litellm fires its async callback
via fire-and-forget ``create_task``, so the
``TokenTrackingCallback`` would run too late. We also blank
the accumulator in this child-task context so the late callback
doesn't double-count.
"""
try:
title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm
title_result = await title_chain.ainvoke(
{"user_query": user_query[:500]}
from litellm import acompletion
from app.services.llm_router_service import LLMRouterService
from app.services.token_tracking_service import _turn_accumulator
_turn_accumulator.set(None)
prompt = TITLE_GENERATION_PROMPT.replace("{user_query}", user_query[:500])
messages = [{"role": "user", "content": prompt}]
if getattr(llm, "model", None) == "auto":
router = LLMRouterService.get_router()
response = await router.acompletion(model="auto", messages=messages)
else:
response = await acompletion(
model=llm.model,
messages=messages,
api_key=getattr(llm, "api_key", None),
api_base=getattr(llm, "api_base", None),
)
usage_dict: dict[str, int] | None = None
if title_result:
um = getattr(title_result, "usage_metadata", None)
if um:
rm = getattr(title_result, "response_metadata", None) or {}
raw_model = rm.get("model_name", "unknown")
usage_dict = {
"model": raw_model.split("/", 1)[-1] if "/" in raw_model else raw_model,
"prompt_tokens": um.get("input_tokens", 0),
"completion_tokens": um.get("output_tokens", 0),
"total_tokens": um.get("total_tokens", 0),
usage_info = None
usage = getattr(response, "usage", None)
if usage:
raw_model = getattr(llm, "model", "") or ""
model_name = raw_model.split("/", 1)[-1] if "/" in raw_model else (raw_model or response.model or "unknown")
usage_info = {
"model": model_name,
"prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
"completion_tokens": getattr(usage, "completion_tokens", 0) or 0,
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
}
if hasattr(title_result, "content"):
raw_title = title_result.content.strip()
raw_title = response.choices[0].message.content.strip()
if raw_title and len(raw_title) <= 100:
return raw_title.strip("\"'"), usage_dict
return None, usage_dict
return raw_title.strip("\"'"), usage_info
return None, usage_info
except Exception:
logging.getLogger(__name__).exception("[TitleGen] _generate_title failed")
return None, None
title_task = asyncio.create_task(_generate_title())
@ -1520,7 +1544,9 @@ async def stream_new_chat(
# Inject title update mid-stream as soon as the background task finishes
if title_task is not None and title_task.done() and not title_emitted:
generated_title, _title_usage = title_task.result()
generated_title, title_usage = title_task.result()
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
@ -1567,7 +1593,9 @@ async def stream_new_chat(
# If the title task didn't finish during streaming, await it now
if title_task is not None and not title_emitted:
generated_title, _title_usage = await title_task
generated_title, title_usage = await title_task
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(