mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-27 17:56:25 +02:00
feat: implement token usage recording in chat routes and enhance title generation handling
This commit is contained in:
parent
292fcb1a2c
commit
f01ddf3f0a
3 changed files with 121 additions and 36 deletions
|
|
@ -5,9 +5,11 @@ Uses a ContextVar-scoped accumulator to group all LLM calls within a single
|
|||
async request/turn. The accumulated data is emitted via SSE and persisted
|
||||
when the frontend calls appendMessage.
|
||||
|
||||
Agent LLM calls are captured automatically via the async callback.
|
||||
Title-generation usage is added explicitly from the LangChain response
|
||||
metadata to avoid callback-timing issues.
|
||||
The module also provides ``record_token_usage``, a thin async helper that
|
||||
creates a ``TokenUsage`` row for *any* usage type (chat, indexing, image
|
||||
generation, podcasts, …). Call sites should prefer this helper over
|
||||
constructing ``TokenUsage`` manually so that logging and error handling
|
||||
stay consistent.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -17,8 +19,12 @@ import logging
|
|||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import TokenUsage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -138,3 +144,54 @@ class TokenTrackingCallback(CustomLogger):
|
|||
|
||||
|
||||
token_tracker = TokenTrackingCallback()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persistence helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def record_token_usage(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
usage_type: str,
|
||||
search_space_id: int,
|
||||
user_id: UUID,
|
||||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
model_breakdown: dict[str, Any] | None = None,
|
||||
call_details: dict[str, Any] | None = None,
|
||||
thread_id: int | None = None,
|
||||
message_id: int | None = None,
|
||||
) -> TokenUsage | None:
|
||||
"""Persist a single ``TokenUsage`` row.
|
||||
|
||||
Returns the record on success, ``None`` if persistence failed (the
|
||||
failure is logged but never propagated so callers don't need to
|
||||
wrap this in try/except).
|
||||
"""
|
||||
try:
|
||||
record = TokenUsage(
|
||||
usage_type=usage_type,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model_breakdown=model_breakdown,
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(record)
|
||||
logger.debug(
|
||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
|
||||
usage_type, prompt_tokens, completion_tokens, total_tokens,
|
||||
)
|
||||
return record
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[TokenTracking] failed to record %s token usage", usage_type, exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue