diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index d97665f7a..1bf9e2386 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -820,7 +820,9 @@ class ChatLiteLLMRouter(BaseChatModel): ) # Convert response to ChatResult with potential tool calls - message = self._convert_response_to_message(response.choices[0].message) + message = self._convert_response_to_message( + response.choices[0].message, response=response + ) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) @@ -886,7 +888,9 @@ class ChatLiteLLMRouter(BaseChatModel): ) # Convert response to ChatResult with potential tool calls - message = self._convert_response_to_message(response.choices[0].message) + message = self._convert_response_to_message( + response.choices[0].message, response=response + ) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) @@ -1076,7 +1080,9 @@ class ChatLiteLLMRouter(BaseChatModel): return result - def _convert_response_to_message(self, response_message: Any) -> AIMessage: + def _convert_response_to_message( + self, response_message: Any, response: Any = None + ) -> AIMessage: """Convert a LiteLLM response message to a LangChain AIMessage.""" import json @@ -1099,9 +1105,22 @@ class ChatLiteLLMRouter(BaseChatModel): tool_call["args"] = tc.function.arguments tool_calls.append(tool_call) + extra_kwargs: dict[str, Any] = {} + if response: + usage = getattr(response, "usage", None) + if usage: + extra_kwargs["usage_metadata"] = { + "input_tokens": getattr(usage, "prompt_tokens", 0) or 0, + "output_tokens": getattr(usage, "completion_tokens", 0) or 0, + "total_tokens": getattr(usage, "total_tokens", 0) or 0, + } + extra_kwargs["response_metadata"] = { + "model_name": getattr(response, "model", "unknown"), + } + if tool_calls: - return AIMessage(content=content, tool_calls=tool_calls) - return AIMessage(content=content) + return AIMessage(content=content, tool_calls=tool_calls, **extra_kwargs) + return AIMessage(content=content, **extra_kwargs) def _convert_delta_to_chunk(self, delta: Any) -> AIMessageChunk | None: """Convert a streaming delta to an AIMessageChunk.""" diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py index 98cb13bb8..6a5b3793f 100644 --- a/surfsense_backend/app/services/token_tracking_service.py +++ b/surfsense_backend/app/services/token_tracking_service.py @@ -4,6 +4,10 @@ Token usage tracking via LiteLLM custom callback. Uses a ContextVar-scoped accumulator to group all LLM calls within a single async request/turn. The accumulated data is emitted via SSE and persisted when the frontend calls appendMessage. + +Agent LLM calls are captured automatically via the async callback. +Title-generation usage is added explicitly from the LangChain response +metadata to avoid callback-timing issues. """ from __future__ import annotations diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 2002e1585..364a14bad 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1459,22 +1459,35 @@ async def stream_new_chat( ) is_first_response = (assistant_count_result.scalar() or 0) == 0 - title_task: asyncio.Task[str | None] | None = None + title_task: asyncio.Task[tuple[str | None, dict[str, int] | None]] | None = None if is_first_response: - async def _generate_title() -> str | None: + async def _generate_title() -> tuple[str | None, dict[str, int] | None]: + """Return (title, usage_dict) where usage_dict has model/prompt/completion/total.""" try: title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm title_result = await title_chain.ainvoke( {"user_query": user_query[:500]} ) - if title_result and hasattr(title_result, "content"): - raw_title = title_result.content.strip() - if raw_title and len(raw_title) <= 100: - return raw_title.strip("\"'") + usage_dict: dict[str, int] | None = None + if title_result: + um = getattr(title_result, "usage_metadata", None) + if um: + rm = getattr(title_result, "response_metadata", None) or {} + raw_model = rm.get("model_name", "unknown") + usage_dict = { + "model": raw_model.split("/", 1)[-1] if "/" in raw_model else raw_model, + "prompt_tokens": um.get("input_tokens", 0), + "completion_tokens": um.get("output_tokens", 0), + "total_tokens": um.get("total_tokens", 0), + } + if hasattr(title_result, "content"): + raw_title = title_result.content.strip() + if raw_title and len(raw_title) <= 100: + return raw_title.strip("\"'"), usage_dict + return None, usage_dict except Exception: - pass - return None + return None, None title_task = asyncio.create_task(_generate_title()) @@ -1506,7 +1519,7 @@ async def stream_new_chat( # Inject title update mid-stream as soon as the background task finishes if title_task is not None and title_task.done() and not title_emitted: - generated_title = title_task.result() + generated_title, _title_usage = title_task.result() if generated_title: async with shielded_async_session() as title_session: title_thread_result = await title_session.execute( @@ -1532,7 +1545,6 @@ async def stream_new_chat( if title_task is not None and not title_task.done(): title_task.cancel() - await asyncio.sleep(0.2) usage_summary = accumulator.per_message_summary() _perf_log.info( "[token_usage] interrupted new_chat: calls=%d total=%d summary=%s", @@ -1554,7 +1566,7 @@ async def stream_new_chat( # If the title task didn't finish during streaming, await it now if title_task is not None and not title_emitted: - generated_title = await title_task + generated_title, _title_usage = await title_task if generated_title: async with shielded_async_session() as title_session: title_thread_result = await title_session.execute( @@ -1568,7 +1580,6 @@ async def stream_new_chat( chat_id, generated_title ) - await asyncio.sleep(0.2) usage_summary = accumulator.per_message_summary() _perf_log.info( "[token_usage] normal new_chat: calls=%d total=%d summary=%s", @@ -1807,7 +1818,6 @@ async def stream_resume_chat( time.perf_counter() - _t_stream_start, chat_id, ) - await asyncio.sleep(0.2) if stream_result.is_interrupted: usage_summary = accumulator.per_message_summary() _perf_log.info( diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index dff52c3f5..7a357dc85 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -1,4 +1,5 @@ import { + ActionBarMorePrimitive, ActionBarPrimitive, AuiIf, ErrorPrimitive, @@ -40,14 +41,7 @@ import { DrawerHeader, DrawerTitle, } from "@/components/ui/drawer"; -import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuItem, - DropdownMenuLabel, - DropdownMenuSeparator, - DropdownMenuTrigger, -} from "@/components/ui/dropdown-menu"; +import { DropdownMenuLabel } from "@/components/ui/dropdown-menu"; import { Button } from "@/components/ui/button"; import { useComments } from "@/hooks/use-comments"; import { useMediaQuery } from "@/hooks/use-media-query"; @@ -397,14 +391,17 @@ const MessageInfoDropdown: FC = () => { const hasUsage = usage && usage.total_tokens > 0; return ( - - + + - - + + {createdAt && ( {formatMessageDate(createdAt)} @@ -412,27 +409,27 @@ const MessageInfoDropdown: FC = () => { )} {hasUsage && ( <> - + {models.length > 0 ? ( models.map(([model, counts]) => ( - e.preventDefault()}> + e.preventDefault()}> {model} {counts.total_tokens.toLocaleString()} tokens - + )) ) : ( - e.preventDefault()}> + e.preventDefault()}> {usage.total_tokens.toLocaleString()} tokens - + )} )} - - + + ); };