From 5af6005163d1b9b749ae8e32f25c87a0234768e6 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Tue, 14 Apr 2026 14:28:31 +0530 Subject: [PATCH] feat: improve token usage tracking and response handling in chat routes and services --- .../app/routes/new_chat_routes.py | 17 +++++++++++++---- .../app/services/token_tracking_service.py | 7 +++++++ .../app/tasks/chat/stream_new_chat.py | 19 +++++++++++++++++++ 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index a5245456e..fe79c7c06 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -46,12 +46,12 @@ from app.schemas.new_chat import ( NewChatThreadWithMessages, PublicChatSnapshotCreateResponse, PublicChatSnapshotListResponse, - TokenUsageSummary, RegenerateRequest, ResumeRequest, ThreadHistoryLoadResponse, ThreadListItem, ThreadListResponse, + TokenUsageSummary, ) from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat from app.users import current_active_user @@ -965,9 +965,17 @@ async def append_message( await session.commit() - # Return the in-memory object (already has id from flush) instead of - # doing an extra refresh() SELECT. - return db_message + # Build response manually to avoid lazy-loading the token_usage + # relationship after commit (which would trigger MissingGreenlet). + return NewChatMessageRead( + id=db_message.id, + thread_id=db_message.thread_id, + role=db_message.role, + content=db_message.content, + created_at=db_message.created_at, + author_id=db_message.author_id, + token_usage=None, + ) except HTTPException: raise @@ -1031,6 +1039,7 @@ async def list_messages( # Get messages query = ( select(NewChatMessage) + .options(selectinload(NewChatMessage.token_usage)) .filter(NewChatMessage.thread_id == thread_id) .order_by(NewChatMessage.created_at) .offset(skip) diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py index 434a55ae0..98cb13bb8 100644 --- a/surfsense_backend/app/services/token_tracking_service.py +++ b/surfsense_backend/app/services/token_tracking_service.py @@ -87,6 +87,7 @@ def start_turn() -> TurnTokenAccumulator: """Create a fresh accumulator for the current async context and return it.""" acc = TurnTokenAccumulator() _turn_accumulator.set(acc) + logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc)) return acc @@ -106,10 +107,12 @@ class TokenTrackingCallback(CustomLogger): ) -> None: acc = _turn_accumulator.get() if acc is None: + logger.debug("[TokenTracking] async_log_success_event fired but no accumulator in context") return usage = getattr(response_obj, "usage", None) if not usage: + logger.debug("[TokenTracking] async_log_success_event fired but response has no usage data") return prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 @@ -124,6 +127,10 @@ class TokenTrackingCallback(CustomLogger): completion_tokens=completion_tokens, total_tokens=total_tokens, ) + logger.info( + "[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)", + model, prompt_tokens, completion_tokens, total_tokens, len(acc.calls), + ) token_tracker = TokenTrackingCallback() diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 4459b9c06..2002e1585 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1532,7 +1532,12 @@ async def stream_new_chat( if title_task is not None and not title_task.done(): title_task.cancel() + await asyncio.sleep(0.2) usage_summary = accumulator.per_message_summary() + _perf_log.info( + "[token_usage] interrupted new_chat: calls=%d total=%d summary=%s", + len(accumulator.calls), accumulator.grand_total, usage_summary, + ) if usage_summary: yield streaming_service.format_data("token-usage", { "usage": usage_summary, @@ -1563,7 +1568,12 @@ async def stream_new_chat( chat_id, generated_title ) + await asyncio.sleep(0.2) usage_summary = accumulator.per_message_summary() + _perf_log.info( + "[token_usage] normal new_chat: calls=%d total=%d summary=%s", + len(accumulator.calls), accumulator.grand_total, usage_summary, + ) if usage_summary: yield streaming_service.format_data("token-usage", { "usage": usage_summary, @@ -1797,8 +1807,13 @@ async def stream_resume_chat( time.perf_counter() - _t_stream_start, chat_id, ) + await asyncio.sleep(0.2) if stream_result.is_interrupted: usage_summary = accumulator.per_message_summary() + _perf_log.info( + "[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s", + len(accumulator.calls), accumulator.grand_total, usage_summary, + ) if usage_summary: yield streaming_service.format_data("token-usage", { "usage": usage_summary, @@ -1814,6 +1829,10 @@ async def stream_resume_chat( return usage_summary = accumulator.per_message_summary() + _perf_log.info( + "[token_usage] normal resume_chat: calls=%d total=%d summary=%s", + len(accumulator.calls), accumulator.grand_total, usage_summary, + ) if usage_summary: yield streaming_service.format_data("token-usage", { "usage": usage_summary,