mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-04 20:05:16 +02:00
feat: improve token usage tracking and response handling in chat routes and services
This commit is contained in:
parent
55099a20ac
commit
5af6005163
3 changed files with 39 additions and 4 deletions
|
|
@ -46,12 +46,12 @@ from app.schemas.new_chat import (
|
||||||
NewChatThreadWithMessages,
|
NewChatThreadWithMessages,
|
||||||
PublicChatSnapshotCreateResponse,
|
PublicChatSnapshotCreateResponse,
|
||||||
PublicChatSnapshotListResponse,
|
PublicChatSnapshotListResponse,
|
||||||
TokenUsageSummary,
|
|
||||||
RegenerateRequest,
|
RegenerateRequest,
|
||||||
ResumeRequest,
|
ResumeRequest,
|
||||||
ThreadHistoryLoadResponse,
|
ThreadHistoryLoadResponse,
|
||||||
ThreadListItem,
|
ThreadListItem,
|
||||||
ThreadListResponse,
|
ThreadListResponse,
|
||||||
|
TokenUsageSummary,
|
||||||
)
|
)
|
||||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
|
|
@ -965,9 +965,17 @@ async def append_message(
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# Return the in-memory object (already has id from flush) instead of
|
# Build response manually to avoid lazy-loading the token_usage
|
||||||
# doing an extra refresh() SELECT.
|
# relationship after commit (which would trigger MissingGreenlet).
|
||||||
return db_message
|
return NewChatMessageRead(
|
||||||
|
id=db_message.id,
|
||||||
|
thread_id=db_message.thread_id,
|
||||||
|
role=db_message.role,
|
||||||
|
content=db_message.content,
|
||||||
|
created_at=db_message.created_at,
|
||||||
|
author_id=db_message.author_id,
|
||||||
|
token_usage=None,
|
||||||
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|
@ -1031,6 +1039,7 @@ async def list_messages(
|
||||||
# Get messages
|
# Get messages
|
||||||
query = (
|
query = (
|
||||||
select(NewChatMessage)
|
select(NewChatMessage)
|
||||||
|
.options(selectinload(NewChatMessage.token_usage))
|
||||||
.filter(NewChatMessage.thread_id == thread_id)
|
.filter(NewChatMessage.thread_id == thread_id)
|
||||||
.order_by(NewChatMessage.created_at)
|
.order_by(NewChatMessage.created_at)
|
||||||
.offset(skip)
|
.offset(skip)
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,7 @@ def start_turn() -> TurnTokenAccumulator:
|
||||||
"""Create a fresh accumulator for the current async context and return it."""
|
"""Create a fresh accumulator for the current async context and return it."""
|
||||||
acc = TurnTokenAccumulator()
|
acc = TurnTokenAccumulator()
|
||||||
_turn_accumulator.set(acc)
|
_turn_accumulator.set(acc)
|
||||||
|
logger.info("[TokenTracking] start_turn: new accumulator created (id=%s)", id(acc))
|
||||||
return acc
|
return acc
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -106,10 +107,12 @@ class TokenTrackingCallback(CustomLogger):
|
||||||
) -> None:
|
) -> None:
|
||||||
acc = _turn_accumulator.get()
|
acc = _turn_accumulator.get()
|
||||||
if acc is None:
|
if acc is None:
|
||||||
|
logger.debug("[TokenTracking] async_log_success_event fired but no accumulator in context")
|
||||||
return
|
return
|
||||||
|
|
||||||
usage = getattr(response_obj, "usage", None)
|
usage = getattr(response_obj, "usage", None)
|
||||||
if not usage:
|
if not usage:
|
||||||
|
logger.debug("[TokenTracking] async_log_success_event fired but response has no usage data")
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||||
|
|
@ -124,6 +127,10 @@ class TokenTrackingCallback(CustomLogger):
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
|
||||||
|
model, prompt_tokens, completion_tokens, total_tokens, len(acc.calls),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
token_tracker = TokenTrackingCallback()
|
token_tracker = TokenTrackingCallback()
|
||||||
|
|
|
||||||
|
|
@ -1532,7 +1532,12 @@ async def stream_new_chat(
|
||||||
if title_task is not None and not title_task.done():
|
if title_task is not None and not title_task.done():
|
||||||
title_task.cancel()
|
title_task.cancel()
|
||||||
|
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
usage_summary = accumulator.per_message_summary()
|
usage_summary = accumulator.per_message_summary()
|
||||||
|
_perf_log.info(
|
||||||
|
"[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
|
||||||
|
len(accumulator.calls), accumulator.grand_total, usage_summary,
|
||||||
|
)
|
||||||
if usage_summary:
|
if usage_summary:
|
||||||
yield streaming_service.format_data("token-usage", {
|
yield streaming_service.format_data("token-usage", {
|
||||||
"usage": usage_summary,
|
"usage": usage_summary,
|
||||||
|
|
@ -1563,7 +1568,12 @@ async def stream_new_chat(
|
||||||
chat_id, generated_title
|
chat_id, generated_title
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
usage_summary = accumulator.per_message_summary()
|
usage_summary = accumulator.per_message_summary()
|
||||||
|
_perf_log.info(
|
||||||
|
"[token_usage] normal new_chat: calls=%d total=%d summary=%s",
|
||||||
|
len(accumulator.calls), accumulator.grand_total, usage_summary,
|
||||||
|
)
|
||||||
if usage_summary:
|
if usage_summary:
|
||||||
yield streaming_service.format_data("token-usage", {
|
yield streaming_service.format_data("token-usage", {
|
||||||
"usage": usage_summary,
|
"usage": usage_summary,
|
||||||
|
|
@ -1797,8 +1807,13 @@ async def stream_resume_chat(
|
||||||
time.perf_counter() - _t_stream_start,
|
time.perf_counter() - _t_stream_start,
|
||||||
chat_id,
|
chat_id,
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
if stream_result.is_interrupted:
|
if stream_result.is_interrupted:
|
||||||
usage_summary = accumulator.per_message_summary()
|
usage_summary = accumulator.per_message_summary()
|
||||||
|
_perf_log.info(
|
||||||
|
"[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s",
|
||||||
|
len(accumulator.calls), accumulator.grand_total, usage_summary,
|
||||||
|
)
|
||||||
if usage_summary:
|
if usage_summary:
|
||||||
yield streaming_service.format_data("token-usage", {
|
yield streaming_service.format_data("token-usage", {
|
||||||
"usage": usage_summary,
|
"usage": usage_summary,
|
||||||
|
|
@ -1814,6 +1829,10 @@ async def stream_resume_chat(
|
||||||
return
|
return
|
||||||
|
|
||||||
usage_summary = accumulator.per_message_summary()
|
usage_summary = accumulator.per_message_summary()
|
||||||
|
_perf_log.info(
|
||||||
|
"[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
|
||||||
|
len(accumulator.calls), accumulator.grand_total, usage_summary,
|
||||||
|
)
|
||||||
if usage_summary:
|
if usage_summary:
|
||||||
yield streaming_service.format_data("token-usage", {
|
yield streaming_service.format_data("token-usage", {
|
||||||
"usage": usage_summary,
|
"usage": usage_summary,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue