mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
chore: ran linting
This commit is contained in:
parent
a74ed014cc
commit
9fc0976d5e
5 changed files with 97 additions and 47 deletions
|
|
@ -33,7 +33,9 @@ def upgrade() -> None:
|
|||
"token_usage",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("prompt_tokens", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("completion_tokens", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"completion_tokens", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
sa.Column("total_tokens", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("model_breakdown", JSONB, nullable=True),
|
||||
sa.Column("call_details", JSONB, nullable=True),
|
||||
|
|
@ -72,7 +74,9 @@ def upgrade() -> None:
|
|||
|
||||
op.create_index("ix_token_usage_thread_id", "token_usage", ["thread_id"])
|
||||
op.create_index("ix_token_usage_message_id", "token_usage", ["message_id"])
|
||||
op.create_index("ix_token_usage_search_space_id", "token_usage", ["search_space_id"])
|
||||
op.create_index(
|
||||
"ix_token_usage_search_space_id", "token_usage", ["search_space_id"]
|
||||
)
|
||||
op.create_index("ix_token_usage_user_id", "token_usage", ["user_id"])
|
||||
op.create_index("ix_token_usage_usage_type", "token_usage", ["usage_type"])
|
||||
|
||||
|
|
|
|||
|
|
@ -498,7 +498,9 @@ async def get_thread_messages(
|
|||
author_id=msg.author_id,
|
||||
author_display_name=msg.author.display_name if msg.author else None,
|
||||
author_avatar_url=msg.author.avatar_url if msg.author else None,
|
||||
token_usage=TokenUsageSummary.model_validate(msg.token_usage) if msg.token_usage else None,
|
||||
token_usage=TokenUsageSummary.model_validate(msg.token_usage)
|
||||
if msg.token_usage
|
||||
else None,
|
||||
)
|
||||
for msg in db_messages
|
||||
]
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from app.services.llm_router_service import (
|
|||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
from app.services.token_tracking_service import token_tracker
|
||||
|
||||
# Configure litellm to automatically drop unsupported parameters
|
||||
litellm.drop_params = True
|
||||
|
|
@ -25,8 +26,6 @@ litellm.cache = None
|
|||
litellm.failure_callback = []
|
||||
litellm.input_callback = []
|
||||
|
||||
from app.services.token_tracking_service import token_tracker
|
||||
|
||||
litellm.callbacks = [token_tracker]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -117,12 +117,16 @@ class TokenTrackingCallback(CustomLogger):
|
|||
) -> None:
|
||||
acc = _turn_accumulator.get()
|
||||
if acc is None:
|
||||
logger.debug("[TokenTracking] async_log_success_event fired but no accumulator in context")
|
||||
logger.debug(
|
||||
"[TokenTracking] async_log_success_event fired but no accumulator in context"
|
||||
)
|
||||
return
|
||||
|
||||
usage = getattr(response_obj, "usage", None)
|
||||
if not usage:
|
||||
logger.debug("[TokenTracking] async_log_success_event fired but response has no usage data")
|
||||
logger.debug(
|
||||
"[TokenTracking] async_log_success_event fired but response has no usage data"
|
||||
)
|
||||
return
|
||||
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
|
|
@ -139,7 +143,11 @@ class TokenTrackingCallback(CustomLogger):
|
|||
)
|
||||
logger.info(
|
||||
"[TokenTracking] Captured: model=%s prompt=%d completion=%d total=%d (accumulator now has %d calls)",
|
||||
model, prompt_tokens, completion_tokens, total_tokens, len(acc.calls),
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
len(acc.calls),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -187,11 +195,16 @@ async def record_token_usage(
|
|||
session.add(record)
|
||||
logger.debug(
|
||||
"[TokenTracking] recorded %s usage: prompt=%d completion=%d total=%d",
|
||||
usage_type, prompt_tokens, completion_tokens, total_tokens,
|
||||
usage_type,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
)
|
||||
return record
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[TokenTracking] failed to record %s token usage", usage_type, exc_info=True,
|
||||
"[TokenTracking] failed to record %s token usage",
|
||||
usage_type,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1475,17 +1475,22 @@ async def stream_new_chat(
|
|||
"""
|
||||
try:
|
||||
from litellm import acompletion
|
||||
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
from app.services.token_tracking_service import _turn_accumulator
|
||||
|
||||
_turn_accumulator.set(None)
|
||||
|
||||
prompt = TITLE_GENERATION_PROMPT.replace("{user_query}", user_query[:500])
|
||||
prompt = TITLE_GENERATION_PROMPT.replace(
|
||||
"{user_query}", user_query[:500]
|
||||
)
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
if getattr(llm, "model", None) == "auto":
|
||||
router = LLMRouterService.get_router()
|
||||
response = await router.acompletion(model="auto", messages=messages)
|
||||
response = await router.acompletion(
|
||||
model="auto", messages=messages
|
||||
)
|
||||
else:
|
||||
response = await acompletion(
|
||||
model=llm.model,
|
||||
|
|
@ -1498,11 +1503,16 @@ async def stream_new_chat(
|
|||
usage = getattr(response, "usage", None)
|
||||
if usage:
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
model_name = raw_model.split("/", 1)[-1] if "/" in raw_model else (raw_model or response.model or "unknown")
|
||||
model_name = (
|
||||
raw_model.split("/", 1)[-1]
|
||||
if "/" in raw_model
|
||||
else (raw_model or response.model or "unknown")
|
||||
)
|
||||
usage_info = {
|
||||
"model": model_name,
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0,
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0) or 0,
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0)
|
||||
or 0,
|
||||
"total_tokens": getattr(usage, "total_tokens", 0) or 0,
|
||||
}
|
||||
|
||||
|
|
@ -1511,7 +1521,9 @@ async def stream_new_chat(
|
|||
return raw_title.strip("\"'"), usage_info
|
||||
return None, usage_info
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("[TitleGen] _generate_title failed")
|
||||
logging.getLogger(__name__).exception(
|
||||
"[TitleGen] _generate_title failed"
|
||||
)
|
||||
return None, None
|
||||
|
||||
title_task = asyncio.create_task(_generate_title())
|
||||
|
|
@ -1575,16 +1587,21 @@ async def stream_new_chat(
|
|||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] interrupted new_chat: calls=%d total=%d summary=%s",
|
||||
len(accumulator.calls), accumulator.grand_total, usage_summary,
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data("token-usage", {
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
})
|
||||
yield streaming_service.format_data(
|
||||
"token-usage",
|
||||
{
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
|
|
@ -1612,16 +1629,21 @@ async def stream_new_chat(
|
|||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] normal new_chat: calls=%d total=%d summary=%s",
|
||||
len(accumulator.calls), accumulator.grand_total, usage_summary,
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data("token-usage", {
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
})
|
||||
yield streaming_service.format_data(
|
||||
"token-usage",
|
||||
{
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
||||
# Fire background memory extraction if the agent didn't handle it.
|
||||
# Shared threads write to team memory; private threads write to user memory.
|
||||
|
|
@ -1870,16 +1892,21 @@ async def stream_resume_chat(
|
|||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] interrupted resume_chat: calls=%d total=%d summary=%s",
|
||||
len(accumulator.calls), accumulator.grand_total, usage_summary,
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data("token-usage", {
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
})
|
||||
yield streaming_service.format_data(
|
||||
"token-usage",
|
||||
{
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
|
|
@ -1889,16 +1916,21 @@ async def stream_resume_chat(
|
|||
usage_summary = accumulator.per_message_summary()
|
||||
_perf_log.info(
|
||||
"[token_usage] normal resume_chat: calls=%d total=%d summary=%s",
|
||||
len(accumulator.calls), accumulator.grand_total, usage_summary,
|
||||
len(accumulator.calls),
|
||||
accumulator.grand_total,
|
||||
usage_summary,
|
||||
)
|
||||
if usage_summary:
|
||||
yield streaming_service.format_data("token-usage", {
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
})
|
||||
yield streaming_service.format_data(
|
||||
"token-usage",
|
||||
{
|
||||
"usage": usage_summary,
|
||||
"prompt_tokens": accumulator.total_prompt_tokens,
|
||||
"completion_tokens": accumulator.total_completion_tokens,
|
||||
"total_tokens": accumulator.grand_total,
|
||||
"call_details": accumulator.serialized_calls(),
|
||||
},
|
||||
)
|
||||
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue