diff --git a/surfsense_backend/alembic/versions/125_add_token_usage_table.py b/surfsense_backend/alembic/versions/125_add_token_usage_table.py index c08280487..915561c8c 100644 --- a/surfsense_backend/alembic/versions/125_add_token_usage_table.py +++ b/surfsense_backend/alembic/versions/125_add_token_usage_table.py @@ -33,7 +33,9 @@ def upgrade() -> None: "token_usage", sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), sa.Column("prompt_tokens", sa.Integer(), nullable=False, server_default="0"), - sa.Column("completion_tokens", sa.Integer(), nullable=False, server_default="0"), + sa.Column( + "completion_tokens", sa.Integer(), nullable=False, server_default="0" + ), sa.Column("total_tokens", sa.Integer(), nullable=False, server_default="0"), sa.Column("model_breakdown", JSONB, nullable=True), sa.Column("call_details", JSONB, nullable=True), @@ -72,7 +74,9 @@ def upgrade() -> None: op.create_index("ix_token_usage_thread_id", "token_usage", ["thread_id"]) op.create_index("ix_token_usage_message_id", "token_usage", ["message_id"]) - op.create_index("ix_token_usage_search_space_id", "token_usage", ["search_space_id"]) + op.create_index( + "ix_token_usage_search_space_id", "token_usage", ["search_space_id"] + ) op.create_index("ix_token_usage_user_id", "token_usage", ["user_id"]) op.create_index("ix_token_usage_usage_type", "token_usage", ["usage_type"]) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 55302b873..b914b297e 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -498,7 +498,9 @@ async def get_thread_messages( author_id=msg.author_id, author_display_name=msg.author.display_name if msg.author else None, author_avatar_url=msg.author.avatar_url if msg.author else None, - token_usage=TokenUsageSummary.model_validate(msg.token_usage) if msg.token_usage else None, + token_usage=TokenUsageSummary.model_validate(msg.token_usage) + if msg.token_usage + else None, ) for msg in db_messages ] diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index c90bdfce3..d31e19ed3 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -15,6 +15,7 @@ from app.services.llm_router_service import ( get_auto_mode_llm, is_auto_mode, ) +from app.services.token_tracking_service import token_tracker # Configure litellm to automatically drop unsupported parameters litellm.drop_params = True @@ -25,8 +26,6 @@ litellm.cache = None litellm.failure_callback = [] litellm.input_callback = [] -from app.services.token_tracking_service import token_tracker - litellm.callbacks = [token_tracker] logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/services/token_tracking_service.py b/surfsense_backend/app/services/token_tracking_service.py index 5d69e6870..9aa8c6e70 100644 --- a/surfsense_backend/app/services/token_tracking_service.py +++ b/surfsense_backend/app/services/token_tracking_service.py @@ -117,12 +117,16 @@ class TokenTrackingCallback(CustomLogger): ) -> None: acc = _turn_accumulator.get() if acc is None: - logger.debug("[TokenTracking] async_log_success_event fired but no accumulator in context") + logger.debug( + "[TokenTracking] async_log_success_event fired but no accumulator in context" + ) return usage = getattr(response_obj, "usage", None) if not usage: - logger.debug("[TokenTracking] async_log_success_event fired but response has no usage data") + logger.debug( + "[TokenTracking] async_log_success_event fired but response has no usage data" + ) return prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 @@ -139,7 +143,11 @@ class TokenTrackingCallback(CustomLogger): ) logger.info( "[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)", - model, prompt_tokens, completion_tokens, total_tokens, len(acc.calls), + model, + prompt_tokens, + completion_tokens, + total_tokens, + len(acc.calls), ) @@ -187,11 +195,16 @@ async def record_token_usage( session.add(record) logger.debug( "[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d", - usage_type, prompt_tokens, completion_tokens, total_tokens, + usage_type, + prompt_tokens, + completion_tokens, + total_tokens, ) return record except Exception: logger.warning( - "[TokenTracking] failed to record %s token usage", usage_type, exc_info=True, + "[TokenTracking] failed to record %s token usage", + usage_type, + exc_info=True, ) return None diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index e87a1b791..478aa3671 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1475,17 +1475,22 @@ async def stream_new_chat( """ try: from litellm import acompletion + from app.services.llm_router_service import LLMRouterService from app.services.token_tracking_service import _turn_accumulator _turn_accumulator.set(None) - prompt = TITLE_GENERATION_PROMPT.replace("{user_query}", user_query[:500]) + prompt = TITLE_GENERATION_PROMPT.replace( + "{user_query}", user_query[:500] + ) messages = [{"role": "user", "content": prompt}] if getattr(llm, "model", None) == "auto": router = LLMRouterService.get_router() - response = await router.acompletion(model="auto", messages=messages) + response = await router.acompletion( + model="auto", messages=messages + ) else: response = await acompletion( model=llm.model, @@ -1498,11 +1503,16 @@ async def stream_new_chat( usage = getattr(response, "usage", None) if usage: raw_model = getattr(llm, "model", "") or "" - model_name = raw_model.split("/", 1)[-1] if "/" in raw_model else (raw_model or response.model or "unknown") + model_name = ( + raw_model.split("/", 1)[-1] + if "/" in raw_model + else (raw_model or response.model or "unknown") + ) usage_info = { "model": model_name, "prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0, - "completion_tokens": getattr(usage, "completion_tokens", 0) or 0, + "completion_tokens": getattr(usage, "completion_tokens", 0) + or 0, "total_tokens": getattr(usage, "total_tokens", 0) or 0, } @@ -1511,7 +1521,9 @@ async def stream_new_chat( return raw_title.strip("\"'"), usage_info return None, usage_info except Exception: - logging.getLogger(__name__).exception("[TitleGen] _generate_title failed") + logging.getLogger(__name__).exception( + "[TitleGen] _generate_title failed" + ) return None, None title_task = asyncio.create_task(_generate_title()) @@ -1575,16 +1587,21 @@ async def stream_new_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( "[token_usage] interrupted new_chat: calls=%d total=%d summary=%s", - len(accumulator.calls), accumulator.grand_total, usage_summary, + len(accumulator.calls), + accumulator.grand_total, + usage_summary, ) if usage_summary: - yield streaming_service.format_data("token-usage", { - "usage": usage_summary, - "prompt_tokens": accumulator.total_prompt_tokens, - "completion_tokens": accumulator.total_completion_tokens, - "total_tokens": accumulator.grand_total, - "call_details": accumulator.serialized_calls(), - }) + yield streaming_service.format_data( + "token-usage", + { + "usage": usage_summary, + "prompt_tokens": accumulator.total_prompt_tokens, + "completion_tokens": accumulator.total_completion_tokens, + "total_tokens": accumulator.grand_total, + "call_details": accumulator.serialized_calls(), + }, + ) yield streaming_service.format_finish_step() yield streaming_service.format_finish() @@ -1612,16 +1629,21 @@ async def stream_new_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( "[token_usage] normal new_chat: calls=%d total=%d summary=%s", - len(accumulator.calls), accumulator.grand_total, usage_summary, + len(accumulator.calls), + accumulator.grand_total, + usage_summary, ) if usage_summary: - yield streaming_service.format_data("token-usage", { - "usage": usage_summary, - "prompt_tokens": accumulator.total_prompt_tokens, - "completion_tokens": accumulator.total_completion_tokens, - "total_tokens": accumulator.grand_total, - "call_details": accumulator.serialized_calls(), - }) + yield streaming_service.format_data( + "token-usage", + { + "usage": usage_summary, + "prompt_tokens": accumulator.total_prompt_tokens, + "completion_tokens": accumulator.total_completion_tokens, + "total_tokens": accumulator.grand_total, + "call_details": accumulator.serialized_calls(), + }, + ) # Fire background memory extraction if the agent didn't handle it. # Shared threads write to team memory; private threads write to user memory. @@ -1870,16 +1892,21 @@ async def stream_resume_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( "[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s", - len(accumulator.calls), accumulator.grand_total, usage_summary, + len(accumulator.calls), + accumulator.grand_total, + usage_summary, ) if usage_summary: - yield streaming_service.format_data("token-usage", { - "usage": usage_summary, - "prompt_tokens": accumulator.total_prompt_tokens, - "completion_tokens": accumulator.total_completion_tokens, - "total_tokens": accumulator.grand_total, - "call_details": accumulator.serialized_calls(), - }) + yield streaming_service.format_data( + "token-usage", + { + "usage": usage_summary, + "prompt_tokens": accumulator.total_prompt_tokens, + "completion_tokens": accumulator.total_completion_tokens, + "total_tokens": accumulator.grand_total, + "call_details": accumulator.serialized_calls(), + }, + ) yield streaming_service.format_finish_step() yield streaming_service.format_finish() @@ -1889,16 +1916,21 @@ async def stream_resume_chat( usage_summary = accumulator.per_message_summary() _perf_log.info( "[token_usage] normal resume_chat: calls=%d total=%d summary=%s", - len(accumulator.calls), accumulator.grand_total, usage_summary, + len(accumulator.calls), + accumulator.grand_total, + usage_summary, ) if usage_summary: - yield streaming_service.format_data("token-usage", { - "usage": usage_summary, - "prompt_tokens": accumulator.total_prompt_tokens, - "completion_tokens": accumulator.total_completion_tokens, - "total_tokens": accumulator.grand_total, - "call_details": accumulator.serialized_calls(), - }) + yield streaming_service.format_data( + "token-usage", + { + "usage": usage_summary, + "prompt_tokens": accumulator.total_prompt_tokens, + "completion_tokens": accumulator.total_completion_tokens, + "total_tokens": accumulator.grand_total, + "call_details": accumulator.serialized_calls(), + }, + ) yield streaming_service.format_finish_step() yield streaming_service.format_finish()