From f01ddf3f0a153cdc5685e1739366a58737193add Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Tue, 14 Apr 2026 20:56:07 +0530 Subject: [PATCH] feat: implement token usage recording in chat routes and enhance title generation handling --- .../app/routes/new_chat_routes.py | 12 +-- .../app/services/token_tracking_service.py | 63 +++++++++++++- .../app/tasks/chat/stream_new_chat.py | 82 +++++++++++++------ 3 files changed, 121 insertions(+), 36 deletions(-) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index fe79c7c06..55302b873 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -30,7 +30,6 @@ from app.db import ( NewChatThread, Permission, SearchSpace, - TokenUsage, User, get_async_session, shielded_async_session, @@ -53,6 +52,7 @@ from app.schemas.new_chat import ( ThreadListResponse, TokenUsageSummary, ) +from app.services.token_tracking_service import record_token_usage from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat from app.users import current_active_user from app.utils.rbac import check_permission @@ -949,19 +949,19 @@ async def append_message( # Persist token usage if provided (for assistant messages) token_usage_data = raw_body.get("token_usage") if token_usage_data and message_role == NewChatMessageRole.ASSISTANT: - token_usage_record = TokenUsage( + await record_token_usage( + session, + usage_type="chat", + search_space_id=thread.search_space_id, + user_id=user.id, prompt_tokens=token_usage_data.get("prompt_tokens", 0), completion_tokens=token_usage_data.get("completion_tokens", 0), total_tokens=token_usage_data.get("total_tokens", 0), model_breakdown=token_usage_data.get("usage"), call_details=token_usage_data.get("call_details"), - usage_type="chat", thread_id=thread_id, message_id=db_message.id, - search_space_id=thread.search_space_id, - user_id=user.id, ) - session.add(token_usage_record) await session.commit() diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py index 6a5b3793f..5d69e6870 100644 --- a/surfsense_backend/app/services/token_tracking_service.py +++ b/surfsense_backend/app/services/token_tracking_service.py @@ -5,9 +5,11 @@ Uses a ContextVar-scoped accumulator to group all LLM calls within a single async request/turn. The accumulated data is emitted via SSE and persisted when the frontend calls appendMessage. -Agent LLM calls are captured automatically via the async callback. -Title-generation usage is added explicitly from the LangChain response -metadata to avoid callback-timing issues. +The module also provides ``record_token_usage``, a thin async helper that +creates a ``TokenUsage`` row for *any* usage type (chat, indexing, image +generation, podcasts, …). Call sites should prefer this helper over +constructing ``TokenUsage`` manually so that logging and error handling +stay consistent. """ from __future__ import annotations @@ -17,8 +19,12 @@ import logging from contextvars import ContextVar from dataclasses import dataclass, field from typing import Any +from uuid import UUID from litellm.integrations.custom_logger import CustomLogger +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import TokenUsage logger = logging.getLogger(__name__) @@ -138,3 +144,54 @@ class TokenTrackingCallback(CustomLogger): token_tracker = TokenTrackingCallback() + + +# --------------------------------------------------------------------------- +# Persistence helper +# --------------------------------------------------------------------------- + + +async def record_token_usage( + session: AsyncSession, + *, + usage_type: str, + search_space_id: int, + user_id: UUID, + prompt_tokens: int = 0, + completion_tokens: int = 0, + total_tokens: int = 0, + model_breakdown: dict[str, Any] | None = None, + call_details: dict[str, Any] | None = None, + thread_id: int | None = None, + message_id: int | None = None, +) -> TokenUsage | None: + """Persist a single ``TokenUsage`` row. + + Returns the record on success, ``None`` if persistence failed (the + failure is logged but never propagated so callers don't need to + wrap this in try/except). + """ + try: + record = TokenUsage( + usage_type=usage_type, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + model_breakdown=model_breakdown, + call_details=call_details, + thread_id=thread_id, + message_id=message_id, + search_space_id=search_space_id, + user_id=user_id, + ) + session.add(record) + logger.debug( + "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d", + usage_type, prompt_tokens, completion_tokens, total_tokens, + ) + return record + except Exception: + logger.warning( + "[TokenTracking] failed to record %s token usage", usage_type, exc_info=True, + ) + return None diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 4530f5046..e87a1b791 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -51,7 +51,7 @@ from app.db import ( async_session_maker, shielded_async_session, ) -from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE +from app.prompts import TITLE_GENERATION_PROMPT from app.services.chat_session_state_service import ( clear_ai_responding, set_ai_responding, @@ -1460,34 +1460,58 @@ async def stream_new_chat( ) is_first_response = (assistant_count_result.scalar() or 0) == 0 - title_task: asyncio.Task[tuple[str | None, dict[str, int] | None]] | None = None + title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None if is_first_response: - async def _generate_title() -> tuple[str | None, dict[str, int] | None]: - """Return (title, usage_dict) where usage_dict has model/prompt/completion/total.""" + async def _generate_title() -> tuple[str | None, dict | None]: + """Generate a short title via litellm.acompletion. + + Returns (title, usage_dict). Usage is extracted directly from + the response object because litellm fires its async callback + via fire-and-forget ``create_task``, so the + ``TokenTrackingCallback`` would run too late. We also blank + the accumulator in this child-task context so the late callback + doesn't double-count. + """ try: - title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm - title_result = await title_chain.ainvoke( - {"user_query": user_query[:500]} - ) - usage_dict: dict[str, int] | None = None - if title_result: - um = getattr(title_result, "usage_metadata", None) - if um: - rm = getattr(title_result, "response_metadata", None) or {} - raw_model = rm.get("model_name", "unknown") - usage_dict = { - "model": raw_model.split("/", 1)[-1] if "/" in raw_model else raw_model, - "prompt_tokens": um.get("input_tokens", 0), - "completion_tokens": um.get("output_tokens", 0), - "total_tokens": um.get("total_tokens", 0), - } - if hasattr(title_result, "content"): - raw_title = title_result.content.strip() - if raw_title and len(raw_title) <= 100: - return raw_title.strip("\"'"), usage_dict - return None, usage_dict + from litellm import acompletion + from app.services.llm_router_service import LLMRouterService + from app.services.token_tracking_service import _turn_accumulator + + _turn_accumulator.set(None) + + prompt = TITLE_GENERATION_PROMPT.replace("{user_query}", user_query[:500]) + messages = [{"role": "user", "content": prompt}] + + if getattr(llm, "model", None) == "auto": + router = LLMRouterService.get_router() + response = await router.acompletion(model="auto", messages=messages) + else: + response = await acompletion( + model=llm.model, + messages=messages, + api_key=getattr(llm, "api_key", None), + api_base=getattr(llm, "api_base", None), + ) + + usage_info = None + usage = getattr(response, "usage", None) + if usage: + raw_model = getattr(llm, "model", "") or "" + model_name = raw_model.split("/", 1)[-1] if "/" in raw_model else (raw_model or response.model or "unknown") + usage_info = { + "model": model_name, + "prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0, + "completion_tokens": getattr(usage, "completion_tokens", 0) or 0, + "total_tokens": getattr(usage, "total_tokens", 0) or 0, + } + + raw_title = response.choices[0].message.content.strip() + if raw_title and len(raw_title) <= 100: + return raw_title.strip("\"'"), usage_info + return None, usage_info except Exception: + logging.getLogger(__name__).exception("[TitleGen] _generate_title failed") return None, None title_task = asyncio.create_task(_generate_title()) @@ -1520,7 +1544,9 @@ async def stream_new_chat( # Inject title update mid-stream as soon as the background task finishes if title_task is not None and title_task.done() and not title_emitted: - generated_title, _title_usage = title_task.result() + generated_title, title_usage = title_task.result() + if title_usage: + accumulator.add(**title_usage) if generated_title: async with shielded_async_session() as title_session: title_thread_result = await title_session.execute( @@ -1567,7 +1593,9 @@ async def stream_new_chat( # If the title task didn't finish during streaming, await it now if title_task is not None and not title_emitted: - generated_title, _title_usage = await title_task + generated_title, title_usage = await title_task + if title_usage: + accumulator.add(**title_usage) if generated_title: async with shielded_async_session() as title_session: title_thread_result = await title_session.execute(