mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
feat: implement token usage tracking for LLM calls with new accumulator and callback
This commit is contained in:
parent
917f35eb33
commit
3cfe53fb7f
6 changed files with 223 additions and 4 deletions
|
|
@ -30,6 +30,7 @@ from app.db import (
|
|||
NewChatThread,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
TokenUsage,
|
||||
User,
|
||||
get_async_session,
|
||||
shielded_async_session,
|
||||
|
|
@ -45,6 +46,7 @@ from app.schemas.new_chat import (
|
|||
NewChatThreadWithMessages,
|
||||
PublicChatSnapshotCreateResponse,
|
||||
PublicChatSnapshotListResponse,
|
||||
TokenUsageSummary,
|
||||
RegenerateRequest,
|
||||
ResumeRequest,
|
||||
ThreadHistoryLoadResponse,
|
||||
|
|
@ -473,10 +475,13 @@ async def get_thread_messages(
|
|||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
# Get messages with their authors loaded
|
||||
# Get messages with their authors and token usage loaded
|
||||
messages_result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.options(selectinload(NewChatMessage.author))
|
||||
.options(
|
||||
selectinload(NewChatMessage.author),
|
||||
selectinload(NewChatMessage.token_usage),
|
||||
)
|
||||
.filter(NewChatMessage.thread_id == thread_id)
|
||||
.order_by(NewChatMessage.created_at)
|
||||
)
|
||||
|
|
@ -493,6 +498,7 @@ async def get_thread_messages(
|
|||
author_id=msg.author_id,
|
||||
author_display_name=msg.author.display_name if msg.author else None,
|
||||
author_avatar_url=msg.author.avatar_url if msg.author else None,
|
||||
token_usage=TokenUsageSummary.model_validate(msg.token_usage) if msg.token_usage else None,
|
||||
)
|
||||
for msg in db_messages
|
||||
]
|
||||
|
|
@ -530,7 +536,11 @@ async def get_thread_full(
|
|||
try:
|
||||
result = await session.execute(
|
||||
select(NewChatThread)
|
||||
.options(selectinload(NewChatThread.messages))
|
||||
.options(
|
||||
selectinload(NewChatThread.messages).selectinload(
|
||||
NewChatMessage.token_usage
|
||||
),
|
||||
)
|
||||
.filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
|
|
@ -935,6 +945,24 @@ async def append_message(
|
|||
|
||||
# flush assigns the PK/defaults without a round-trip SELECT
|
||||
await session.flush()
|
||||
|
||||
# Persist token usage if provided (for assistant messages)
|
||||
token_usage_data = raw_body.get("token_usage")
|
||||
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
||||
token_usage_record = TokenUsage(
|
||||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||
model_breakdown=token_usage_data.get("usage"),
|
||||
call_details=token_usage_data.get("call_details"),
|
||||
usage_type="chat",
|
||||
thread_id=thread_id,
|
||||
message_id=db_message.id,
|
||||
search_space_id=thread.search_space_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(token_usage_record)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Return the in-memory object (already has id from flush) instead of
|
||||
|
|
|
|||
|
|
@ -34,6 +34,14 @@ class NewChatMessageCreate(NewChatMessageBase):
|
|||
thread_id: int
|
||||
|
||||
|
||||
class TokenUsageSummary(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
model_breakdown: dict | None = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
|
||||
"""Schema for reading a message."""
|
||||
|
||||
|
|
@ -41,6 +49,7 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
|
|||
author_id: UUID | None = None
|
||||
author_display_name: str | None = None
|
||||
author_avatar_url: str | None = None
|
||||
token_usage: TokenUsageSummary | None = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -970,6 +970,7 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
**call_kwargs,
|
||||
)
|
||||
except ContextWindowExceededError as e:
|
||||
|
|
|
|||
|
|
@ -22,10 +22,13 @@ litellm.drop_params = True
|
|||
# Memory controls: prevent unbounded internal accumulation
|
||||
litellm.telemetry = False
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm.failure_callback = []
|
||||
litellm.input_callback = []
|
||||
|
||||
from app.services.token_tracking_service import token_tracker
|
||||
|
||||
litellm.callbacks = [token_tracker]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
129
surfsense_backend/app/services/token_tracking_service.py
Normal file
129
surfsense_backend/app/services/token_tracking_service.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""
|
||||
Token usage tracking via LiteLLM custom callback.
|
||||
|
||||
Uses a ContextVar-scoped accumulator to group all LLM calls within a single
|
||||
async request/turn. The accumulated data is emitted via SSE and persisted
|
||||
when the frontend calls appendMessage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenCallRecord:
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class TurnTokenAccumulator:
|
||||
"""Accumulates token usage across all LLM calls within a single user turn."""
|
||||
|
||||
calls: list[TokenCallRecord] = field(default_factory=list)
|
||||
|
||||
def add(
|
||||
self,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
) -> None:
|
||||
self.calls.append(
|
||||
TokenCallRecord(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||
"""Return token counts grouped by model name."""
|
||||
by_model: dict[str, dict[str, int]] = {}
|
||||
for c in self.calls:
|
||||
entry = by_model.setdefault(
|
||||
c.model,
|
||||
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
)
|
||||
entry["prompt_tokens"] += c.prompt_tokens
|
||||
entry["completion_tokens"] += c.completion_tokens
|
||||
entry["total_tokens"] += c.total_tokens
|
||||
return by_model
|
||||
|
||||
@property
|
||||
def grand_total(self) -> int:
|
||||
return sum(c.total_tokens for c in self.calls)
|
||||
|
||||
@property
|
||||
def total_prompt_tokens(self) -> int:
|
||||
return sum(c.prompt_tokens for c in self.calls)
|
||||
|
||||
@property
|
||||
def total_completion_tokens(self) -> int:
|
||||
return sum(c.completion_tokens for c in self.calls)
|
||||
|
||||
def serialized_calls(self) -> list[dict[str, Any]]:
|
||||
return [dataclasses.asdict(c) for c in self.calls]
|
||||
|
||||
|
||||
_turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
|
||||
"_turn_accumulator", default=None
|
||||
)
|
||||
|
||||
|
||||
def start_turn() -> TurnTokenAccumulator:
|
||||
"""Create a fresh accumulator for the current async context and return it."""
|
||||
acc = TurnTokenAccumulator()
|
||||
_turn_accumulator.set(acc)
|
||||
return acc
|
||||
|
||||
|
||||
def get_current_accumulator() -> TurnTokenAccumulator | None:
|
||||
return _turn_accumulator.get()
|
||||
|
||||
|
||||
class TokenTrackingCallback(CustomLogger):
|
||||
"""LiteLLM callback that captures token usage into the turn accumulator."""
|
||||
|
||||
async def async_log_success_event(
|
||||
self,
|
||||
kwargs: dict[str, Any],
|
||||
response_obj: Any,
|
||||
start_time: Any,
|
||||
end_time: Any,
|
||||
) -> None:
|
||||
acc = _turn_accumulator.get()
|
||||
if acc is None:
|
||||
return
|
||||
|
||||
usage = getattr(response_obj, "usage", None)
|
||||
if not usage:
|
||||
return
|
||||
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||
|
||||
model = kwargs.get("model", "unknown")
|
||||
|
||||
acc.add(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
|
||||
token_tracker = TokenTrackingCallback()
|
||||
|
|
@ -1170,6 +1170,10 @@ async def stream_new_chat(
|
|||
_t_total = time.perf_counter()
|
||||
log_system_snapshot("stream_new_chat_START")
|
||||
|
||||
from app.services.token_tracking_service import start_turn
|
||||
|
||||
accumulator = start_turn()
|
||||
|
||||
session = async_session_maker()
|
||||
try:
|
||||
# Mark AI as responding to this user for live collaboration
|
||||
|
|
@ -1527,6 +1531,17 @@ async def stream_new_chat(
|
|||
if stream_result.is_interrupted:
|
||||
if title_task is not None and not title_task.done():
|
||||
title_task.cancel()
|
||||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data("token-usage", {
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
})
|
||||
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
|
@ -1548,6 +1563,16 @@ async def stream_new_chat(
|
|||
chat_id, generated_title
|
||||
)
|
||||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data("token-usage", {
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
})
|
||||
|
||||
# Fire background memory extraction if the agent didn't handle it.
|
||||
# Shared threads write to team memory; private threads write to user memory.
|
||||
if not stream_result.agent_called_update_memory:
|
||||
|
|
@ -1646,6 +1671,10 @@ async def stream_resume_chat(
|
|||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
|
||||
from app.services.token_tracking_service import start_turn
|
||||
|
||||
accumulator = start_turn()
|
||||
|
||||
session = async_session_maker()
|
||||
try:
|
||||
if user_id:
|
||||
|
|
@ -1769,11 +1798,31 @@ async def stream_resume_chat(
|
|||
chat_id,
|
||||
)
|
||||
if stream_result.is_interrupted:
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data("token-usage", {
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
})
|
||||
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data("token-usage", {
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
})
|
||||
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue