feat: implement token usage tracking for LLM calls with new accumulator and callback

This commit is contained in:
Anish Sarkar 2026-04-14 13:40:32 +05:30
parent 917f35eb33
commit 3cfe53fb7f
6 changed files with 223 additions and 4 deletions

View file

@ -30,6 +30,7 @@ from app.db import (
NewChatThread,
Permission,
SearchSpace,
TokenUsage,
User,
get_async_session,
shielded_async_session,
@ -45,6 +46,7 @@ from app.schemas.new_chat import (
NewChatThreadWithMessages,
PublicChatSnapshotCreateResponse,
PublicChatSnapshotListResponse,
TokenUsageSummary,
RegenerateRequest,
ResumeRequest,
ThreadHistoryLoadResponse,
@ -473,10 +475,13 @@ async def get_thread_messages(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Get messages with their authors loaded
# Get messages with their authors and token usage loaded
messages_result = await session.execute(
select(NewChatMessage)
.options(selectinload(NewChatMessage.author))
.options(
selectinload(NewChatMessage.author),
selectinload(NewChatMessage.token_usage),
)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
)
@ -493,6 +498,7 @@ async def get_thread_messages(
author_id=msg.author_id,
author_display_name=msg.author.display_name if msg.author else None,
author_avatar_url=msg.author.avatar_url if msg.author else None,
token_usage=TokenUsageSummary.model_validate(msg.token_usage) if msg.token_usage else None,
)
for msg in db_messages
]
@ -530,7 +536,11 @@ async def get_thread_full(
try:
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.options(
selectinload(NewChatThread.messages).selectinload(
NewChatMessage.token_usage
),
)
.filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
@ -935,6 +945,24 @@ async def append_message(
# flush assigns the PK/defaults without a round-trip SELECT
await session.flush()
# Persist token usage if provided (for assistant messages)
token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
token_usage_record = TokenUsage(
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
usage_type="chat",
thread_id=thread_id,
message_id=db_message.id,
search_space_id=thread.search_space_id,
user_id=user.id,
)
session.add(token_usage_record)
await session.commit()
# Return the in-memory object (already has id from flush) instead of