diff --git a/surfsense_backend/alembic/versions/125_add_token_usage_table.py b/surfsense_backend/alembic/versions/125_add_token_usage_table.py
new file mode 100644
index 000000000..915561c8c
--- /dev/null
+++ b/surfsense_backend/alembic/versions/125_add_token_usage_table.py
@@ -0,0 +1,85 @@
+"""125_add_token_usage_table
+
+Revision ID: 125
+Revises: 124
+Create Date: 2026-04-14
+
+Adds token_usage table for tracking LLM token consumption per message.
+Supports future extension via usage_type for indexing, image gen, etc.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import sqlalchemy as sa
+from sqlalchemy.dialects.postgresql import JSONB, UUID
+
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision: str = "125"
+down_revision: str | None = "124"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+ conn = op.get_bind()
+ if sa.inspect(conn).has_table("token_usage"):
+ return
+
+ op.create_table(
+ "token_usage",
+ sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
+ sa.Column("prompt_tokens", sa.Integer(), nullable=False, server_default="0"),
+ sa.Column(
+ "completion_tokens", sa.Integer(), nullable=False, server_default="0"
+ ),
+ sa.Column("total_tokens", sa.Integer(), nullable=False, server_default="0"),
+ sa.Column("model_breakdown", JSONB, nullable=True),
+ sa.Column("call_details", JSONB, nullable=True),
+ sa.Column("usage_type", sa.String(50), nullable=False, server_default="chat"),
+ sa.Column(
+ "thread_id",
+ sa.Integer(),
+ sa.ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
+ nullable=True,
+ ),
+ sa.Column(
+ "message_id",
+ sa.Integer(),
+ sa.ForeignKey("new_chat_messages.id", ondelete="SET NULL"),
+ nullable=True,
+ ),
+ sa.Column(
+ "search_space_id",
+ sa.Integer(),
+ sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
+ nullable=False,
+ ),
+ sa.Column(
+ "user_id",
+ UUID(as_uuid=True),
+ sa.ForeignKey("user.id", ondelete="CASCADE"),
+ nullable=False,
+ ),
+ sa.Column(
+ "created_at",
+ sa.TIMESTAMP(timezone=True),
+ nullable=False,
+ server_default=sa.func.now(),
+ ),
+ )
+
+ op.create_index("ix_token_usage_thread_id", "token_usage", ["thread_id"])
+ op.create_index("ix_token_usage_message_id", "token_usage", ["message_id"])
+ op.create_index(
+ "ix_token_usage_search_space_id", "token_usage", ["search_space_id"]
+ )
+ op.create_index("ix_token_usage_user_id", "token_usage", ["user_id"])
+ op.create_index("ix_token_usage_usage_type", "token_usage", ["usage_type"])
+
+
+def downgrade() -> None:
+ op.drop_table("token_usage")
diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py
index 82d77f847..61bdd65cb 100644
--- a/surfsense_backend/app/db.py
+++ b/surfsense_backend/app/db.py
@@ -647,6 +647,11 @@ class NewChatThread(BaseModel, TimestampMixin):
cascade="all, delete-orphan",
foreign_keys="[PublicChatSnapshot.thread_id]",
)
+ token_usages = relationship(
+ "TokenUsage",
+ back_populates="thread",
+ cascade="all, delete-orphan",
+ )
class NewChatMessage(BaseModel, TimestampMixin):
@@ -685,6 +690,63 @@ class NewChatMessage(BaseModel, TimestampMixin):
back_populates="message",
cascade="all, delete-orphan",
)
+ token_usage = relationship(
+ "TokenUsage",
+ back_populates="message",
+ uselist=False,
+ cascade="all, delete-orphan",
+ )
+
+
+class TokenUsage(BaseModel, TimestampMixin):
+ """
+ Tracks LLM token consumption per assistant turn.
+
+ One row per usage event. For chat, linked to a specific message via message_id.
+ The usage_type column enables future extension to track non-chat usage
+ (indexing, image generation, podcasts, etc.) without schema changes.
+ """
+
+ __tablename__ = "token_usage"
+
+ prompt_tokens = Column(Integer, nullable=False, default=0)
+ completion_tokens = Column(Integer, nullable=False, default=0)
+ total_tokens = Column(Integer, nullable=False, default=0)
+ model_breakdown = Column(JSONB, nullable=True)
+ call_details = Column(JSONB, nullable=True)
+
+ usage_type = Column(String(50), nullable=False, default="chat", index=True)
+
+ thread_id = Column(
+ Integer,
+ ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
+ nullable=True,
+ index=True,
+ )
+ message_id = Column(
+ Integer,
+ ForeignKey("new_chat_messages.id", ondelete="SET NULL"),
+ nullable=True,
+ index=True,
+ )
+ search_space_id = Column(
+ Integer,
+ ForeignKey("searchspaces.id", ondelete="CASCADE"),
+ nullable=False,
+ index=True,
+ )
+ user_id = Column(
+ UUID(as_uuid=True),
+ ForeignKey("user.id", ondelete="CASCADE"),
+ nullable=False,
+ index=True,
+ )
+
+ # Relationships
+ thread = relationship("NewChatThread", back_populates="token_usages")
+ message = relationship("NewChatMessage", back_populates="token_usage")
+ search_space = relationship("SearchSpace")
+ user = relationship("User")
class PublicChatSnapshot(BaseModel, TimestampMixin):
diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py
index 10a6951fa..b914b297e 100644
--- a/surfsense_backend/app/routes/new_chat_routes.py
+++ b/surfsense_backend/app/routes/new_chat_routes.py
@@ -50,7 +50,9 @@ from app.schemas.new_chat import (
ThreadHistoryLoadResponse,
ThreadListItem,
ThreadListResponse,
+ TokenUsageSummary,
)
+from app.services.token_tracking_service import record_token_usage
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user
from app.utils.rbac import check_permission
@@ -473,10 +475,13 @@ async def get_thread_messages(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
- # Get messages with their authors loaded
+ # Get messages with their authors and token usage loaded
messages_result = await session.execute(
select(NewChatMessage)
- .options(selectinload(NewChatMessage.author))
+ .options(
+ selectinload(NewChatMessage.author),
+ selectinload(NewChatMessage.token_usage),
+ )
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
)
@@ -493,6 +498,9 @@ async def get_thread_messages(
author_id=msg.author_id,
author_display_name=msg.author.display_name if msg.author else None,
author_avatar_url=msg.author.avatar_url if msg.author else None,
+ token_usage=TokenUsageSummary.model_validate(msg.token_usage)
+ if msg.token_usage
+ else None,
)
for msg in db_messages
]
@@ -530,7 +538,11 @@ async def get_thread_full(
try:
result = await session.execute(
select(NewChatThread)
- .options(selectinload(NewChatThread.messages))
+ .options(
+ selectinload(NewChatThread.messages).selectinload(
+ NewChatMessage.token_usage
+ ),
+ )
.filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
@@ -935,11 +947,37 @@ async def append_message(
# flush assigns the PK/defaults without a round-trip SELECT
await session.flush()
+
+ # Persist token usage if provided (for assistant messages)
+ token_usage_data = raw_body.get("token_usage")
+ if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
+ await record_token_usage(
+ session,
+ usage_type="chat",
+ search_space_id=thread.search_space_id,
+ user_id=user.id,
+ prompt_tokens=token_usage_data.get("prompt_tokens", 0),
+ completion_tokens=token_usage_data.get("completion_tokens", 0),
+ total_tokens=token_usage_data.get("total_tokens", 0),
+ model_breakdown=token_usage_data.get("usage"),
+ call_details=token_usage_data.get("call_details"),
+ thread_id=thread_id,
+ message_id=db_message.id,
+ )
+
await session.commit()
- # Return the in-memory object (already has id from flush) instead of
- # doing an extra refresh() SELECT.
- return db_message
+ # Build response manually to avoid lazy-loading the token_usage
+ # relationship after commit (which would trigger MissingGreenlet).
+ return NewChatMessageRead(
+ id=db_message.id,
+ thread_id=db_message.thread_id,
+ role=db_message.role,
+ content=db_message.content,
+ created_at=db_message.created_at,
+ author_id=db_message.author_id,
+ token_usage=None,
+ )
except HTTPException:
raise
@@ -1003,6 +1041,7 @@ async def list_messages(
# Get messages
query = (
select(NewChatMessage)
+ .options(selectinload(NewChatMessage.token_usage))
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
.offset(skip)
diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py
index 5d8ae207e..e523657a4 100644
--- a/surfsense_backend/app/schemas/new_chat.py
+++ b/surfsense_backend/app/schemas/new_chat.py
@@ -34,6 +34,14 @@ class NewChatMessageCreate(NewChatMessageBase):
thread_id: int
+class TokenUsageSummary(BaseModel):
+ prompt_tokens: int = 0
+ completion_tokens: int = 0
+ total_tokens: int = 0
+ model_breakdown: dict | None = None
+ model_config = ConfigDict(from_attributes=True)
+
+
class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
"""Schema for reading a message."""
@@ -41,6 +49,7 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
author_id: UUID | None = None
author_display_name: str | None = None
author_avatar_url: str | None = None
+ token_usage: TokenUsageSummary | None = None
model_config = ConfigDict(from_attributes=True)
diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py
index 63d8d10b9..1bf9e2386 100644
--- a/surfsense_backend/app/services/llm_router_service.py
+++ b/surfsense_backend/app/services/llm_router_service.py
@@ -820,7 +820,9 @@ class ChatLiteLLMRouter(BaseChatModel):
)
# Convert response to ChatResult with potential tool calls
- message = self._convert_response_to_message(response.choices[0].message)
+ message = self._convert_response_to_message(
+ response.choices[0].message, response=response
+ )
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@@ -886,7 +888,9 @@ class ChatLiteLLMRouter(BaseChatModel):
)
# Convert response to ChatResult with potential tool calls
- message = self._convert_response_to_message(response.choices[0].message)
+ message = self._convert_response_to_message(
+ response.choices[0].message, response=response
+ )
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@@ -970,6 +974,7 @@ class ChatLiteLLMRouter(BaseChatModel):
messages=formatted_messages,
stop=stop,
stream=True,
+ stream_options={"include_usage": True},
**call_kwargs,
)
except ContextWindowExceededError as e:
@@ -1075,7 +1080,9 @@ class ChatLiteLLMRouter(BaseChatModel):
return result
- def _convert_response_to_message(self, response_message: Any) -> AIMessage:
+ def _convert_response_to_message(
+ self, response_message: Any, response: Any = None
+ ) -> AIMessage:
"""Convert a LiteLLM response message to a LangChain AIMessage."""
import json
@@ -1098,9 +1105,22 @@ class ChatLiteLLMRouter(BaseChatModel):
tool_call["args"] = tc.function.arguments
tool_calls.append(tool_call)
+ extra_kwargs: dict[str, Any] = {}
+ if response:
+ usage = getattr(response, "usage", None)
+ if usage:
+ extra_kwargs["usage_metadata"] = {
+ "input_tokens": getattr(usage, "prompt_tokens", 0) or 0,
+ "output_tokens": getattr(usage, "completion_tokens", 0) or 0,
+ "total_tokens": getattr(usage, "total_tokens", 0) or 0,
+ }
+ extra_kwargs["response_metadata"] = {
+ "model_name": getattr(response, "model", "unknown"),
+ }
+
if tool_calls:
- return AIMessage(content=content, tool_calls=tool_calls)
- return AIMessage(content=content)
+ return AIMessage(content=content, tool_calls=tool_calls, **extra_kwargs)
+ return AIMessage(content=content, **extra_kwargs)
def _convert_delta_to_chunk(self, delta: Any) -> AIMessageChunk | None:
"""Convert a streaming delta to an AIMessageChunk."""
diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py
index 723b17607..d31e19ed3 100644
--- a/surfsense_backend/app/services/llm_service.py
+++ b/surfsense_backend/app/services/llm_service.py
@@ -15,6 +15,7 @@ from app.services.llm_router_service import (
get_auto_mode_llm,
is_auto_mode,
)
+from app.services.token_tracking_service import token_tracker
# Configure litellm to automatically drop unsupported parameters
litellm.drop_params = True
@@ -22,10 +23,11 @@ litellm.drop_params = True
# Memory controls: prevent unbounded internal accumulation
litellm.telemetry = False
litellm.cache = None
-litellm.success_callback = []
litellm.failure_callback = []
litellm.input_callback = []
+litellm.callbacks = [token_tracker]
+
logger = logging.getLogger(__name__)
diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py
new file mode 100644
index 000000000..9aa8c6e70
--- /dev/null
+++ b/surfsense_backend/app/services/token_tracking_service.py
@@ -0,0 +1,210 @@
+"""
+Token usage tracking via LiteLLM custom callback.
+
+Uses a ContextVar-scoped accumulator to group all LLM calls within a single
+async request/turn. The accumulated data is emitted via SSE and persisted
+when the frontend calls appendMessage.
+
+The module also provides ``record_token_usage``, a thin async helper that
+creates a ``TokenUsage`` row for *any* usage type (chat, indexing, image
+generation, podcasts, …). Call sites should prefer this helper over
+constructing ``TokenUsage`` manually so that logging and error handling
+stay consistent.
+"""
+
+from __future__ import annotations
+
+import dataclasses
+import logging
+from contextvars import ContextVar
+from dataclasses import dataclass, field
+from typing import Any
+from uuid import UUID
+
+from litellm.integrations.custom_logger import CustomLogger
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.db import TokenUsage
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class TokenCallRecord:
+ model: str
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+
+
+@dataclass
+class TurnTokenAccumulator:
+ """Accumulates token usage across all LLM calls within a single user turn."""
+
+ calls: list[TokenCallRecord] = field(default_factory=list)
+
+ def add(
+ self,
+ model: str,
+ prompt_tokens: int,
+ completion_tokens: int,
+ total_tokens: int,
+ ) -> None:
+ self.calls.append(
+ TokenCallRecord(
+ model=model,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ )
+ )
+
+ def per_message_summary(self) -> dict[str, dict[str, int]]:
+ """Return token counts grouped by model name."""
+ by_model: dict[str, dict[str, int]] = {}
+ for c in self.calls:
+ entry = by_model.setdefault(
+ c.model,
+ {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
+ )
+ entry["prompt_tokens"] += c.prompt_tokens
+ entry["completion_tokens"] += c.completion_tokens
+ entry["total_tokens"] += c.total_tokens
+ return by_model
+
+ @property
+ def grand_total(self) -> int:
+ return sum(c.total_tokens for c in self.calls)
+
+ @property
+ def total_prompt_tokens(self) -> int:
+ return sum(c.prompt_tokens for c in self.calls)
+
+ @property
+ def total_completion_tokens(self) -> int:
+ return sum(c.completion_tokens for c in self.calls)
+
+ def serialized_calls(self) -> list[dict[str, Any]]:
+ return [dataclasses.asdict(c) for c in self.calls]
+
+
+_turn_accumulator: ContextVar[TurnTokenAccumulator | None] = ContextVar(
+ "_turn_accumulator", default=None
+)
+
+
+def start_turn() -> TurnTokenAccumulator:
+ """Create a fresh accumulator for the current async context and return it."""
+ acc = TurnTokenAccumulator()
+ _turn_accumulator.set(acc)
+ logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
+ return acc
+
+
+def get_current_accumulator() -> TurnTokenAccumulator | None:
+ return _turn_accumulator.get()
+
+
+class TokenTrackingCallback(CustomLogger):
+ """LiteLLM callback that captures token usage into the turn accumulator."""
+
+ async def async_log_success_event(
+ self,
+ kwargs: dict[str, Any],
+ response_obj: Any,
+ start_time: Any,
+ end_time: Any,
+ ) -> None:
+ acc = _turn_accumulator.get()
+ if acc is None:
+ logger.debug(
+ "[TokenTracking] async_log_success_event fired but no accumulator in context"
+ )
+ return
+
+ usage = getattr(response_obj, "usage", None)
+ if not usage:
+ logger.debug(
+ "[TokenTracking] async_log_success_event fired but response has no usage data"
+ )
+ return
+
+ prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
+ completion_tokens = getattr(usage, "completion_tokens", 0) or 0
+ total_tokens = getattr(usage, "total_tokens", 0) or 0
+
+ model = kwargs.get("model", "unknown")
+
+ acc.add(
+ model=model,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ )
+ logger.info(
+ "[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
+ model,
+ prompt_tokens,
+ completion_tokens,
+ total_tokens,
+ len(acc.calls),
+ )
+
+
+token_tracker = TokenTrackingCallback()
+
+
+# ---------------------------------------------------------------------------
+# Persistence helper
+# ---------------------------------------------------------------------------
+
+
+async def record_token_usage(
+ session: AsyncSession,
+ *,
+ usage_type: str,
+ search_space_id: int,
+ user_id: UUID,
+ prompt_tokens: int = 0,
+ completion_tokens: int = 0,
+ total_tokens: int = 0,
+ model_breakdown: dict[str, Any] | None = None,
+ call_details: dict[str, Any] | None = None,
+ thread_id: int | None = None,
+ message_id: int | None = None,
+) -> TokenUsage | None:
+ """Persist a single ``TokenUsage`` row.
+
+ Returns the record on success, ``None`` if persistence failed (the
+ failure is logged but never propagated so callers don't need to
+ wrap this in try/except).
+ """
+ try:
+ record = TokenUsage(
+ usage_type=usage_type,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ model_breakdown=model_breakdown,
+ call_details=call_details,
+ thread_id=thread_id,
+ message_id=message_id,
+ search_space_id=search_space_id,
+ user_id=user_id,
+ )
+ session.add(record)
+ logger.debug(
+ "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
+ usage_type,
+ prompt_tokens,
+ completion_tokens,
+ total_tokens,
+ )
+ return record
+ except Exception:
+ logger.warning(
+ "[TokenTracking] failed to record %s token usage",
+ usage_type,
+ exc_info=True,
+ )
+ return None
diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py
index 47a270568..478aa3671 100644
--- a/surfsense_backend/app/tasks/chat/stream_new_chat.py
+++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py
@@ -51,7 +51,7 @@ from app.db import (
async_session_maker,
shielded_async_session,
)
-from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
+from app.prompts import TITLE_GENERATION_PROMPT
from app.services.chat_session_state_service import (
clear_ai_responding,
set_ai_responding,
@@ -1171,6 +1171,10 @@ async def stream_new_chat(
_t_total = time.perf_counter()
log_system_snapshot("stream_new_chat_START")
+ from app.services.token_tracking_service import start_turn
+
+ accumulator = start_turn()
+
session = async_session_maker()
try:
# Mark AI as responding to this user for live collaboration
@@ -1456,22 +1460,71 @@ async def stream_new_chat(
)
is_first_response = (assistant_count_result.scalar() or 0) == 0
- title_task: asyncio.Task[str | None] | None = None
+ title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None
if is_first_response:
- async def _generate_title() -> str | None:
+ async def _generate_title() -> tuple[str | None, dict | None]:
+ """Generate a short title via litellm.acompletion.
+
+ Returns (title, usage_dict). Usage is extracted directly from
+ the response object because litellm fires its async callback
+ via fire-and-forget ``create_task``, so the
+ ``TokenTrackingCallback`` would run too late. We also blank
+ the accumulator in this child-task context so the late callback
+ doesn't double-count.
+ """
try:
- title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm
- title_result = await title_chain.ainvoke(
- {"user_query": user_query[:500]}
+ from litellm import acompletion
+
+ from app.services.llm_router_service import LLMRouterService
+ from app.services.token_tracking_service import _turn_accumulator
+
+ _turn_accumulator.set(None)
+
+ prompt = TITLE_GENERATION_PROMPT.replace(
+ "{user_query}", user_query[:500]
)
- if title_result and hasattr(title_result, "content"):
- raw_title = title_result.content.strip()
- if raw_title and len(raw_title) <= 100:
- return raw_title.strip("\"'")
+ messages = [{"role": "user", "content": prompt}]
+
+ if getattr(llm, "model", None) == "auto":
+ router = LLMRouterService.get_router()
+ response = await router.acompletion(
+ model="auto", messages=messages
+ )
+ else:
+ response = await acompletion(
+ model=llm.model,
+ messages=messages,
+ api_key=getattr(llm, "api_key", None),
+ api_base=getattr(llm, "api_base", None),
+ )
+
+ usage_info = None
+ usage = getattr(response, "usage", None)
+ if usage:
+ raw_model = getattr(llm, "model", "") or ""
+ model_name = (
+ raw_model.split("/", 1)[-1]
+ if "/" in raw_model
+ else (raw_model or response.model or "unknown")
+ )
+ usage_info = {
+ "model": model_name,
+ "prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
+ "completion_tokens": getattr(usage, "completion_tokens", 0)
+ or 0,
+ "total_tokens": getattr(usage, "total_tokens", 0) or 0,
+ }
+
+ raw_title = response.choices[0].message.content.strip()
+ if raw_title and len(raw_title) <= 100:
+ return raw_title.strip("\"'"), usage_info
+ return None, usage_info
except Exception:
- pass
- return None
+ logging.getLogger(__name__).exception(
+ "[TitleGen] _generate_title failed"
+ )
+ return None, None
title_task = asyncio.create_task(_generate_title())
@@ -1503,7 +1556,9 @@ async def stream_new_chat(
# Inject title update mid-stream as soon as the background task finishes
if title_task is not None and title_task.done() and not title_emitted:
- generated_title = title_task.result()
+ generated_title, title_usage = title_task.result()
+ if title_usage:
+ accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
@@ -1528,6 +1583,26 @@ async def stream_new_chat(
if stream_result.is_interrupted:
if title_task is not None and not title_task.done():
title_task.cancel()
+
+ usage_summary = accumulator.per_message_summary()
+ _perf_log.info(
+ "[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
+ len(accumulator.calls),
+ accumulator.grand_total,
+ usage_summary,
+ )
+ if usage_summary:
+ yield streaming_service.format_data(
+ "token-usage",
+ {
+ "usage": usage_summary,
+ "prompt_tokens": accumulator.total_prompt_tokens,
+ "completion_tokens": accumulator.total_completion_tokens,
+ "total_tokens": accumulator.grand_total,
+ "call_details": accumulator.serialized_calls(),
+ },
+ )
+
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
@@ -1535,7 +1610,9 @@ async def stream_new_chat(
# If the title task didn't finish during streaming, await it now
if title_task is not None and not title_emitted:
- generated_title = await title_task
+ generated_title, title_usage = await title_task
+ if title_usage:
+ accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
@@ -1549,6 +1626,25 @@ async def stream_new_chat(
chat_id, generated_title
)
+ usage_summary = accumulator.per_message_summary()
+ _perf_log.info(
+ "[token_usage] normal new_chat: calls=%d total=%d summary=%s",
+ len(accumulator.calls),
+ accumulator.grand_total,
+ usage_summary,
+ )
+ if usage_summary:
+ yield streaming_service.format_data(
+ "token-usage",
+ {
+ "usage": usage_summary,
+ "prompt_tokens": accumulator.total_prompt_tokens,
+ "completion_tokens": accumulator.total_completion_tokens,
+ "total_tokens": accumulator.grand_total,
+ "call_details": accumulator.serialized_calls(),
+ },
+ )
+
# Fire background memory extraction if the agent didn't handle it.
# Shared threads write to team memory; private threads write to user memory.
if not stream_result.agent_called_update_memory:
@@ -1666,6 +1762,10 @@ async def stream_resume_chat(
stream_result = StreamResult()
_t_total = time.perf_counter()
+ from app.services.token_tracking_service import start_turn
+
+ accumulator = start_turn()
+
session = async_session_maker()
try:
if user_id:
@@ -1789,11 +1889,49 @@ async def stream_resume_chat(
chat_id,
)
if stream_result.is_interrupted:
+ usage_summary = accumulator.per_message_summary()
+ _perf_log.info(
+ "[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s",
+ len(accumulator.calls),
+ accumulator.grand_total,
+ usage_summary,
+ )
+ if usage_summary:
+ yield streaming_service.format_data(
+ "token-usage",
+ {
+ "usage": usage_summary,
+ "prompt_tokens": accumulator.total_prompt_tokens,
+ "completion_tokens": accumulator.total_completion_tokens,
+ "total_tokens": accumulator.grand_total,
+ "call_details": accumulator.serialized_calls(),
+ },
+ )
+
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
return
+ usage_summary = accumulator.per_message_summary()
+ _perf_log.info(
+ "[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
+ len(accumulator.calls),
+ accumulator.grand_total,
+ usage_summary,
+ )
+ if usage_summary:
+ yield streaming_service.format_data(
+ "token-usage",
+ {
+ "usage": usage_summary,
+ "prompt_tokens": accumulator.total_prompt_tokens,
+ "completion_tokens": accumulator.total_completion_tokens,
+ "total_tokens": accumulator.grand_total,
+ "call_details": accumulator.serialized_calls(),
+ },
+ )
+
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
diff --git a/surfsense_web/app/(home)/login/LocalLoginForm.tsx b/surfsense_web/app/(home)/login/LocalLoginForm.tsx
index 07a4db4d3..e3c34306f 100644
--- a/surfsense_web/app/(home)/login/LocalLoginForm.tsx
+++ b/surfsense_web/app/(home)/login/LocalLoginForm.tsx
@@ -174,31 +174,31 @@ export function LocalLoginForm() {
-
- setPassword(e.target.value)}
- className={`mt-1 block w-full rounded-md border pr-10 px-3 py-1.5 md:py-2 shadow-sm focus:outline-none focus:ring-1 bg-background text-foreground transition-all ${
- error.title
- ? "border-destructive focus:border-destructive focus:ring-destructive/40"
- : "border-border focus:border-primary focus:ring-primary/40"
- }`}
- disabled={isLoggingIn}
- />
-
-
+
+ setPassword(e.target.value)}
+ className={`block w-full rounded-md border pr-10 px-3 py-1.5 md:py-2 shadow-sm focus:outline-none focus:ring-1 bg-background text-foreground transition-all ${
+ error.title
+ ? "border-destructive focus:border-destructive focus:ring-destructive/40"
+ : "border-border focus:border-primary focus:ring-primary/40"
+ }`}
+ disabled={isLoggingIn}
+ />
+
+