mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
feat: implement token usage recording in chat routes and enhance title generation handling
This commit is contained in:
parent
292fcb1a2c
commit
f01ddf3f0a
3 changed files with 121 additions and 36 deletions
|
|
@ -51,7 +51,7 @@ from app.db import (
|
|||
async_session_maker,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.prompts import TITLE_GENERATION_PROMPT
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
|
|
@ -1460,34 +1460,58 @@ async def stream_new_chat(
|
|||
)
|
||||
is_first_response = (assistant_count_result.scalar() or 0) == 0
|
||||
|
||||
title_task: asyncio.Task[tuple[str | None, dict[str, int] | None]] | None = None
|
||||
title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None
|
||||
if is_first_response:
|
||||
|
||||
async def _generate_title() -> tuple[str | None, dict[str, int] | None]:
|
||||
"""Return (title, usage_dict) where usage_dict has model/prompt/completion/total."""
|
||||
async def _generate_title() -> tuple[str | None, dict | None]:
|
||||
"""Generate a short title via litellm.acompletion.
|
||||
|
||||
Returns (title, usage_dict). Usage is extracted directly from
|
||||
the response object because litellm fires its async callback
|
||||
via fire-and-forget ``create_task``, so the
|
||||
``TokenTrackingCallback`` would run too late. We also blank
|
||||
the accumulator in this child-task context so the late callback
|
||||
doesn't double-count.
|
||||
"""
|
||||
try:
|
||||
title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm
|
||||
title_result = await title_chain.ainvoke(
|
||||
{"user_query": user_query[:500]}
|
||||
)
|
||||
usage_dict: dict[str, int] | None = None
|
||||
if title_result:
|
||||
um = getattr(title_result, "usage_metadata", None)
|
||||
if um:
|
||||
rm = getattr(title_result, "response_metadata", None) or {}
|
||||
raw_model = rm.get("model_name", "unknown")
|
||||
usage_dict = {
|
||||
"model": raw_model.split("/", 1)[-1] if "/" in raw_model else raw_model,
|
||||
"prompt_tokens": um.get("input_tokens", 0),
|
||||
"completion_tokens": um.get("output_tokens", 0),
|
||||
"total_tokens": um.get("total_tokens", 0),
|
||||
}
|
||||
if hasattr(title_result, "content"):
|
||||
raw_title = title_result.content.strip()
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
return raw_title.strip("\"'"), usage_dict
|
||||
return None, usage_dict
|
||||
from litellm import acompletion
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
from app.services.token_tracking_service import _turn_accumulator
|
||||
|
||||
_turn_accumulator.set(None)
|
||||
|
||||
prompt = TITLE_GENERATION_PROMPT.replace("{user_query}", user_query[:500])
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
if getattr(llm, "model", None) == "auto":
|
||||
router = LLMRouterService.get_router()
|
||||
response = await router.acompletion(model="auto", messages=messages)
|
||||
else:
|
||||
response = await acompletion(
|
||||
model=llm.model,
|
||||
messages=messages,
|
||||
api_key=getattr(llm, "api_key", None),
|
||||
api_base=getattr(llm, "api_base", None),
|
||||
)
|
||||
|
||||
usage_info = None
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage:
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
model_name = raw_model.split("/", 1)[-1] if "/" in raw_model else (raw_model or response.model or "unknown")
|
||||
usage_info = {
|
||||
"model": model_name,
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0) or 0,
|
||||
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
|
||||
}
|
||||
|
||||
raw_title = response.choices[0].message.content.strip()
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
return raw_title.strip("\"'"), usage_info
|
||||
return None, usage_info
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("[TitleGen] _generate_title failed")
|
||||
return None, None
|
||||
|
||||
title_task = asyncio.create_task(_generate_title())
|
||||
|
|
@ -1520,7 +1544,9 @@ async def stream_new_chat(
|
|||
|
||||
# Inject title update mid-stream as soon as the background task finishes
|
||||
if title_task is not None and title_task.done() and not title_emitted:
|
||||
generated_title, _title_usage = title_task.result()
|
||||
generated_title, title_usage = title_task.result()
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
|
|
@ -1567,7 +1593,9 @@ async def stream_new_chat(
|
|||
|
||||
# If the title task didn't finish during streaming, await it now
|
||||
if title_task is not None and not title_emitted:
|
||||
generated_title, _title_usage = await title_task
|
||||
generated_title, title_usage = await title_task
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue