mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-10 20:35:17 +02:00
feat: enhance LLM response handling and token usage tracking in chat services and UI components
This commit is contained in:
parent
5510c1de03
commit
f21bdc0668
4 changed files with 67 additions and 37 deletions
|
|
@ -820,7 +820,9 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
)
|
||||
|
||||
# Convert response to ChatResult with potential tool calls
|
||||
message = self._convert_response_to_message(response.choices[0].message)
|
||||
message = self._convert_response_to_message(
|
||||
response.choices[0].message, response=response
|
||||
)
|
||||
generation = ChatGeneration(message=message)
|
||||
|
||||
return ChatResult(generations=[generation])
|
||||
|
|
@ -886,7 +888,9 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
)
|
||||
|
||||
# Convert response to ChatResult with potential tool calls
|
||||
message = self._convert_response_to_message(response.choices[0].message)
|
||||
message = self._convert_response_to_message(
|
||||
response.choices[0].message, response=response
|
||||
)
|
||||
generation = ChatGeneration(message=message)
|
||||
|
||||
return ChatResult(generations=[generation])
|
||||
|
|
@ -1076,7 +1080,9 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
|
||||
return result
|
||||
|
||||
def _convert_response_to_message(self, response_message: Any) -> AIMessage:
|
||||
def _convert_response_to_message(
|
||||
self, response_message: Any, response: Any = None
|
||||
) -> AIMessage:
|
||||
"""Convert a LiteLLM response message to a LangChain AIMessage."""
|
||||
import json
|
||||
|
||||
|
|
@ -1099,9 +1105,22 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
tool_call["args"] = tc.function.arguments
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
if response:
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage:
|
||||
extra_kwargs["usage_metadata"] = {
|
||||
"input_tokens": getattr(usage, "prompt_tokens", 0) or 0,
|
||||
"output_tokens": getattr(usage, "completion_tokens", 0) or 0,
|
||||
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
|
||||
}
|
||||
extra_kwargs["response_metadata"] = {
|
||||
"model_name": getattr(response, "model", "unknown"),
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
return AIMessage(content=content, tool_calls=tool_calls)
|
||||
return AIMessage(content=content)
|
||||
return AIMessage(content=content, tool_calls=tool_calls, **extra_kwargs)
|
||||
return AIMessage(content=content, **extra_kwargs)
|
||||
|
||||
def _convert_delta_to_chunk(self, delta: Any) -> AIMessageChunk | None:
|
||||
"""Convert a streaming delta to an AIMessageChunk."""
|
||||
|
|
|
|||
|
|
@ -4,6 +4,10 @@ Token usage tracking via LiteLLM custom callback.
|
|||
Uses a ContextVar-scoped accumulator to group all LLM calls within a single
|
||||
async request/turn. The accumulated data is emitted via SSE and persisted
|
||||
when the frontend calls appendMessage.
|
||||
|
||||
Agent LLM calls are captured automatically via the async callback.
|
||||
Title-generation usage is added explicitly from the LangChain response
|
||||
metadata to avoid callback-timing issues.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
|
|||
|
|
@ -1459,22 +1459,35 @@ async def stream_new_chat(
|
|||
)
|
||||
is_first_response = (assistant_count_result.scalar() or 0) == 0
|
||||
|
||||
title_task: asyncio.Task[str | None] | None = None
|
||||
title_task: asyncio.Task[tuple[str | None, dict[str, int] | None]] | None = None
|
||||
if is_first_response:
|
||||
|
||||
async def _generate_title() -> str | None:
|
||||
async def _generate_title() -> tuple[str | None, dict[str, int] | None]:
|
||||
"""Return (title, usage_dict) where usage_dict has model/prompt/completion/total."""
|
||||
try:
|
||||
title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm
|
||||
title_result = await title_chain.ainvoke(
|
||||
{"user_query": user_query[:500]}
|
||||
)
|
||||
if title_result and hasattr(title_result, "content"):
|
||||
raw_title = title_result.content.strip()
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
return raw_title.strip("\"'")
|
||||
usage_dict: dict[str, int] | None = None
|
||||
if title_result:
|
||||
um = getattr(title_result, "usage_metadata", None)
|
||||
if um:
|
||||
rm = getattr(title_result, "response_metadata", None) or {}
|
||||
raw_model = rm.get("model_name", "unknown")
|
||||
usage_dict = {
|
||||
"model": raw_model.split("/", 1)[-1] if "/" in raw_model else raw_model,
|
||||
"prompt_tokens": um.get("input_tokens", 0),
|
||||
"completion_tokens": um.get("output_tokens", 0),
|
||||
"total_tokens": um.get("total_tokens", 0),
|
||||
}
|
||||
if hasattr(title_result, "content"):
|
||||
raw_title = title_result.content.strip()
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
return raw_title.strip("\"'"), usage_dict
|
||||
return None, usage_dict
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
return None, None
|
||||
|
||||
title_task = asyncio.create_task(_generate_title())
|
||||
|
||||
|
|
@ -1506,7 +1519,7 @@ async def stream_new_chat(
|
|||
|
||||
# Inject title update mid-stream as soon as the background task finishes
|
||||
if title_task is not None and title_task.done() and not title_emitted:
|
||||
generated_title = title_task.result()
|
||||
generated_title, _title_usage = title_task.result()
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
|
|
@ -1532,7 +1545,6 @@ async def stream_new_chat(
|
|||
if title_task is not None and not title_task.done():
|
||||
title_task.cancel()
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
|
||||
|
|
@ -1554,7 +1566,7 @@ async def stream_new_chat(
|
|||
|
||||
# If the title task didn't finish during streaming, await it now
|
||||
if title_task is not None and not title_emitted:
|
||||
generated_title = await title_task
|
||||
generated_title, _title_usage = await title_task
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
|
|
@ -1568,7 +1580,6 @@ async def stream_new_chat(
|
|||
chat_id, generated_title
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] normal new_chat: calls=%d total=%d summary=%s",
|
||||
|
|
@ -1807,7 +1818,6 @@ async def stream_resume_chat(
|
|||
time.perf_counter() - _t_stream_start,
|
||||
chat_id,
|
||||
)
|
||||
await asyncio.sleep(0.2)
|
||||
if stream_result.is_interrupted:
|
||||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import {
|
||||
ActionBarMorePrimitive,
|
||||
ActionBarPrimitive,
|
||||
AuiIf,
|
||||
ErrorPrimitive,
|
||||
|
|
@ -40,14 +41,7 @@ import {
|
|||
DrawerHeader,
|
||||
DrawerTitle,
|
||||
} from "@/components/ui/drawer";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuLabel,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { DropdownMenuLabel } from "@/components/ui/dropdown-menu";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useComments } from "@/hooks/use-comments";
|
||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||
|
|
@ -397,14 +391,17 @@ const MessageInfoDropdown: FC = () => {
|
|||
const hasUsage = usage && usage.total_tokens > 0;
|
||||
|
||||
return (
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<ActionBarMorePrimitive.Root>
|
||||
<ActionBarMorePrimitive.Trigger asChild>
|
||||
<Button variant="ghost" size="icon" className="aui-button-icon size-6 p-1">
|
||||
<MoreHorizontalIcon className="size-4" />
|
||||
<span className="sr-only">More</span>
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="start" className="min-w-[180px]">
|
||||
</ActionBarMorePrimitive.Trigger>
|
||||
<ActionBarMorePrimitive.Content
|
||||
align="start"
|
||||
className="bg-muted text-popover-foreground z-50 max-h-(--radix-dropdown-menu-content-available-height) min-w-[180px] origin-(--radix-dropdown-menu-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border dark:border-neutral-700 p-1 shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2"
|
||||
>
|
||||
{createdAt && (
|
||||
<DropdownMenuLabel className="text-xs text-muted-foreground font-normal select-none">
|
||||
{formatMessageDate(createdAt)}
|
||||
|
|
@ -412,27 +409,27 @@ const MessageInfoDropdown: FC = () => {
|
|||
)}
|
||||
{hasUsage && (
|
||||
<>
|
||||
<DropdownMenuSeparator />
|
||||
<ActionBarMorePrimitive.Separator className="bg-border mx-2 my-1 h-px" />
|
||||
{models.length > 0 ? (
|
||||
models.map(([model, counts]) => (
|
||||
<DropdownMenuItem key={model} className="flex-col items-start gap-0.5 cursor-default" onSelect={(e) => e.preventDefault()}>
|
||||
<ActionBarMorePrimitive.Item key={model} className="focus:bg-neutral-200 dark:focus:bg-neutral-700 relative flex cursor-default flex-col items-start gap-0.5 rounded-sm px-2 py-1.5 text-sm outline-hidden select-none" onSelect={(e) => e.preventDefault()}>
|
||||
<span className="text-xs font-medium">{model}</span>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{counts.total_tokens.toLocaleString()} tokens
|
||||
</span>
|
||||
</DropdownMenuItem>
|
||||
</ActionBarMorePrimitive.Item>
|
||||
))
|
||||
) : (
|
||||
<DropdownMenuItem className="flex-col items-start gap-0.5 cursor-default" onSelect={(e) => e.preventDefault()}>
|
||||
<ActionBarMorePrimitive.Item className="focus:bg-neutral-200 dark:focus:bg-neutral-700 relative flex cursor-default flex-col items-start gap-0.5 rounded-sm px-2 py-1.5 text-sm outline-hidden select-none" onSelect={(e) => e.preventDefault()}>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{usage.total_tokens.toLocaleString()} tokens
|
||||
</span>
|
||||
</DropdownMenuItem>
|
||||
</ActionBarMorePrimitive.Item>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</ActionBarMorePrimitive.Content>
|
||||
</ActionBarMorePrimitive.Root>
|
||||
);
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue