mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-02 20:32:39 +02:00
feat: implement token usage tracking for LLM calls with new accumulator and callback
This commit is contained in:
parent
917f35eb33
commit
3cfe53fb7f
6 changed files with 223 additions and 4 deletions
|
|
@ -30,6 +30,7 @@ from app.db import (
|
||||||
NewChatThread,
|
NewChatThread,
|
||||||
Permission,
|
Permission,
|
||||||
SearchSpace,
|
SearchSpace,
|
||||||
|
TokenUsage,
|
||||||
User,
|
User,
|
||||||
get_async_session,
|
get_async_session,
|
||||||
shielded_async_session,
|
shielded_async_session,
|
||||||
|
|
@ -45,6 +46,7 @@ from app.schemas.new_chat import (
|
||||||
NewChatThreadWithMessages,
|
NewChatThreadWithMessages,
|
||||||
PublicChatSnapshotCreateResponse,
|
PublicChatSnapshotCreateResponse,
|
||||||
PublicChatSnapshotListResponse,
|
PublicChatSnapshotListResponse,
|
||||||
|
TokenUsageSummary,
|
||||||
RegenerateRequest,
|
RegenerateRequest,
|
||||||
ResumeRequest,
|
ResumeRequest,
|
||||||
ThreadHistoryLoadResponse,
|
ThreadHistoryLoadResponse,
|
||||||
|
|
@ -473,10 +475,13 @@ async def get_thread_messages(
|
||||||
# Check thread-level access based on visibility
|
# Check thread-level access based on visibility
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
|
||||||
# Get messages with their authors loaded
|
# Get messages with their authors and token usage loaded
|
||||||
messages_result = await session.execute(
|
messages_result = await session.execute(
|
||||||
select(NewChatMessage)
|
select(NewChatMessage)
|
||||||
.options(selectinload(NewChatMessage.author))
|
.options(
|
||||||
|
selectinload(NewChatMessage.author),
|
||||||
|
selectinload(NewChatMessage.token_usage),
|
||||||
|
)
|
||||||
.filter(NewChatMessage.thread_id == thread_id)
|
.filter(NewChatMessage.thread_id == thread_id)
|
||||||
.order_by(NewChatMessage.created_at)
|
.order_by(NewChatMessage.created_at)
|
||||||
)
|
)
|
||||||
|
|
@ -493,6 +498,7 @@ async def get_thread_messages(
|
||||||
author_id=msg.author_id,
|
author_id=msg.author_id,
|
||||||
author_display_name=msg.author.display_name if msg.author else None,
|
author_display_name=msg.author.display_name if msg.author else None,
|
||||||
author_avatar_url=msg.author.avatar_url if msg.author else None,
|
author_avatar_url=msg.author.avatar_url if msg.author else None,
|
||||||
|
token_usage=TokenUsageSummary.model_validate(msg.token_usage) if msg.token_usage else None,
|
||||||
)
|
)
|
||||||
for msg in db_messages
|
for msg in db_messages
|
||||||
]
|
]
|
||||||
|
|
@ -530,7 +536,11 @@ async def get_thread_full(
|
||||||
try:
|
try:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(NewChatThread)
|
select(NewChatThread)
|
||||||
.options(selectinload(NewChatThread.messages))
|
.options(
|
||||||
|
selectinload(NewChatThread.messages).selectinload(
|
||||||
|
NewChatMessage.token_usage
|
||||||
|
),
|
||||||
|
)
|
||||||
.filter(NewChatThread.id == thread_id)
|
.filter(NewChatThread.id == thread_id)
|
||||||
)
|
)
|
||||||
thread = result.scalars().first()
|
thread = result.scalars().first()
|
||||||
|
|
@ -935,6 +945,24 @@ async def append_message(
|
||||||
|
|
||||||
# flush assigns the PK/defaults without a round-trip SELECT
|
# flush assigns the PK/defaults without a round-trip SELECT
|
||||||
await session.flush()
|
await session.flush()
|
||||||
|
|
||||||
|
# Persist token usage if provided (for assistant messages)
|
||||||
|
token_usage_data = raw_body.get("token_usage")
|
||||||
|
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
||||||
|
token_usage_record = TokenUsage(
|
||||||
|
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||||
|
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||||
|
model_breakdown=token_usage_data.get("usage"),
|
||||||
|
call_details=token_usage_data.get("call_details"),
|
||||||
|
usage_type="chat",
|
||||||
|
thread_id=thread_id,
|
||||||
|
message_id=db_message.id,
|
||||||
|
search_space_id=thread.search_space_id,
|
||||||
|
user_id=user.id,
|
||||||
|
)
|
||||||
|
session.add(token_usage_record)
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# Return the in-memory object (already has id from flush) instead of
|
# Return the in-memory object (already has id from flush) instead of
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,14 @@ class NewChatMessageCreate(NewChatMessageBase):
|
||||||
thread_id: int
|
thread_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class TokenUsageSummary(BaseModel):
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
model_breakdown: dict | None = None
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
|
class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
|
||||||
"""Schema for reading a message."""
|
"""Schema for reading a message."""
|
||||||
|
|
||||||
|
|
@ -41,6 +49,7 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
|
||||||
author_id: UUID | None = None
|
author_id: UUID | None = None
|
||||||
author_display_name: str | None = None
|
author_display_name: str | None = None
|
||||||
author_avatar_url: str | None = None
|
author_avatar_url: str | None = None
|
||||||
|
token_usage: TokenUsageSummary | None = None
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -970,6 +970,7 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
messages=formatted_messages,
|
messages=formatted_messages,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
**call_kwargs,
|
**call_kwargs,
|
||||||
)
|
)
|
||||||
except ContextWindowExceededError as e:
|
except ContextWindowExceededError as e:
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,13 @@ litellm.drop_params = True
|
||||||
# Memory controls: prevent unbounded internal accumulation
|
# Memory controls: prevent unbounded internal accumulation
|
||||||
litellm.telemetry = False
|
litellm.telemetry = False
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
litellm.success_callback = []
|
|
||||||
litellm.failure_callback = []
|
litellm.failure_callback = []
|
||||||
litellm.input_callback = []
|
litellm.input_callback = []
|
||||||
|
|
||||||
|
from app.services.token_tracking_service import token_tracker
|
||||||
|
|
||||||
|
litellm.callbacks = [token_tracker]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
129
surfsense_backend/app/services/token_tracking_service.py
Normal file
129
surfsense_backend/app/services/token_tracking_service.py
Normal file
|
|
@ -0,0 +1,129 @@
|
||||||
|
"""
|
||||||
|
Token usage tracking via LiteLLM custom callback.
|
||||||
|
|
||||||
|
Uses a ContextVar-scoped accumulator to group all LLM calls within a single
|
||||||
|
async request/turn. The accumulated data is emitted via SSE and persisted
|
||||||
|
when the frontend calls appendMessage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenCallRecord:
|
||||||
|
model: str
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TurnTokenAccumulator:
|
||||||
|
"""Accumulates token usage across all LLM calls within a single user turn."""
|
||||||
|
|
||||||
|
calls: list[TokenCallRecord] = field(default_factory=list)
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
total_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
self.calls.append(
|
||||||
|
TokenCallRecord(
|
||||||
|
model=model,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||||
|
"""Return token counts grouped by model name."""
|
||||||
|
by_model: dict[str, dict[str, int]] = {}
|
||||||
|
for c in self.calls:
|
||||||
|
entry = by_model.setdefault(
|
||||||
|
c.model,
|
||||||
|
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||||
|
)
|
||||||
|
entry["prompt_tokens"] += c.prompt_tokens
|
||||||
|
entry["completion_tokens"] += c.completion_tokens
|
||||||
|
entry["total_tokens"] += c.total_tokens
|
||||||
|
return by_model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def grand_total(self) -> int:
|
||||||
|
return sum(c.total_tokens for c in self.calls)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_prompt_tokens(self) -> int:
|
||||||
|
return sum(c.prompt_tokens for c in self.calls)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_completion_tokens(self) -> int:
|
||||||
|
return sum(c.completion_tokens for c in self.calls)
|
||||||
|
|
||||||
|
def serialized_calls(self) -> list[dict[str, Any]]:
|
||||||
|
return [dataclasses.asdict(c) for c in self.calls]
|
||||||
|
|
||||||
|
|
||||||
|
_turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
|
||||||
|
"_turn_accumulator", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def start_turn() -> TurnTokenAccumulator:
|
||||||
|
"""Create a fresh accumulator for the current async context and return it."""
|
||||||
|
acc = TurnTokenAccumulator()
|
||||||
|
_turn_accumulator.set(acc)
|
||||||
|
return acc
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_accumulator() -> TurnTokenAccumulator | None:
|
||||||
|
return _turn_accumulator.get()
|
||||||
|
|
||||||
|
|
||||||
|
class TokenTrackingCallback(CustomLogger):
|
||||||
|
"""LiteLLM callback that captures token usage into the turn accumulator."""
|
||||||
|
|
||||||
|
async def async_log_success_event(
|
||||||
|
self,
|
||||||
|
kwargs: dict[str, Any],
|
||||||
|
response_obj: Any,
|
||||||
|
start_time: Any,
|
||||||
|
end_time: Any,
|
||||||
|
) -> None:
|
||||||
|
acc = _turn_accumulator.get()
|
||||||
|
if acc is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
usage = getattr(response_obj, "usage", None)
|
||||||
|
if not usage:
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||||
|
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||||
|
total_tokens = getattr(usage, "total_tokens", 0) or 0
|
||||||
|
|
||||||
|
model = kwargs.get("model", "unknown")
|
||||||
|
|
||||||
|
acc.add(
|
||||||
|
model=model,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
token_tracker = TokenTrackingCallback()
|
||||||
|
|
@ -1170,6 +1170,10 @@ async def stream_new_chat(
|
||||||
_t_total = time.perf_counter()
|
_t_total = time.perf_counter()
|
||||||
log_system_snapshot("stream_new_chat_START")
|
log_system_snapshot("stream_new_chat_START")
|
||||||
|
|
||||||
|
from app.services.token_tracking_service import start_turn
|
||||||
|
|
||||||
|
accumulator = start_turn()
|
||||||
|
|
||||||
session = async_session_maker()
|
session = async_session_maker()
|
||||||
try:
|
try:
|
||||||
# Mark AI as responding to this user for live collaboration
|
# Mark AI as responding to this user for live collaboration
|
||||||
|
|
@ -1527,6 +1531,17 @@ async def stream_new_chat(
|
||||||
if stream_result.is_interrupted:
|
if stream_result.is_interrupted:
|
||||||
if title_task is not None and not title_task.done():
|
if title_task is not None and not title_task.done():
|
||||||
title_task.cancel()
|
title_task.cancel()
|
||||||
|
|
||||||
|
usage_summary = accumulator.per_message_summary()
|
||||||
|
if usage_summary:
|
||||||
|
yield streaming_service.format_data("token-usage", {
|
||||||
|
"usage": usage_summary,
|
||||||
|
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||||
|
"completion_tokens": accumulator.total_completion_tokens,
|
||||||
|
"total_tokens": accumulator.grand_total,
|
||||||
|
"call_details": accumulator.serialized_calls(),
|
||||||
|
})
|
||||||
|
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
|
|
@ -1548,6 +1563,16 @@ async def stream_new_chat(
|
||||||
chat_id, generated_title
|
chat_id, generated_title
|
||||||
)
|
)
|
||||||
|
|
||||||
|
usage_summary = accumulator.per_message_summary()
|
||||||
|
if usage_summary:
|
||||||
|
yield streaming_service.format_data("token-usage", {
|
||||||
|
"usage": usage_summary,
|
||||||
|
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||||
|
"completion_tokens": accumulator.total_completion_tokens,
|
||||||
|
"total_tokens": accumulator.grand_total,
|
||||||
|
"call_details": accumulator.serialized_calls(),
|
||||||
|
})
|
||||||
|
|
||||||
# Fire background memory extraction if the agent didn't handle it.
|
# Fire background memory extraction if the agent didn't handle it.
|
||||||
# Shared threads write to team memory; private threads write to user memory.
|
# Shared threads write to team memory; private threads write to user memory.
|
||||||
if not stream_result.agent_called_update_memory:
|
if not stream_result.agent_called_update_memory:
|
||||||
|
|
@ -1646,6 +1671,10 @@ async def stream_resume_chat(
|
||||||
stream_result = StreamResult()
|
stream_result = StreamResult()
|
||||||
_t_total = time.perf_counter()
|
_t_total = time.perf_counter()
|
||||||
|
|
||||||
|
from app.services.token_tracking_service import start_turn
|
||||||
|
|
||||||
|
accumulator = start_turn()
|
||||||
|
|
||||||
session = async_session_maker()
|
session = async_session_maker()
|
||||||
try:
|
try:
|
||||||
if user_id:
|
if user_id:
|
||||||
|
|
@ -1769,11 +1798,31 @@ async def stream_resume_chat(
|
||||||
chat_id,
|
chat_id,
|
||||||
)
|
)
|
||||||
if stream_result.is_interrupted:
|
if stream_result.is_interrupted:
|
||||||
|
usage_summary = accumulator.per_message_summary()
|
||||||
|
if usage_summary:
|
||||||
|
yield streaming_service.format_data("token-usage", {
|
||||||
|
"usage": usage_summary,
|
||||||
|
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||||
|
"completion_tokens": accumulator.total_completion_tokens,
|
||||||
|
"total_tokens": accumulator.grand_total,
|
||||||
|
"call_details": accumulator.serialized_calls(),
|
||||||
|
})
|
||||||
|
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
usage_summary = accumulator.per_message_summary()
|
||||||
|
if usage_summary:
|
||||||
|
yield streaming_service.format_data("token-usage", {
|
||||||
|
"usage": usage_summary,
|
||||||
|
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||||
|
"completion_tokens": accumulator.total_completion_tokens,
|
||||||
|
"total_tokens": accumulator.grand_total,
|
||||||
|
"call_details": accumulator.serialized_calls(),
|
||||||
|
})
|
||||||
|
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue