feat: implement token usage recording in chat routes and enhance title generation handling

This commit is contained in:
Anish Sarkar 2026-04-14 20:56:07 +05:30
parent 292fcb1a2c
commit f01ddf3f0a
3 changed files with 121 additions and 36 deletions

View file

@ -30,7 +30,6 @@ from app.db import (
NewChatThread, NewChatThread,
Permission, Permission,
SearchSpace, SearchSpace,
TokenUsage,
User, User,
get_async_session, get_async_session,
shielded_async_session, shielded_async_session,
@ -53,6 +52,7 @@ from app.schemas.new_chat import (
ThreadListResponse, ThreadListResponse,
TokenUsageSummary, TokenUsageSummary,
) )
from app.services.token_tracking_service import record_token_usage
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user from app.users import current_active_user
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
@ -949,19 +949,19 @@ async def append_message(
# Persist token usage if provided (for assistant messages) # Persist token usage if provided (for assistant messages)
token_usage_data = raw_body.get("token_usage") token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT: if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
token_usage_record = TokenUsage( await record_token_usage(
session,
usage_type="chat",
search_space_id=thread.search_space_id,
user_id=user.id,
prompt_tokens=token_usage_data.get("prompt_tokens", 0), prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0), completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0), total_tokens=token_usage_data.get("total_tokens", 0),
model_breakdown=token_usage_data.get("usage"), model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"), call_details=token_usage_data.get("call_details"),
usage_type="chat",
thread_id=thread_id, thread_id=thread_id,
message_id=db_message.id, message_id=db_message.id,
search_space_id=thread.search_space_id,
user_id=user.id,
) )
session.add(token_usage_record)
await session.commit() await session.commit()

View file

@ -5,9 +5,11 @@ Uses a ContextVar-scoped accumulator to group all LLM calls within a single
async request/turn. The accumulated data is emitted via SSE and persisted async request/turn. The accumulated data is emitted via SSE and persisted
when the frontend calls appendMessage. when the frontend calls appendMessage.
Agent LLM calls are captured automatically via the async callback. The module also provides ``record_token_usage``, a thin async helper that
Title-generation usage is added explicitly from the LangChain response creates a ``TokenUsage`` row for *any* usage type (chat, indexing, image
metadata to avoid callback-timing issues. generation, podcasts, ). Call sites should prefer this helper over
constructing ``TokenUsage`` manually so that logging and error handling
stay consistent.
""" """
from __future__ import annotations from __future__ import annotations
@ -17,8 +19,12 @@ import logging
from contextvars import ContextVar from contextvars import ContextVar
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from uuid import UUID
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import TokenUsage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -138,3 +144,54 @@ class TokenTrackingCallback(CustomLogger):
token_tracker = TokenTrackingCallback() token_tracker = TokenTrackingCallback()
# ---------------------------------------------------------------------------
# Persistence helper
# ---------------------------------------------------------------------------
async def record_token_usage(
session: AsyncSession,
*,
usage_type: str,
search_space_id: int,
user_id: UUID,
prompt_tokens: int = 0,
completion_tokens: int = 0,
total_tokens: int = 0,
model_breakdown: dict[str, Any] | None = None,
call_details: dict[str, Any] | None = None,
thread_id: int | None = None,
message_id: int | None = None,
) -> TokenUsage | None:
"""Persist a single ``TokenUsage`` row.
Returns the record on success, ``None`` if persistence failed (the
failure is logged but never propagated so callers don't need to
wrap this in try/except).
"""
try:
record = TokenUsage(
usage_type=usage_type,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model_breakdown=model_breakdown,
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
search_space_id=search_space_id,
user_id=user_id,
)
session.add(record)
logger.debug(
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
usage_type, prompt_tokens, completion_tokens, total_tokens,
)
return record
except Exception:
logger.warning(
"[TokenTracking] failed to record %s token usage", usage_type, exc_info=True,
)
return None

View file

