From d66fa1559b3913648e195c379e60b03ff1f00baf Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 20:29:41 +0530 Subject: [PATCH] feat(chat): implement forced repin to free tier for pinned LLM configurations --- .../app/services/auto_model_pin_service.py | 17 +- .../app/tasks/chat/stream_new_chat.py | 209 ++++++++++++------ .../services/test_auto_model_pin_service.py | 38 ++++ 3 files changed, 200 insertions(+), 64 deletions(-) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index ce417a26d..6bdb60f57 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -84,6 +84,7 @@ async def resolve_or_get_pinned_llm_config_id( search_space_id: int, user_id: str | UUID | None, selected_llm_config_id: int, + force_repin_free: bool = False, ) -> AutoPinResolution: """Resolve Auto (Fastest) to one concrete config id and persist pin metadata. @@ -130,9 +131,12 @@ async def resolve_or_get_pinned_llm_config_id( raise ValueError("No usable global LLM configs are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} - # Reuse existing valid pin without re-checking current quota (no silent tier switch). + # Reuse existing valid pin without re-checking current quota (no silent tier switch), + # unless the caller explicitly requests a forced repin to free. pinned_id = thread.pinned_llm_config_id if ( + not force_repin_free + and thread.pinned_auto_mode == AUTO_FASTEST_MODE and pinned_id is not None and int(pinned_id) in candidate_by_id @@ -159,7 +163,7 @@ async def resolve_or_get_pinned_llm_config_id( thread.pinned_auto_mode, ) - premium_eligible = await _is_premium_eligible(session, user_id) + premium_eligible = False if force_repin_free else await _is_premium_eligible(session, user_id) if premium_eligible: eligible = candidates else: @@ -179,6 +183,15 @@ async def resolve_or_get_pinned_llm_config_id( thread.pinned_at = datetime.now(UTC) await session.commit() + if force_repin_free: + logger.info( + "auto_pin_forced_free_repin thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s", + thread_id, + search_space_id, + pinned_id, + selected_id, + ) + if pinned_id is None: logger.info( "auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s", diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 233b45396..edc5aa763 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1455,6 +1455,37 @@ async def stream_new_chat( await set_ai_responding(session, chat_id, UUID(user_id)) # Load LLM config - supports both YAML (negative IDs) and database (positive IDs) agent_config: AgentConfig | None = None + requested_llm_config_id = llm_config_id + + async def _load_llm_bundle( + config_id: int, + ) -> tuple[Any, AgentConfig | None, str | None]: + if config_id >= 0: + loaded_agent_config = await load_agent_config( + session=session, + config_id=config_id, + search_space_id=search_space_id, + ) + if not loaded_agent_config: + return ( + None, + None, + f"Failed to load NewLLMConfig with id {config_id}", + ) + return ( + create_chat_litellm_from_agent_config(loaded_agent_config), + loaded_agent_config, + None, + ) + + loaded_llm_config = load_global_llm_config_by_id(config_id) + if not loaded_llm_config: + return None, None, f"Failed to load LLM config with id {config_id}" + return ( + create_chat_litellm_from_config(loaded_llm_config), + AgentConfig.from_yaml_config(loaded_llm_config), + None, + ) _t0 = time.perf_counter() try: @@ -1472,35 +1503,11 @@ async def stream_new_chat( yield streaming_service.format_done() return - if llm_config_id >= 0: - # Positive ID: Load from NewLLMConfig database table - agent_config = await load_agent_config( - session=session, - config_id=llm_config_id, - search_space_id=search_space_id, - ) - if not agent_config: - yield streaming_service.format_error( - f"Failed to load NewLLMConfig with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - - # Create ChatLiteLLM from AgentConfig - llm = create_chat_litellm_from_agent_config(agent_config) - else: - # Negative ID: Load from in-memory global configs (includes dynamic OpenRouter models) - llm_config = load_global_llm_config_by_id(llm_config_id) - if not llm_config: - yield streaming_service.format_error( - f"Failed to load LLM config with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - - # Create ChatLiteLLM from global config dict - llm = create_chat_litellm_from_config(llm_config) - agent_config = AgentConfig.from_yaml_config(llm_config) + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield streaming_service.format_error(llm_load_error) + yield streaming_service.format_done() + return _perf_log.info( "[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)", time.perf_counter() - _t0, @@ -1541,11 +1548,43 @@ async def stream_new_chat( user_id, llm_config_id, ) - yield streaming_service.format_error( - "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." - ) - yield streaming_service.format_done() - return + if requested_llm_config_id == 0: + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + force_repin_free=True, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield streaming_service.format_error(str(pin_error)) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield streaming_service.format_error(llm_load_error) + yield streaming_service.format_done() + return + _premium_request_id = None + _premium_reserved = 0 + logging.getLogger(__name__).info( + "premium_quota_auto_fallback_to_free thread_id=%s search_space_id=%s user_id=%s fallback_config_id=%s", + chat_id, + search_space_id, + user_id, + llm_config_id, + ) + else: + yield streaming_service.format_error( + "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." + ) + yield streaming_service.format_done() + return if not llm: yield streaming_service.format_error("Failed to create LLM instance") @@ -2183,6 +2222,38 @@ async def stream_resume_chat( await set_ai_responding(session, chat_id, UUID(user_id)) agent_config: AgentConfig | None = None + requested_llm_config_id = llm_config_id + + async def _load_llm_bundle( + config_id: int, + ) -> tuple[Any, AgentConfig | None, str | None]: + if config_id >= 0: + loaded_agent_config = await load_agent_config( + session=session, + config_id=config_id, + search_space_id=search_space_id, + ) + if not loaded_agent_config: + return ( + None, + None, + f"Failed to load NewLLMConfig with id {config_id}", + ) + return ( + create_chat_litellm_from_agent_config(loaded_agent_config), + loaded_agent_config, + None, + ) + + loaded_llm_config = load_global_llm_config_by_id(config_id) + if not loaded_llm_config: + return None, None, f"Failed to load LLM config with id {config_id}" + return ( + create_chat_litellm_from_config(loaded_llm_config), + AgentConfig.from_yaml_config(loaded_llm_config), + None, + ) + _t0 = time.perf_counter() try: llm_config_id = ( @@ -2199,29 +2270,11 @@ async def stream_resume_chat( yield streaming_service.format_done() return - if llm_config_id >= 0: - agent_config = await load_agent_config( - session=session, - config_id=llm_config_id, - search_space_id=search_space_id, - ) - if not agent_config: - yield streaming_service.format_error( - f"Failed to load NewLLMConfig with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - llm = create_chat_litellm_from_agent_config(agent_config) - else: - llm_config = load_global_llm_config_by_id(llm_config_id) - if not llm_config: - yield streaming_service.format_error( - f"Failed to load LLM config with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - llm = create_chat_litellm_from_config(llm_config) - agent_config = AgentConfig.from_yaml_config(llm_config) + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield streaming_service.format_error(llm_load_error) + yield streaming_service.format_done() + return _perf_log.info( "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 ) @@ -2262,11 +2315,43 @@ async def stream_resume_chat( user_id, llm_config_id, ) - yield streaming_service.format_error( - "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." - ) - yield streaming_service.format_done() - return + if requested_llm_config_id == 0: + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + force_repin_free=True, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield streaming_service.format_error(str(pin_error)) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield streaming_service.format_error(llm_load_error) + yield streaming_service.format_done() + return + _resume_premium_request_id = None + _resume_premium_reserved = 0 + logging.getLogger(__name__).info( + "premium_quota_auto_fallback_to_free thread_id=%s search_space_id=%s user_id=%s fallback_config_id=%s", + chat_id, + search_space_id, + user_id, + llm_config_id, + ) + else: + yield streaming_service.format_error( + "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." + ) + yield streaming_service.format_done() + return if not llm: yield streaming_service.format_error("Failed to create LLM instance") diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index a9853c980..f08e50ba2 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -227,6 +227,44 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): assert result.from_existing_pin is True +@pytest.mark.asyncio +async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + force_repin_free=True, + ) + assert result.resolved_llm_config_id == -2 + assert result.resolved_tier == "free" + assert result.from_existing_pin is False + assert session.thread.pinned_llm_config_id == -2 + + @pytest.mark.asyncio async def test_explicit_user_model_change_clears_pin(monkeypatch): from app.config import config