diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 10a6951fa..a5245456e 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -30,6 +30,7 @@ from app.db import ( NewChatThread, Permission, SearchSpace, + TokenUsage, User, get_async_session, shielded_async_session, @@ -45,6 +46,7 @@ from app.schemas.new_chat import ( NewChatThreadWithMessages, PublicChatSnapshotCreateResponse, PublicChatSnapshotListResponse, + TokenUsageSummary, RegenerateRequest, ResumeRequest, ThreadHistoryLoadResponse, @@ -473,10 +475,13 @@ async def get_thread_messages( # Check thread-level access based on visibility await check_thread_access(session, thread, user) - # Get messages with their authors loaded + # Get messages with their authors and token usage loaded messages_result = await session.execute( select(NewChatMessage) - .options(selectinload(NewChatMessage.author)) + .options( + selectinload(NewChatMessage.author), + selectinload(NewChatMessage.token_usage), + ) .filter(NewChatMessage.thread_id == thread_id) .order_by(NewChatMessage.created_at) ) @@ -493,6 +498,7 @@ async def get_thread_messages( author_id=msg.author_id, author_display_name=msg.author.display_name if msg.author else None, author_avatar_url=msg.author.avatar_url if msg.author else None, + token_usage=TokenUsageSummary.model_validate(msg.token_usage) if msg.token_usage else None, ) for msg in db_messages ] @@ -530,7 +536,11 @@ async def get_thread_full( try: result = await session.execute( select(NewChatThread) - .options(selectinload(NewChatThread.messages)) + .options( + selectinload(NewChatThread.messages).selectinload( + NewChatMessage.token_usage + ), + ) .filter(NewChatThread.id == thread_id) ) thread = result.scalars().first() @@ -935,6 +945,24 @@ async def append_message( # flush assigns the PK/defaults without a round-trip SELECT await session.flush() + + # Persist token usage if provided (for assistant messages) + token_usage_data = raw_body.get("token_usage") + if token_usage_data and message_role == NewChatMessageRole.ASSISTANT: + token_usage_record = TokenUsage( + prompt_tokens=token_usage_data.get("prompt_tokens", 0), + completion_tokens=token_usage_data.get("completion_tokens", 0), + total_tokens=token_usage_data.get("total_tokens", 0), + model_breakdown=token_usage_data.get("usage"), + call_details=token_usage_data.get("call_details"), + usage_type="chat", + thread_id=thread_id, + message_id=db_message.id, + search_space_id=thread.search_space_id, + user_id=user.id, + ) + session.add(token_usage_record) + await session.commit() # Return the in-memory object (already has id from flush) instead of diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index 5d8ae207e..e523657a4 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -34,6 +34,14 @@ class NewChatMessageCreate(NewChatMessageBase): thread_id: int +class TokenUsageSummary(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + model_breakdown: dict | None = None + model_config = ConfigDict(from_attributes=True) + + class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel): """Schema for reading a message.""" @@ -41,6 +49,7 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel): author_id: UUID | None = None author_display_name: str | None = None author_avatar_url: str | None = None + token_usage: TokenUsageSummary | None = None model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 63d8d10b9..d97665f7a 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -970,6 +970,7 @@ class ChatLiteLLMRouter(BaseChatModel): messages=formatted_messages, stop=stop, stream=True, + stream_options={"include_usage": True}, **call_kwargs, ) except ContextWindowExceededError as e: diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 723b17607..c90bdfce3 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -22,10 +22,13 @@ litellm.drop_params = True # Memory controls: prevent unbounded internal accumulation litellm.telemetry = False litellm.cache = None -litellm.success_callback = [] litellm.failure_callback = [] litellm.input_callback = [] +from app.services.token_tracking_service import token_tracker + +litellm.callbacks = [token_tracker] + logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py new file mode 100644 index 000000000..434a55ae0 --- /dev/null +++ b/surfsense_backend/app/services/token_tracking_service.py @@ -0,0 +1,129 @@ +""" +Token usage tracking via LiteLLM custom callback. + +Uses a ContextVar-scoped accumulator to group all LLM calls within a single +async request/turn. The accumulated data is emitted via SSE and persisted +when the frontend calls appendMessage. +""" + +from __future__ import annotations + +import dataclasses +import logging +from contextvars import ContextVar +from dataclasses import dataclass, field +from typing import Any + +from litellm.integrations.custom_logger import CustomLogger + +logger = logging.getLogger(__name__) + + +@dataclass +class TokenCallRecord: + model: str + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +@dataclass +class TurnTokenAccumulator: + """Accumulates token usage across all LLM calls within a single user turn.""" + + calls: list[TokenCallRecord] = field(default_factory=list) + + def add( + self, + model: str, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + ) -> None: + self.calls.append( + TokenCallRecord( + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + ) + + def per_message_summary(self) -> dict[str, dict[str, int]]: + """Return token counts grouped by model name.""" + by_model: dict[str, dict[str, int]] = {} + for c in self.calls: + entry = by_model.setdefault( + c.model, + {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + ) + entry["prompt_tokens"] += c.prompt_tokens + entry["completion_tokens"] += c.completion_tokens + entry["total_tokens"] += c.total_tokens + return by_model + + @property + def grand_total(self) -> int: + return sum(c.total_tokens for c in self.calls) + + @property + def total_prompt_tokens(self) -> int: + return sum(c.prompt_tokens for c in self.calls) + + @property + def total_completion_tokens(self) -> int: + return sum(c.completion_tokens for c in self.calls) + + def serialized_calls(self) -> list[dict[str, Any]]: + return [dataclasses.asdict(c) for c in self.calls] + + +_turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar( + "_turn_accumulator", default=None +) + + +def start_turn() -> TurnTokenAccumulator: + """Create a fresh accumulator for the current async context and return it.""" + acc = TurnTokenAccumulator() + _turn_accumulator.set(acc) + return acc + + +def get_current_accumulator() -> TurnTokenAccumulator | None: + return _turn_accumulator.get() + + +class TokenTrackingCallback(CustomLogger): + """LiteLLM callback that captures token usage into the turn accumulator.""" + + async def async_log_success_event( + self, + kwargs: dict[str, Any], + response_obj: Any, + start_time: Any, + end_time: Any, + ) -> None: + acc = _turn_accumulator.get() + if acc is None: + return + + usage = getattr(response_obj, "usage", None) + if not usage: + return + + prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(usage, "completion_tokens", 0) or 0 + total_tokens = getattr(usage, "total_tokens", 0) or 0 + + model = kwargs.get("model", "unknown") + + acc.add( + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + + +token_tracker = TokenTrackingCallback() diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index fd118528e..4459b9c06 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1170,6 +1170,10 @@ async def stream_new_chat( _t_total = time.perf_counter() log_system_snapshot("stream_new_chat_START") + from app.services.token_tracking_service import start_turn + + accumulator = start_turn() + session = async_session_maker() try: # Mark AI as responding to this user for live collaboration @@ -1527,6 +1531,17 @@ async def stream_new_chat( if stream_result.is_interrupted: if title_task is not None and not title_task.done(): title_task.cancel() + + usage_summary = accumulator.per_message_summary() + if usage_summary: + yield streaming_service.format_data("token-usage", { + "usage": usage_summary, + "prompt_tokens": accumulator.total_prompt_tokens, + "completion_tokens": accumulator.total_completion_tokens, + "total_tokens": accumulator.grand_total, + "call_details": accumulator.serialized_calls(), + }) + yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -1548,6 +1563,16 @@ async def stream_new_chat( chat_id, generated_title ) + usage_summary = accumulator.per_message_summary() + if usage_summary: + yield streaming_service.format_data("token-usage", { + "usage": usage_summary, + "prompt_tokens": accumulator.total_prompt_tokens, + "completion_tokens": accumulator.total_completion_tokens, + "total_tokens": accumulator.grand_total, + "call_details": accumulator.serialized_calls(), + }) + # Fire background memory extraction if the agent didn't handle it. # Shared threads write to team memory; private threads write to user memory. if not stream_result.agent_called_update_memory: @@ -1646,6 +1671,10 @@ async def stream_resume_chat( stream_result = StreamResult() _t_total = time.perf_counter() + from app.services.token_tracking_service import start_turn + + accumulator = start_turn() + session = async_session_maker() try: if user_id: @@ -1769,11 +1798,31 @@ async def stream_resume_chat( chat_id, ) if stream_result.is_interrupted: + usage_summary = accumulator.per_message_summary() + if usage_summary: + yield streaming_service.format_data("token-usage", { + "usage": usage_summary, + "prompt_tokens": accumulator.total_prompt_tokens, + "completion_tokens": accumulator.total_completion_tokens, + "total_tokens": accumulator.grand_total, + "call_details": accumulator.serialized_calls(), + }) + yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() return + usage_summary = accumulator.per_message_summary() + if usage_summary: + yield streaming_service.format_data("token-usage", { + "usage": usage_summary, + "prompt_tokens": accumulator.total_prompt_tokens, + "completion_tokens": accumulator.total_completion_tokens, + "total_tokens": accumulator.grand_total, + "call_details": accumulator.serialized_calls(), + }) + yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done()