feat: implement token usage recording in chat routes and enhance title generation handling

This commit is contained in:
Anish Sarkar 2026-04-14 20:56:07 +05:30
parent 292fcb1a2c
commit f01ddf3f0a
3 changed files with 121 additions and 36 deletions

View file

@ -51,7 +51,7 @@ from app.db import (
async_session_maker,
shielded_async_session,
)
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
from app.prompts import TITLE_GENERATION_PROMPT
from app.services.chat_session_state_service import (
clear_ai_responding,
set_ai_responding,
@ -1460,34 +1460,58 @@ async def stream_new_chat(
)
is_first_response = (assistant_count_result.scalar() or 0) == 0
title_task: asyncio.Task[tuple[str | None, dict[str, int] | None]] | None = None
title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None
if is_first_response:
async def _generate_title() -> tuple[str | None, dict[str, int] | None]:
"""Return (title, usage_dict) where usage_dict has model/prompt/completion/total."""
async def _generate_title() -> tuple[str | None, dict | None]:
"""Generate a short title via litellm.acompletion.
Returns (title, usage_dict). Usage is extracted directly from
the response object because litellm fires its async callback
via fire-and-forget ``create_task``, so the
``TokenTrackingCallback`` would run too late. We also blank
the accumulator in this child-task context so the late callback
doesn't double-count.
"""
try:
title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm
title_result = await title_chain.ainvoke(
{"user_query": user_query[:500]}
)
usage_dict: dict[str, int] | None = None
if title_result:
um = getattr(title_result, "usage_metadata", None)
if um:
rm = getattr(title_result, "response_metadata", None) or {}
raw_model = rm.get("model_name", "unknown")
usage_dict = {
"model": raw_model.split("/", 1)[-1] if "/" in raw_model else raw_model,
"prompt_tokens": um.get("input_tokens", 0),
"completion_tokens": um.get("output_tokens", 0),
"total_tokens": um.get("total_tokens", 0),
}
if hasattr(title_result, "content"):
raw_title = title_result.content.strip()
if raw_title and len(raw_title) <= 100:
return raw_title.strip("\"'"), usage_dict
return None, usage_dict
from litellm import acompletion
from app.services.llm_router_service import LLMRouterService
from app.services.token_tracking_service import _turn_accumulator
_turn_accumulator.set(None)
prompt = TITLE_GENERATION_PROMPT.replace("{user_query}", user_query[:500])
messages = [{"role": "user", "content": prompt}]
if getattr(llm, "model", None) == "auto":
router = LLMRouterService.get_router()
response = await router.acompletion(model="auto", messages=messages)
else:
response = await acompletion(
model=llm.model,
messages=messages,
api_key=getattr(llm, "api_key", None),
api_base=getattr(llm, "api_base", None),
)
usage_info = None
usage = getattr(response, "usage", None)
if usage:
raw_model = getattr(llm, "model", "") or ""
model_name = raw_model.split("/", 1)[-1] if "/" in raw_model else (raw_model or response.model or "unknown")
usage_info = {
"model": model_name,
"prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
"completion_tokens": getattr(usage, "completion_tokens", 0) or 0,
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
}
raw_title = response.choices[0].message.content.strip()
if raw_title and len(raw_title) <= 100:
return raw_title.strip("\"'"), usage_info
return None, usage_info
except Exception:
logging.getLogger(__name__).exception("[TitleGen] _generate_title failed")
return None, None
title_task = asyncio.create_task(_generate_title())
@ -1520,7 +1544,9 @@ async def stream_new_chat(
# Inject title update mid-stream as soon as the background task finishes
if title_task is not None and title_task.done() and not title_emitted:
generated_title, _title_usage = title_task.result()
generated_title, title_usage = title_task.result()
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
@ -1567,7 +1593,9 @@ async def stream_new_chat(
# If the title task didn't finish during streaming, await it now
if title_task is not None and not title_emitted:
generated_title, _title_usage = await title_task
generated_title, title_usage = await title_task
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(