@ -51,7 +51,7 @@ from app.db import (
async_session_maker, async_session_maker,
shielded_async_session, shielded_async_session,
) )
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE from app.prompts import TITLE_GENERATION_PROMPT
from app.services.chat_session_state_service import ( from app.services.chat_session_state_service import (
clear_ai_responding, clear_ai_responding,
set_ai_responding, set_ai_responding,
@ -1460,34 +1460,58 @@ async def stream_new_chat(
) )
is_first_response = (assistant_count_result.scalar() or 0) == 0 is_first_response = (assistant_count_result.scalar() or 0) == 0
title_task: asyncio.Task[tuple[str | None, dict[str, int] | None]] | None = None title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None
if is_first_response: if is_first_response:
async def _generate_title() -> tuple[str | None, dict[str, int] | None]: async def _generate_title() -> tuple[str | None, dict | None]:
"""Return (title, usage_dict) where usage_dict has model/prompt/completion/total.""" """Generate a short title via litellm.acompletion.
Returns (title, usage_dict). Usage is extracted directly from
the response object because litellm fires its async callback
via fire-and-forget ``create_task``, so the
``TokenTrackingCallback`` would run too late. We also blank
the accumulator in this child-task context so the late callback
doesn't double-count.
"""
try: try:
title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm from litellm import acompletion
title_result = await title_chain.ainvoke( from app.services.llm_router_service import LLMRouterService
{"user_query": user_query[:500]} from app.services.token_tracking_service import _turn_accumulator
)
usage_dict: dict[str, int] | None = None _turn_accumulator.set(None)
if title_result:
um = getattr(title_result, "usage_metadata", None) prompt = TITLE_GENERATION_PROMPT.replace("{user_query}", user_query[:500])
if um: messages = [{"role": "user", "content": prompt}]
rm = getattr(title_result, "response_metadata", None) or {}
raw_model = rm.get("model_name", "unknown") if getattr(llm, "model", None) == "auto":
usage_dict = { router = LLMRouterService.get_router()
"model": raw_model.split("/", 1)[-1] if "/" in raw_model else raw_model, response = await router.acompletion(model="auto", messages=messages)
"prompt_tokens": um.get("input_tokens", 0), else:
"completion_tokens": um.get("output_tokens", 0), response = await acompletion(
"total_tokens": um.get("total_tokens", 0), model=llm.model,
} messages=messages,
if hasattr(title_result, "content"): api_key=getattr(llm, "api_key", None),
raw_title = title_result.content.strip() api_base=getattr(llm, "api_base", None),
if raw_title and len(raw_title) <= 100: )
return raw_title.strip("\"'"), usage_dict
return None, usage_dict usage_info = None
usage = getattr(response, "usage", None)
if usage:
raw_model = getattr(llm, "model", "") or ""
model_name = raw_model.split("/", 1)[-1] if "/" in raw_model else (raw_model or response.model or "unknown")
usage_info = {
"model": model_name,
"prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
"completion_tokens": getattr(usage, "completion_tokens", 0) or 0,
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
}
raw_title = response.choices[0].message.content.strip()
if raw_title and len(raw_title) <= 100:
return raw_title.strip("\"'"), usage_info
return None, usage_info
except Exception: except Exception:
logging.getLogger(__name__).exception("[TitleGen] _generate_title failed")
return None, None return None, None
title_task = asyncio.create_task(_generate_title()) title_task = asyncio.create_task(_generate_title())
@ -1520,7 +1544,9 @@ async def stream_new_chat(
# Inject title update mid-stream as soon as the background task finishes # Inject title update mid-stream as soon as the background task finishes
if title_task is not None and title_task.done() and not title_emitted: if title_task is not None and title_task.done() and not title_emitted:
generated_title, _title_usage = title_task.result() generated_title, title_usage = title_task.result()
if title_usage:
accumulator.add(**title_usage)
if generated_title: if generated_title:
async with shielded_async_session() as title_session: async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute( title_thread_result = await title_session.execute(
@ -1567,7 +1593,9 @@ async def stream_new_chat(
# If the title task didn't finish during streaming, await it now # If the title task didn't finish during streaming, await it now
if title_task is not None and not title_emitted: if title_task is not None and not title_emitted:
generated_title, _title_usage = await title_task generated_title, title_usage = await title_task
if title_usage:
accumulator.add(**title_usage)
if generated_title: if generated_title:
async with shielded_async_session() as title_session: async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute( title_thread_result = await title_session.execute(