fix(chat): enforce pinned model quota flow and reset stale pins

This commit is contained in:
Anish Sarkar 2026-04-29 19:15:36 +05:30
parent 41849fe10f
commit 835bd9f65d
2 changed files with 88 additions and 44 deletions

View file

@ -3,7 +3,7 @@ import logging
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from pydantic import BaseModel as PydanticBaseModel from pydantic import BaseModel as PydanticBaseModel
from sqlalchemy import func from sqlalchemy import func, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem
from app.config import config from app.config import config
from app.db import ( from app.db import (
ImageGenerationConfig, ImageGenerationConfig,
NewChatThread,
NewLLMConfig, NewLLMConfig,
Permission, Permission,
SearchSpace, SearchSpace,
@ -790,9 +791,31 @@ async def update_llm_preferences(
# Update preferences # Update preferences
update_data = preferences.model_dump(exclude_unset=True) update_data = preferences.model_dump(exclude_unset=True)
previous_agent_llm_id = search_space.agent_llm_id
for key, value in update_data.items(): for key, value in update_data.items():
setattr(search_space, key, value) setattr(search_space, key, value)
agent_llm_changed = (
"agent_llm_id" in update_data
and update_data["agent_llm_id"] != previous_agent_llm_id
)
if agent_llm_changed:
await session.execute(
update(NewChatThread)
.where(NewChatThread.search_space_id == search_space_id)
.values(
pinned_llm_config_id=None,
pinned_auto_mode=None,
pinned_at=None,
)
)
logger.info(
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
search_space_id,
previous_agent_llm_id,
update_data["agent_llm_id"],
)
await session.commit() await session.commit()
await session.refresh(search_space) await session.refresh(search_space)

View file

@ -56,6 +56,7 @@ from app.db import (
shielded_async_session, shielded_async_session,
) )
from app.prompts import TITLE_GENERATION_PROMPT from app.prompts import TITLE_GENERATION_PROMPT
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
from app.services.chat_session_state_service import ( from app.services.chat_session_state_service import (
clear_ai_responding, clear_ai_responding,
set_ai_responding, set_ai_responding,
@ -1456,6 +1457,21 @@ async def stream_new_chat(
agent_config: AgentConfig | None = None agent_config: AgentConfig | None = None
_t0 = time.perf_counter() _t0 = time.perf_counter()
try:
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=llm_config_id,
)
).resolved_llm_config_id
except ValueError as pin_error:
yield streaming_service.format_error(str(pin_error))
yield streaming_service.format_done()
return
if llm_config_id >= 0: if llm_config_id >= 0:
# Positive ID: Load from NewLLMConfig database table # Positive ID: Load from NewLLMConfig database table
agent_config = await load_agent_config( agent_config = await load_agent_config(
@ -1491,12 +1507,11 @@ async def stream_new_chat(
llm_config_id, llm_config_id,
) )
# Premium quota reservation — applies to explicitly premium configs # Premium quota reservation for pinned premium model only.
# AND Auto mode (which may route to premium models).
_needs_premium_quota = ( _needs_premium_quota = (
agent_config is not None agent_config is not None
and user_id and user_id
and (agent_config.is_premium or agent_config.is_auto_mode) and agent_config.is_premium
) )
if _needs_premium_quota: if _needs_premium_quota:
import uuid as _uuid import uuid as _uuid
@ -1519,16 +1534,18 @@ async def stream_new_chat(
) )
_premium_reserved = reserve_amount _premium_reserved = reserve_amount
if not quota_result.allowed: if not quota_result.allowed:
if agent_config.is_premium: logging.getLogger(__name__).info(
yield streaming_service.format_error( "premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s",
"Premium token quota exceeded. Please purchase more tokens to continue using premium models." chat_id,
) search_space_id,
yield streaming_service.format_done() user_id,
return llm_config_id,
# Auto mode: quota exhausted but we can still proceed )
# (the router may pick a free model). Reset reservation. yield streaming_service.format_error(
_premium_request_id = None "Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin."
_premium_reserved = 0 )
yield streaming_service.format_done()
return
if not llm: if not llm:
yield streaming_service.format_error("Failed to create LLM instance") yield streaming_service.format_error("Failed to create LLM instance")
@ -1961,28 +1978,20 @@ async def stream_new_chat(
) )
# Finalize premium quota with actual tokens. # Finalize premium quota with actual tokens.
# For Auto mode, only count tokens from calls that used premium models.
if _premium_request_id and user_id: if _premium_request_id and user_id:
try: try:
from app.services.token_quota_service import TokenQuotaService from app.services.token_quota_service import TokenQuotaService
if agent_config and agent_config.is_auto_mode:
from app.services.llm_router_service import LLMRouterService
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
accumulator.calls
)
else:
actual_premium_tokens = accumulator.grand_total
async with shielded_async_session() as quota_session: async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_finalize( await TokenQuotaService.premium_finalize(
db_session=quota_session, db_session=quota_session,
user_id=UUID(user_id), user_id=UUID(user_id),
request_id=_premium_request_id, request_id=_premium_request_id,
actual_tokens=actual_premium_tokens, actual_tokens=accumulator.grand_total,
reserved_tokens=_premium_reserved, reserved_tokens=_premium_reserved,
) )
_premium_request_id = None
_premium_reserved = 0
except Exception: except Exception:
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s", "Failed to finalize premium quota for user %s",
@ -2175,6 +2184,21 @@ async def stream_resume_chat(
agent_config: AgentConfig | None = None agent_config: AgentConfig | None = None
_t0 = time.perf_counter() _t0 = time.perf_counter()
try:
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=llm_config_id,
)
).resolved_llm_config_id
except ValueError as pin_error:
yield streaming_service.format_error(str(pin_error))
yield streaming_service.format_done()
return
if llm_config_id >= 0: if llm_config_id >= 0:
agent_config = await load_agent_config( agent_config = await load_agent_config(
session=session, session=session,
@ -2208,7 +2232,7 @@ async def stream_resume_chat(
_resume_needs_premium = ( _resume_needs_premium = (
agent_config is not None agent_config is not None
and user_id and user_id
and (agent_config.is_premium or agent_config.is_auto_mode) and agent_config.is_premium
) )
if _resume_needs_premium: if _resume_needs_premium:
import uuid as _uuid import uuid as _uuid
@ -2231,14 +2255,18 @@ async def stream_resume_chat(
) )
_resume_premium_reserved = reserve_amount _resume_premium_reserved = reserve_amount
if not quota_result.allowed: if not quota_result.allowed:
if agent_config.is_premium: logging.getLogger(__name__).info(
yield streaming_service.format_error( "premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s",
"Premium token quota exceeded. Please purchase more tokens to continue using premium models." chat_id,
) search_space_id,
yield streaming_service.format_done() user_id,
return llm_config_id,
_resume_premium_request_id = None )
_resume_premium_reserved = 0 yield streaming_service.format_error(
"Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin."
)
yield streaming_service.format_done()
return
if not llm: if not llm:
yield streaming_service.format_error("Failed to create LLM instance") yield streaming_service.format_error("Failed to create LLM instance")
@ -2370,23 +2398,16 @@ async def stream_resume_chat(
try: try:
from app.services.token_quota_service import TokenQuotaService from app.services.token_quota_service import TokenQuotaService
if agent_config and agent_config.is_auto_mode:
from app.services.llm_router_service import LLMRouterService
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
accumulator.calls
)
else:
actual_premium_tokens = accumulator.grand_total
async with shielded_async_session() as quota_session: async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_finalize( await TokenQuotaService.premium_finalize(
db_session=quota_session, db_session=quota_session,
user_id=UUID(user_id), user_id=UUID(user_id),
request_id=_resume_premium_request_id, request_id=_resume_premium_request_id,
actual_tokens=actual_premium_tokens, actual_tokens=accumulator.grand_total,
reserved_tokens=_resume_premium_reserved, reserved_tokens=_resume_premium_reserved,
) )
_resume_premium_request_id = None
_resume_premium_reserved = 0
except Exception: except Exception:
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s (resume)", "Failed to finalize premium quota for user %s (resume)",