mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
feat: implement token usage recording in chat routes and enhance title generation handling
This commit is contained in:
parent
292fcb1a2c
commit
f01ddf3f0a
3 changed files with 121 additions and 36 deletions
|
|
@ -30,7 +30,6 @@ from app.db import (
|
|||
NewChatThread,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
TokenUsage,
|
||||
User,
|
||||
get_async_session,
|
||||
shielded_async_session,
|
||||
|
|
@ -53,6 +52,7 @@ from app.schemas.new_chat import (
|
|||
ThreadListResponse,
|
||||
TokenUsageSummary,
|
||||
)
|
||||
from app.services.token_tracking_service import record_token_usage
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
@ -949,19 +949,19 @@ async def append_message(
|
|||
# Persist token usage if provided (for assistant messages)
|
||||
token_usage_data = raw_body.get("token_usage")
|
||||
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
||||
token_usage_record = TokenUsage(
|
||||
await record_token_usage(
|
||||
session,
|
||||
usage_type="chat",
|
||||
search_space_id=thread.search_space_id,
|
||||
user_id=user.id,
|
||||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||
model_breakdown=token_usage_data.get("usage"),
|
||||
call_details=token_usage_data.get("call_details"),
|
||||
usage_type="chat",
|
||||
thread_id=thread_id,
|
||||
message_id=db_message.id,
|
||||
search_space_id=thread.search_space_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(token_usage_record)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,11 @@ Uses a ContextVar-scoped accumulator to group all LLM calls within a single
|
|||
async request/turn. The accumulated data is emitted via SSE and persisted
|
||||
when the frontend calls appendMessage.
|
||||
|
||||
Agent LLM calls are captured automatically via the async callback.
|
||||
Title-generation usage is added explicitly from the LangChain response
|
||||
metadata to avoid callback-timing issues.
|
||||
The module also provides ``record_token_usage``, a thin async helper that
|
||||
creates a ``TokenUsage`` row for *any* usage type (chat, indexing, image
|
||||
generation, podcasts, …). Call sites should prefer this helper over
|
||||
constructing ``TokenUsage`` manually so that logging and error handling
|
||||
stay consistent.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -17,8 +19,12 @@ import logging
|
|||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import TokenUsage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -138,3 +144,54 @@ class TokenTrackingCallback(CustomLogger):
|
|||
|
||||
|
||||
token_tracker = TokenTrackingCallback()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persistence helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def record_token_usage(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
usage_type: str,
|
||||
search_space_id: int,
|
||||
user_id: UUID,
|
||||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
model_breakdown: dict[str, Any] | None = None,
|
||||
call_details: dict[str, Any] | None = None,
|
||||
thread_id: int | None = None,
|
||||
message_id: int | None = None,
|
||||
) -> TokenUsage | None:
|
||||
"""Persist a single ``TokenUsage`` row.
|
||||
|
||||
Returns the record on success, ``None`` if persistence failed (the
|
||||
failure is logged but never propagated so callers don't need to
|
||||
wrap this in try/except).
|
||||
"""
|
||||
try:
|
||||
record = TokenUsage(
|
||||
usage_type=usage_type,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model_breakdown=model_breakdown,
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(record)
|
||||
logger.debug(
|
||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
|
||||
usage_type, prompt_tokens, completion_tokens, total_tokens,
|
||||
)
|
||||
return record
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[TokenTracking] failed to record %s token usage", usage_type, exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ from app.db import (
|
|||
async_session_maker,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.prompts import TITLE_GENERATION_PROMPT
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
|
|
@ -1460,34 +1460,58 @@ async def stream_new_chat(
|
|||
)
|
||||
is_first_response = (assistant_count_result.scalar() or 0) == 0
|
||||
|
||||
title_task: asyncio.Task[tuple[str | None, dict[str, int] | None]] | None = None
|
||||
title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None
|
||||
if is_first_response:
|
||||
|
||||
async def _generate_title() -> tuple[str | None, dict[str, int] | None]:
|
||||
"""Return (title, usage_dict) where usage_dict has model/prompt/completion/total."""
|
||||
async def _generate_title() -> tuple[str | None, dict | None]:
|
||||
"""Generate a short title via litellm.acompletion.
|
||||
|
||||
Returns (title, usage_dict). Usage is extracted directly from
|
||||
the response object because litellm fires its async callback
|
||||
via fire-and-forget ``create_task``, so the
|
||||
``TokenTrackingCallback`` would run too late. We also blank
|
||||
the accumulator in this child-task context so the late callback
|
||||
doesn't double-count.
|
||||
"""
|
||||
try:
|
||||
title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm
|
||||
title_result = await title_chain.ainvoke(
|
||||
{"user_query": user_query[:500]}
|
||||
from litellm import acompletion
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
from app.services.token_tracking_service import _turn_accumulator
|
||||
|
||||
_turn_accumulator.set(None)
|
||||
|
||||
prompt = TITLE_GENERATION_PROMPT.replace("{user_query}", user_query[:500])
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
if getattr(llm, "model", None) == "auto":
|
||||
router = LLMRouterService.get_router()
|
||||
response = await router.acompletion(model="auto", messages=messages)
|
||||
else:
|
||||
response = await acompletion(
|
||||
model=llm.model,
|
||||
messages=messages,
|
||||
api_key=getattr(llm, "api_key", None),
|
||||
api_base=getattr(llm, "api_base", None),
|
||||
)
|
||||
usage_dict: dict[str, int] | None = None
|
||||
if title_result:
|
||||
um = getattr(title_result, "usage_metadata", None)
|
||||
if um:
|
||||
rm = getattr(title_result, "response_metadata", None) or {}
|
||||
raw_model = rm.get("model_name", "unknown")
|
||||
usage_dict = {
|
||||
"model": raw_model.split("/", 1)[-1] if "/" in raw_model else raw_model,
|
||||
"prompt_tokens": um.get("input_tokens", 0),
|
||||
"completion_tokens": um.get("output_tokens", 0),
|
||||
"total_tokens": um.get("total_tokens", 0),
|
||||
|
||||
usage_info = None
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage:
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
model_name = raw_model.split("/", 1)[-1] if "/" in raw_model else (raw_model or response.model or "unknown")
|
||||
usage_info = {
|
||||
"model": model_name,
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0) or 0,
|
||||
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
|
||||
}
|
||||
if hasattr(title_result, "content"):
|
||||
raw_title = title_result.content.strip()
|
||||
|
||||
raw_title = response.choices[0].message.content.strip()
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
return raw_title.strip("\"'"), usage_dict
|
||||
return None, usage_dict
|
||||
return raw_title.strip("\"'"), usage_info
|
||||
return None, usage_info
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("[TitleGen] _generate_title failed")
|
||||
return None, None
|
||||
|
||||
title_task = asyncio.create_task(_generate_title())
|
||||
|
|
@ -1520,7 +1544,9 @@ async def stream_new_chat(
|
|||
|
||||
# Inject title update mid-stream as soon as the background task finishes
|
||||
if title_task is not None and title_task.done() and not title_emitted:
|
||||
generated_title, _title_usage = title_task.result()
|
||||
generated_title, title_usage = title_task.result()
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
|
|
@ -1567,7 +1593,9 @@ async def stream_new_chat(
|
|||
|
||||
# If the title task didn't finish during streaming, await it now
|
||||
if title_task is not None and not title_emitted:
|
||||
generated_title, _title_usage = await title_task
|
||||
generated_title, title_usage = await title_task
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue