mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
fix(chat): enforce pinned model quota flow and reset stale pins
This commit is contained in:
parent
41849fe10f
commit
835bd9f65d
2 changed files with 88 additions and 44 deletions
|
|
@ -56,6 +56,7 @@ from app.db import (
|
|||
shielded_async_session,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT
|
||||
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
|
|
@ -1456,6 +1457,21 @@ async def stream_new_chat(
|
|||
agent_config: AgentConfig | None = None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
try:
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=llm_config_id,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
yield streaming_service.format_error(str(pin_error))
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if llm_config_id >= 0:
|
||||
# Positive ID: Load from NewLLMConfig database table
|
||||
agent_config = await load_agent_config(
|
||||
|
|
@ -1491,12 +1507,11 @@ async def stream_new_chat(
|
|||
llm_config_id,
|
||||
)
|
||||
|
||||
# Premium quota reservation — applies to explicitly premium configs
|
||||
# AND Auto mode (which may route to premium models).
|
||||
# Premium quota reservation for pinned premium model only.
|
||||
_needs_premium_quota = (
|
||||
agent_config is not None
|
||||
and user_id
|
||||
and (agent_config.is_premium or agent_config.is_auto_mode)
|
||||
and agent_config.is_premium
|
||||
)
|
||||
if _needs_premium_quota:
|
||||
import uuid as _uuid
|
||||
|
|
@ -1519,16 +1534,18 @@ async def stream_new_chat(
|
|||
)
|
||||
_premium_reserved = reserve_amount
|
||||
if not quota_result.allowed:
|
||||
if agent_config.is_premium:
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
# Auto mode: quota exhausted but we can still proceed
|
||||
# (the router may pick a free model). Reset reservation.
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
logging.getLogger(__name__).info(
|
||||
"premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s",
|
||||
chat_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
llm_config_id,
|
||||
)
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
|
|
@ -1961,28 +1978,20 @@ async def stream_new_chat(
|
|||
)
|
||||
|
||||
# Finalize premium quota with actual tokens.
|
||||
# For Auto mode, only count tokens from calls that used premium models.
|
||||
if _premium_request_id and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
if agent_config and agent_config.is_auto_mode:
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
|
||||
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
|
||||
accumulator.calls
|
||||
)
|
||||
else:
|
||||
actual_premium_tokens = accumulator.grand_total
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_premium_request_id,
|
||||
actual_tokens=actual_premium_tokens,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
reserved_tokens=_premium_reserved,
|
||||
)
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s",
|
||||
|
|
@ -2175,6 +2184,21 @@ async def stream_resume_chat(
|
|||
|
||||
agent_config: AgentConfig | None = None
|
||||
_t0 = time.perf_counter()
|
||||
try:
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=llm_config_id,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
yield streaming_service.format_error(str(pin_error))
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if llm_config_id >= 0:
|
||||
agent_config = await load_agent_config(
|
||||
session=session,
|
||||
|
|
@ -2208,7 +2232,7 @@ async def stream_resume_chat(
|
|||
_resume_needs_premium = (
|
||||
agent_config is not None
|
||||
and user_id
|
||||
and (agent_config.is_premium or agent_config.is_auto_mode)
|
||||
and agent_config.is_premium
|
||||
)
|
||||
if _resume_needs_premium:
|
||||
import uuid as _uuid
|
||||
|
|
@ -2231,14 +2255,18 @@ async def stream_resume_chat(
|
|||
)
|
||||
_resume_premium_reserved = reserve_amount
|
||||
if not quota_result.allowed:
|
||||
if agent_config.is_premium:
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
logging.getLogger(__name__).info(
|
||||
"premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s",
|
||||
chat_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
llm_config_id,
|
||||
)
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
|
|
@ -2370,23 +2398,16 @@ async def stream_resume_chat(
|
|||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
if agent_config and agent_config.is_auto_mode:
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
|
||||
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
|
||||
accumulator.calls
|
||||
)
|
||||
else:
|
||||
actual_premium_tokens = accumulator.grand_total
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_resume_premium_request_id,
|
||||
actual_tokens=actual_premium_tokens,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
reserved_tokens=_resume_premium_reserved,
|
||||
)
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s (resume)",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue