diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 828137518..7944e7d66 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -3,7 +3,7 @@ import logging from fastapi import APIRouter, Depends, HTTPException from langchain_core.messages import HumanMessage from pydantic import BaseModel as PydanticBaseModel -from sqlalchemy import func +from sqlalchemy import func, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem from app.config import config from app.db import ( ImageGenerationConfig, + NewChatThread, NewLLMConfig, Permission, SearchSpace, @@ -790,9 +791,31 @@ async def update_llm_preferences( # Update preferences update_data = preferences.model_dump(exclude_unset=True) + previous_agent_llm_id = search_space.agent_llm_id for key, value in update_data.items(): setattr(search_space, key, value) + agent_llm_changed = ( + "agent_llm_id" in update_data + and update_data["agent_llm_id"] != previous_agent_llm_id + ) + if agent_llm_changed: + await session.execute( + update(NewChatThread) + .where(NewChatThread.search_space_id == search_space_id) + .values( + pinned_llm_config_id=None, + pinned_auto_mode=None, + pinned_at=None, + ) + ) + logger.info( + "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)", + search_space_id, + previous_agent_llm_id, + update_data["agent_llm_id"], + ) + await session.commit() await session.refresh(search_space) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index c254e66e2..1a56547ca 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -56,6 +56,7 @@ from app.db import ( shielded_async_session, ) from app.prompts import TITLE_GENERATION_PROMPT +from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id from app.services.chat_session_state_service import ( clear_ai_responding, set_ai_responding, @@ -1456,6 +1457,21 @@ async def stream_new_chat( agent_config: AgentConfig | None = None _t0 = time.perf_counter() + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=llm_config_id, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield streaming_service.format_error(str(pin_error)) + yield streaming_service.format_done() + return + if llm_config_id >= 0: # Positive ID: Load from NewLLMConfig database table agent_config = await load_agent_config( @@ -1491,12 +1507,11 @@ async def stream_new_chat( llm_config_id, ) - # Premium quota reservation — applies to explicitly premium configs - # AND Auto mode (which may route to premium models). + # Premium quota reservation for pinned premium model only. _needs_premium_quota = ( agent_config is not None and user_id - and (agent_config.is_premium or agent_config.is_auto_mode) + and agent_config.is_premium ) if _needs_premium_quota: import uuid as _uuid @@ -1519,16 +1534,18 @@ async def stream_new_chat( ) _premium_reserved = reserve_amount if not quota_result.allowed: - if agent_config.is_premium: - yield streaming_service.format_error( - "Premium token quota exceeded. Please purchase more tokens to continue using premium models." - ) - yield streaming_service.format_done() - return - # Auto mode: quota exhausted but we can still proceed - # (the router may pick a free model). Reset reservation. - _premium_request_id = None - _premium_reserved = 0 + logging.getLogger(__name__).info( + "premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s", + chat_id, + search_space_id, + user_id, + llm_config_id, + ) + yield streaming_service.format_error( + "Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin." + ) + yield streaming_service.format_done() + return if not llm: yield streaming_service.format_error("Failed to create LLM instance") @@ -1961,28 +1978,20 @@ async def stream_new_chat( ) # Finalize premium quota with actual tokens. - # For Auto mode, only count tokens from calls that used premium models. if _premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService - if agent_config and agent_config.is_auto_mode: - from app.services.llm_router_service import LLMRouterService - - actual_premium_tokens = LLMRouterService.compute_premium_tokens( - accumulator.calls - ) - else: - actual_premium_tokens = accumulator.grand_total - async with shielded_async_session() as quota_session: await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, - actual_tokens=actual_premium_tokens, + actual_tokens=accumulator.grand_total, reserved_tokens=_premium_reserved, ) + _premium_request_id = None + _premium_reserved = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s", @@ -2175,6 +2184,21 @@ async def stream_resume_chat( agent_config: AgentConfig | None = None _t0 = time.perf_counter() + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=llm_config_id, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield streaming_service.format_error(str(pin_error)) + yield streaming_service.format_done() + return + if llm_config_id >= 0: agent_config = await load_agent_config( session=session, @@ -2208,7 +2232,7 @@ async def stream_resume_chat( _resume_needs_premium = ( agent_config is not None and user_id - and (agent_config.is_premium or agent_config.is_auto_mode) + and agent_config.is_premium ) if _resume_needs_premium: import uuid as _uuid @@ -2231,14 +2255,18 @@ async def stream_resume_chat( ) _resume_premium_reserved = reserve_amount if not quota_result.allowed: - if agent_config.is_premium: - yield streaming_service.format_error( - "Premium token quota exceeded. Please purchase more tokens to continue using premium models." - ) - yield streaming_service.format_done() - return - _resume_premium_request_id = None - _resume_premium_reserved = 0 + logging.getLogger(__name__).info( + "premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s", + chat_id, + search_space_id, + user_id, + llm_config_id, + ) + yield streaming_service.format_error( + "Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin." + ) + yield streaming_service.format_done() + return if not llm: yield streaming_service.format_error("Failed to create LLM instance") @@ -2370,23 +2398,16 @@ async def stream_resume_chat( try: from app.services.token_quota_service import TokenQuotaService - if agent_config and agent_config.is_auto_mode: - from app.services.llm_router_service import LLMRouterService - - actual_premium_tokens = LLMRouterService.compute_premium_tokens( - accumulator.calls - ) - else: - actual_premium_tokens = accumulator.grand_total - async with shielded_async_session() as quota_session: await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, - actual_tokens=actual_premium_tokens, + actual_tokens=accumulator.grand_total, reserved_tokens=_resume_premium_reserved, ) + _resume_premium_request_id = None + _resume_premium_reserved = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s (resume)",