fix(chat): enforce pinned model quota flow and reset stale pins

This commit is contained in:
Anish Sarkar 2026-04-29 19:15:36 +05:30
parent 41849fe10f
commit 835bd9f65d
2 changed files with 88 additions and 44 deletions

View file

@ -3,7 +3,7 @@ import logging
from fastapi import APIRouter, Depends, HTTPException
from langchain_core.messages import HumanMessage
from pydantic import BaseModel as PydanticBaseModel
from sqlalchemy import func
from sqlalchemy import func, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem
from app.config import config
from app.db import (
ImageGenerationConfig,
NewChatThread,
NewLLMConfig,
Permission,
SearchSpace,
@ -790,9 +791,31 @@ async def update_llm_preferences(
# Update preferences
update_data = preferences.model_dump(exclude_unset=True)
previous_agent_llm_id = search_space.agent_llm_id
for key, value in update_data.items():
setattr(search_space, key, value)
agent_llm_changed = (
"agent_llm_id" in update_data
and update_data["agent_llm_id"] != previous_agent_llm_id
)
if agent_llm_changed:
await session.execute(
update(NewChatThread)
.where(NewChatThread.search_space_id == search_space_id)
.values(
pinned_llm_config_id=None,
pinned_auto_mode=None,
pinned_at=None,
)
)
logger.info(
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
search_space_id,
previous_agent_llm_id,
update_data["agent_llm_id"],
)
await session.commit()
await session.refresh(search_space)

View file

@ -56,6 +56,7 @@ from app.db import (
shielded_async_session,
)
from app.prompts import TITLE_GENERATION_PROMPT
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
from app.services.chat_session_state_service import (
clear_ai_responding,
set_ai_responding,
@ -1456,6 +1457,21 @@ async def stream_new_chat(
agent_config: AgentConfig | None = None
_t0 = time.perf_counter()
try:
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=llm_config_id,
)
).resolved_llm_config_id
except ValueError as pin_error:
yield streaming_service.format_error(str(pin_error))
yield streaming_service.format_done()
return
if llm_config_id >= 0:
# Positive ID: Load from NewLLMConfig database table
agent_config = await load_agent_config(
@ -1491,12 +1507,11 @@ async def stream_new_chat(
llm_config_id,
)
# Premium quota reservation — applies to explicitly premium configs
# AND Auto mode (which may route to premium models).
# Premium quota reservation for pinned premium model only.
_needs_premium_quota = (
agent_config is not None
and user_id
and (agent_config.is_premium or agent_config.is_auto_mode)
and agent_config.is_premium
)
if _needs_premium_quota:
import uuid as _uuid
@ -1519,16 +1534,18 @@ async def stream_new_chat(
)
_premium_reserved = reserve_amount
if not quota_result.allowed:
if agent_config.is_premium:
yield streaming_service.format_error(
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
)
yield streaming_service.format_done()
return
# Auto mode: quota exhausted but we can still proceed
# (the router may pick a free model). Reset reservation.
_premium_request_id = None
_premium_reserved = 0
logging.getLogger(__name__).info(
"premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s",
chat_id,
search_space_id,
user_id,
llm_config_id,
)
yield streaming_service.format_error(
"Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin."
)
yield streaming_service.format_done()
return
if not llm:
yield streaming_service.format_error("Failed to create LLM instance")
@ -1961,28 +1978,20 @@ async def stream_new_chat(
)
# Finalize premium quota with actual tokens.
# For Auto mode, only count tokens from calls that used premium models.
if _premium_request_id and user_id:
try:
from app.services.token_quota_service import TokenQuotaService
if agent_config and agent_config.is_auto_mode:
from app.services.llm_router_service import LLMRouterService
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
accumulator.calls
)
else:
actual_premium_tokens = accumulator.grand_total
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_finalize(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_premium_request_id,
actual_tokens=actual_premium_tokens,
actual_tokens=accumulator.grand_total,
reserved_tokens=_premium_reserved,
)
_premium_request_id = None
_premium_reserved = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s",
@ -2175,6 +2184,21 @@ async def stream_resume_chat(
agent_config: AgentConfig | None = None
_t0 = time.perf_counter()
try:
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=llm_config_id,
)
).resolved_llm_config_id
except ValueError as pin_error:
yield streaming_service.format_error(str(pin_error))
yield streaming_service.format_done()
return
if llm_config_id >= 0:
agent_config = await load_agent_config(
session=session,
@ -2208,7 +2232,7 @@ async def stream_resume_chat(
_resume_needs_premium = (
agent_config is not None
and user_id
and (agent_config.is_premium or agent_config.is_auto_mode)
and agent_config.is_premium
)
if _resume_needs_premium:
import uuid as _uuid
@ -2231,14 +2255,18 @@ async def stream_resume_chat(
)
_resume_premium_reserved = reserve_amount
if not quota_result.allowed:
if agent_config.is_premium:
yield streaming_service.format_error(
"Premium token quota exceeded. Please purchase more tokens to continue using premium models."
)
yield streaming_service.format_done()
return
_resume_premium_request_id = None
_resume_premium_reserved = 0
logging.getLogger(__name__).info(
"premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s",
chat_id,
search_space_id,
user_id,
llm_config_id,
)
yield streaming_service.format_error(
"Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin."
)
yield streaming_service.format_done()
return
if not llm:
yield streaming_service.format_error("Failed to create LLM instance")
@ -2370,23 +2398,16 @@ async def stream_resume_chat(
try:
from app.services.token_quota_service import TokenQuotaService
if agent_config and agent_config.is_auto_mode:
from app.services.llm_router_service import LLMRouterService
actual_premium_tokens = LLMRouterService.compute_premium_tokens(
accumulator.calls
)
else:
actual_premium_tokens = accumulator.grand_total
async with shielded_async_session() as quota_session:
await TokenQuotaService.premium_finalize(
db_session=quota_session,
user_id=UUID(user_id),
request_id=_resume_premium_request_id,
actual_tokens=actual_premium_tokens,
actual_tokens=accumulator.grand_total,
reserved_tokens=_resume_premium_reserved,
)
_resume_premium_request_id = None
_resume_premium_reserved = 0
except Exception:
logging.getLogger(__name__).warning(
"Failed to finalize premium quota for user %s (resume)",