mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-01 20:03:30 +02:00
feat: implement token usage tracking for LLM calls with new accumulator and callback
This commit is contained in:
parent
917f35eb33
commit
3cfe53fb7f
6 changed files with 223 additions and 4 deletions
|
|
@ -30,6 +30,7 @@ from app.db import (
|
|||
NewChatThread,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
TokenUsage,
|
||||
User,
|
||||
get_async_session,
|
||||
shielded_async_session,
|
||||
|
|
@ -45,6 +46,7 @@ from app.schemas.new_chat import (
|
|||
NewChatThreadWithMessages,
|
||||
PublicChatSnapshotCreateResponse,
|
||||
PublicChatSnapshotListResponse,
|
||||
TokenUsageSummary,
|
||||
RegenerateRequest,
|
||||
ResumeRequest,
|
||||
ThreadHistoryLoadResponse,
|
||||
|
|
@ -473,10 +475,13 @@ async def get_thread_messages(
|
|||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
# Get messages with their authors loaded
|
||||
# Get messages with their authors and token usage loaded
|
||||
messages_result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.options(selectinload(NewChatMessage.author))
|
||||
.options(
|
||||
selectinload(NewChatMessage.author),
|
||||
selectinload(NewChatMessage.token_usage),
|
||||
)
|
||||
.filter(NewChatMessage.thread_id == thread_id)
|
||||
.order_by(NewChatMessage.created_at)
|
||||
)
|
||||
|
|
@ -493,6 +498,7 @@ async def get_thread_messages(
|
|||
author_id=msg.author_id,
|
||||
author_display_name=msg.author.display_name if msg.author else None,
|
||||
author_avatar_url=msg.author.avatar_url if msg.author else None,
|
||||
token_usage=TokenUsageSummary.model_validate(msg.token_usage) if msg.token_usage else None,
|
||||
)
|
||||
for msg in db_messages
|
||||
]
|
||||
|
|
@ -530,7 +536,11 @@ async def get_thread_full(
|
|||
try:
|
||||
result = await session.execute(
|
||||
select(NewChatThread)
|
||||
.options(selectinload(NewChatThread.messages))
|
||||
.options(
|
||||
selectinload(NewChatThread.messages).selectinload(
|
||||
NewChatMessage.token_usage
|
||||
),
|
||||
)
|
||||
.filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
|
|
@ -935,6 +945,24 @@ async def append_message(
|
|||
|
||||
# flush assigns the PK/defaults without a round-trip SELECT
|
||||
await session.flush()
|
||||
|
||||
# Persist token usage if provided (for assistant messages)
|
||||
token_usage_data = raw_body.get("token_usage")
|
||||
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
||||
token_usage_record = TokenUsage(
|
||||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||
model_breakdown=token_usage_data.get("usage"),
|
||||
call_details=token_usage_data.get("call_details"),
|
||||
usage_type="chat",
|
||||
thread_id=thread_id,
|
||||
message_id=db_message.id,
|
||||
search_space_id=thread.search_space_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(token_usage_record)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Return the in-memory object (already has id from flush) instead of
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue