mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
feat: improve token usage tracking and response handling in chat routes and services
This commit is contained in:
parent
55099a20ac
commit
5af6005163
3 changed files with 39 additions and 4 deletions
|
|
@ -46,12 +46,12 @@ from app.schemas.new_chat import (
|
|||
NewChatThreadWithMessages,
|
||||
PublicChatSnapshotCreateResponse,
|
||||
PublicChatSnapshotListResponse,
|
||||
TokenUsageSummary,
|
||||
RegenerateRequest,
|
||||
ResumeRequest,
|
||||
ThreadHistoryLoadResponse,
|
||||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
TokenUsageSummary,
|
||||
)
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||
from app.users import current_active_user
|
||||
|
|
@ -965,9 +965,17 @@ async def append_message(
|
|||
|
||||
await session.commit()
|
||||
|
||||
# Return the in-memory object (already has id from flush) instead of
|
||||
# doing an extra refresh() SELECT.
|
||||
return db_message
|
||||
# Build response manually to avoid lazy-loading the token_usage
|
||||
# relationship after commit (which would trigger MissingGreenlet).
|
||||
return NewChatMessageRead(
|
||||
id=db_message.id,
|
||||
thread_id=db_message.thread_id,
|
||||
role=db_message.role,
|
||||
content=db_message.content,
|
||||
created_at=db_message.created_at,
|
||||
author_id=db_message.author_id,
|
||||
token_usage=None,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -1031,6 +1039,7 @@ async def list_messages(
|
|||
# Get messages
|
||||
query = (
|
||||
select(NewChatMessage)
|
||||
.options(selectinload(NewChatMessage.token_usage))
|
||||
.filter(NewChatMessage.thread_id == thread_id)
|
||||
.order_by(NewChatMessage.created_at)
|
||||
.offset(skip)
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ def start_turn() -> TurnTokenAccumulator:
|
|||
"""Create a fresh accumulator for the current async context and return it."""
|
||||
acc = TurnTokenAccumulator()
|
||||
_turn_accumulator.set(acc)
|
||||
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
|
||||
return acc
|
||||
|
||||
|
||||
|
|
@ -106,10 +107,12 @@ class TokenTrackingCallback(CustomLogger):
|
|||
) -> None:
|
||||
acc = _turn_accumulator.get()
|
||||
if acc is None:
|
||||
logger.debug("[TokenTracking] async_log_success_event fired but no accumulator in context")
|
||||
return
|
||||
|
||||
usage = getattr(response_obj, "usage", None)
|
||||
if not usage:
|
||||
logger.debug("[TokenTracking] async_log_success_event fired but response has no usage data")
|
||||
return
|
||||
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
|
|
@ -124,6 +127,10 @@ class TokenTrackingCallback(CustomLogger):
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
logger.info(
|
||||
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
|
||||
model, prompt_tokens, completion_tokens, total_tokens, len(acc.calls),
|
||||
)
|
||||
|
||||
|
||||
token_tracker = TokenTrackingCallback()
|
||||
|
|
|
|||
|
|
@ -1532,7 +1532,12 @@ async def stream_new_chat(
|
|||
if title_task is not None and not title_task.done():
|
||||
title_task.cancel()
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
|
||||
len(accumulator.calls), accumulator.grand_total, usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data("token-usage", {
|
||||
"usage": usage_summary,
|
||||
|
|
@ -1563,7 +1568,12 @@ async def stream_new_chat(
|
|||
chat_id, generated_title
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] normal new_chat: calls=%d total=%d summary=%s",
|
||||
len(accumulator.calls), accumulator.grand_total, usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data("token-usage", {
|
||||
"usage": usage_summary,
|
||||
|
|
@ -1797,8 +1807,13 @@ async def stream_resume_chat(
|
|||
time.perf_counter() - _t_stream_start,
|
||||
chat_id,
|
||||
)
|
||||
await asyncio.sleep(0.2)
|
||||
if stream_result.is_interrupted:
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s",
|
||||
len(accumulator.calls), accumulator.grand_total, usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data("token-usage", {
|
||||
"usage": usage_summary,
|
||||
|
|
@ -1814,6 +1829,10 @@ async def stream_resume_chat(
|
|||
return
|
||||
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
|
||||
len(accumulator.calls), accumulator.grand_total, usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data("token-usage", {
|
||||
"usage": usage_summary,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue