mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 12:52:39 +02:00
fix(chat): enforce pinned model quota flow and reset stale pins
This commit is contained in:
parent
41849fe10f
commit
835bd9f65d
2 changed files with 88 additions and 44 deletions
|
|
@ -3,7 +3,7 @@ import logging
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem
|
|||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGenerationConfig,
|
||||
NewChatThread,
|
||||
NewLLMConfig,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
|
|
@ -790,9 +791,31 @@ async def update_llm_preferences(
|
|||
|
||||
# Update preferences
|
||||
update_data = preferences.model_dump(exclude_unset=True)
|
||||
previous_agent_llm_id = search_space.agent_llm_id
|
||||
for key, value in update_data.items():
|
||||
setattr(search_space, key, value)
|
||||
|
||||
agent_llm_changed = (
|
||||
"agent_llm_id" in update_data
|
||||
and update_data["agent_llm_id"] != previous_agent_llm_id
|
||||
)
|
||||
if agent_llm_changed:
|
||||
await session.execute(
|
||||
update(NewChatThread)
|
||||
.where(NewChatThread.search_space_id == search_space_id)
|
||||
.values(
|
||||
pinned_llm_config_id=None,
|
||||
pinned_auto_mode=None,
|
||||
pinned_at=None,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
|
||||
search_space_id,
|
||||
previous_agent_llm_id,
|
||||
update_data["agent_llm_id"],
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(search_space)
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ from app.db import (
|
|||
shielded_async_session,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT
|
||||
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
|
|
@ -1456,6 +1457,21 @@ async def stream_new_chat(
|
|||
agent_config: AgentConfig | None = None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
try:
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=llm_config_id,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
yield streaming_service.format_error(str(pin_error))
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if llm_config_id >= 0:
|
||||
# Positive ID: Load from NewLLMConfig database table
|
||||
agent_config = await load_agent_config(
|
||||
|
|
@ -1491,12 +1507,11 @@ async def stream_new_chat(
|
|||
llm_config_id,
|
||||
)
|
||||
|
||||
# Premium quota reservation — applies to explicitly premium configs
|
||||
# AND Auto mode (which may route to premium models).
|
||||
# Premium quota reservation for pinned premium model only.
|
||||
_needs_premium_quota = (
|
||||
agent_config is not None
|
||||
and user_id
|
||||
and (agent_config.is_premium or agent_config.is_auto_mode)
|
||||
and agent_config.is_premium
|
||||
)
|
||||
if _needs_premium_quota:
|
||||
import uuid as _uuid
|
||||
|
|
@ -1519,16 +1534,18 @@ async def stream_new_chat(
|
|||
)
|
||||
_premium_reserved = reserve_amount
|
||||
if not quota_result.allowed:
|
||||
if agent_config.is_premium:
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
# Auto mode: quota exhausted but we can still proceed
|
||||
# (the router may pick a free model). Reset reservation.
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
logging.getLogger(__name__).info(
|
||||
"premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s",
|
||||
chat_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
llm_config_id,
|
||||
)
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
|
|
@ -1961,28 +1978,20 @@ async def stream_new_chat(
|
|||
)
|
||||
|
||||
# Finalize premium quota with actual tokens.
|
||||
# For Auto mode, only count tokens from calls that used premium models.
|
||||
if _premium_request_id and user_id:
|
||||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
if agent_config and agent_config.is_auto_mode:
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
|
||||
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
|
||||
accumulator.calls
|
||||
)
|
||||
else:
|
||||
actual_premium_tokens = accumulator.grand_total
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_premium_request_id,
|
||||
actual_tokens=actual_premium_tokens,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
reserved_tokens=_premium_reserved,
|
||||
)
|
||||
_premium_request_id = None
|
||||
_premium_reserved = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s",
|
||||
|
|
@ -2175,6 +2184,21 @@ async def stream_resume_chat(
|
|||
|
||||
agent_config: AgentConfig | None = None
|
||||
_t0 = time.perf_counter()
|
||||
try:
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=llm_config_id,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
yield streaming_service.format_error(str(pin_error))
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if llm_config_id >= 0:
|
||||
agent_config = await load_agent_config(
|
||||
session=session,
|
||||
|
|
@ -2208,7 +2232,7 @@ async def stream_resume_chat(
|
|||
_resume_needs_premium = (
|
||||
agent_config is not None
|
||||
and user_id
|
||||
and (agent_config.is_premium or agent_config.is_auto_mode)
|
||||
and agent_config.is_premium
|
||||
)
|
||||
if _resume_needs_premium:
|
||||
import uuid as _uuid
|
||||
|
|
@ -2231,14 +2255,18 @@ async def stream_resume_chat(
|
|||
)
|
||||
_resume_premium_reserved = reserve_amount
|
||||
if not quota_result.allowed:
|
||||
if agent_config.is_premium:
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
logging.getLogger(__name__).info(
|
||||
"premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s",
|
||||
chat_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
llm_config_id,
|
||||
)
|
||||
yield streaming_service.format_error(
|
||||
"Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin."
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
if not llm:
|
||||
yield streaming_service.format_error("Failed to create LLM instance")
|
||||
|
|
@ -2370,23 +2398,16 @@ async def stream_resume_chat(
|
|||
try:
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
if agent_config and agent_config.is_auto_mode:
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
|
||||
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
|
||||
accumulator.calls
|
||||
)
|
||||
else:
|
||||
actual_premium_tokens = accumulator.grand_total
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=UUID(user_id),
|
||||
request_id=_resume_premium_request_id,
|
||||
actual_tokens=actual_premium_tokens,
|
||||
actual_tokens=accumulator.grand_total,
|
||||
reserved_tokens=_resume_premium_reserved,
|
||||
)
|
||||
_resume_premium_request_id = None
|
||||
_resume_premium_reserved = 0
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to finalize premium quota for user %s (resume)",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue