From 41849fe10f5fbe9a4792ad665308f5fea4c37721 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:15:15 +0530 Subject: [PATCH] feat(chat): add auto model pin resolution service --- .../app/services/auto_model_pin_service.py | 205 ++++++++++++ .../services/test_auto_model_pin_service.py | 291 ++++++++++++++++++ 2 files changed, 496 insertions(+) create mode 100644 surfsense_backend/app/services/auto_model_pin_service.py create mode 100644 surfsense_backend/tests/unit/services/test_auto_model_pin_service.py diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py new file mode 100644 index 000000000..ce417a26d --- /dev/null +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -0,0 +1,205 @@ +"""Resolve and persist Auto (Fastest) model pins per chat thread. + +Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we +resolve that virtual mode to one concrete global LLM config exactly once and +persist the chosen config id on ``new_chat_threads`` so subsequent turns are +stable. +""" + +from __future__ import annotations + +import hashlib +import logging +from dataclasses import dataclass +from datetime import UTC, datetime +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import NewChatThread +from app.services.token_quota_service import TokenQuotaService + +logger = logging.getLogger(__name__) + +AUTO_FASTEST_ID = 0 +AUTO_FASTEST_MODE = "auto_fastest" + + +@dataclass +class AutoPinResolution: + resolved_llm_config_id: int + resolved_tier: str + from_existing_pin: bool + + +def _is_usable_global_config(cfg: dict) -> bool: + return bool( + cfg.get("id") is not None + and cfg.get("model_name") + and cfg.get("provider") + and cfg.get("api_key") + ) + + +def _global_candidates() -> list[dict]: + candidates = [cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg)] + return sorted(candidates, key=lambda c: int(c.get("id", 0))) + + +def _tier_of(cfg: dict) -> str: + return str(cfg.get("billing_tier", "free")).lower() + + +def _deterministic_pick(candidates: list[dict], thread_id: int) -> dict: + digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest() + idx = int.from_bytes(digest[:8], "big") % len(candidates) + return candidates[idx] + + +def _to_uuid(user_id: str | UUID | None) -> UUID | None: + if user_id is None: + return None + if isinstance(user_id, UUID): + return user_id + try: + return UUID(str(user_id)) + except Exception: + return None + + +async def _is_premium_eligible(session: AsyncSession, user_id: str | UUID | None) -> bool: + parsed = _to_uuid(user_id) + if parsed is None: + return False + usage = await TokenQuotaService.premium_get_usage(session, parsed) + return bool(usage.allowed) + + +async def resolve_or_get_pinned_llm_config_id( + session: AsyncSession, + *, + thread_id: int, + search_space_id: int, + user_id: str | UUID | None, + selected_llm_config_id: int, +) -> AutoPinResolution: + """Resolve Auto (Fastest) to one concrete config id and persist pin metadata. + + For non-auto selections, this function clears existing auto pin metadata and + returns the selected id as-is. + """ + thread = ( + ( + await session.execute( + select(NewChatThread) + .where(NewChatThread.id == thread_id) + .with_for_update(of=NewChatThread) + ) + ) + .unique() + .scalar_one_or_none() + ) + if thread is None: + raise ValueError(f"Thread {thread_id} not found") + if thread.search_space_id != search_space_id: + raise ValueError( + f"Thread {thread_id} does not belong to search space {search_space_id}" + ) + + # Explicit model selected: clear stale auto pin metadata. + if selected_llm_config_id != AUTO_FASTEST_ID: + if ( + thread.pinned_llm_config_id is not None + or thread.pinned_auto_mode is not None + or thread.pinned_at is not None + ): + thread.pinned_llm_config_id = None + thread.pinned_auto_mode = None + thread.pinned_at = None + await session.commit() + return AutoPinResolution( + resolved_llm_config_id=selected_llm_config_id, + resolved_tier="explicit", + from_existing_pin=False, + ) + + candidates = _global_candidates() + if not candidates: + raise ValueError("No usable global LLM configs are available for Auto mode") + candidate_by_id = {int(c["id"]): c for c in candidates} + + # Reuse existing valid pin without re-checking current quota (no silent tier switch). + pinned_id = thread.pinned_llm_config_id + if ( + thread.pinned_auto_mode == AUTO_FASTEST_MODE + and pinned_id is not None + and int(pinned_id) in candidate_by_id + ): + pinned_cfg = candidate_by_id[int(pinned_id)] + logger.info( + "auto_pin_reused thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s", + thread_id, + search_space_id, + pinned_id, + _tier_of(pinned_cfg), + ) + return AutoPinResolution( + resolved_llm_config_id=int(pinned_id), + resolved_tier=_tier_of(pinned_cfg), + from_existing_pin=True, + ) + if pinned_id is not None: + logger.info( + "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s pinned_auto_mode=%s", + thread_id, + search_space_id, + pinned_id, + thread.pinned_auto_mode, + ) + + premium_eligible = await _is_premium_eligible(session, user_id) + if premium_eligible: + eligible = candidates + else: + eligible = [c for c in candidates if _tier_of(c) != "premium"] + + if not eligible: + raise ValueError( + "Auto mode could not find an eligible LLM config for this user and quota state" + ) + + selected_cfg = _deterministic_pick(eligible, thread_id) + selected_id = int(selected_cfg["id"]) + selected_tier = _tier_of(selected_cfg) + + thread.pinned_llm_config_id = selected_id + thread.pinned_auto_mode = AUTO_FASTEST_MODE + thread.pinned_at = datetime.now(UTC) + await session.commit() + + if pinned_id is None: + logger.info( + "auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s", + thread_id, + search_space_id, + selected_id, + selected_tier, + premium_eligible, + ) + else: + logger.info( + "auto_pin_repaired thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s tier=%s premium_eligible=%s", + thread_id, + search_space_id, + pinned_id, + selected_id, + selected_tier, + premium_eligible, + ) + return AutoPinResolution( + resolved_llm_config_id=selected_id, + resolved_tier=selected_tier, + from_existing_pin=False, + ) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py new file mode 100644 index 000000000..a9853c980 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest + +from app.services.auto_model_pin_service import ( + AUTO_FASTEST_MODE, + resolve_or_get_pinned_llm_config_id, +) + +pytestmark = pytest.mark.unit + + +@dataclass +class _FakeQuotaResult: + allowed: bool + + +class _FakeExecResult: + def __init__(self, thread): + self._thread = thread + + def unique(self): + return self + + def scalar_one_or_none(self): + return self._thread + + +class _FakeSession: + def __init__(self, thread): + self.thread = thread + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self.thread) + + async def commit(self): + self.commit_count += 1 + + +def _thread( + *, + search_space_id: int = 10, + pinned_llm_config_id: int | None = None, + pinned_auto_mode: str | None = None, +): + return SimpleNamespace( + id=1, + search_space_id=search_space_id, + pinned_llm_config_id=pinned_llm_config_id, + pinned_auto_mode=pinned_auto_mode, + pinned_at=None, + ) + + +@pytest.mark.asyncio +async def test_auto_first_turn_pins_one_model(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id in {-1, -2} + assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id + assert session.thread.pinned_auto_mode == AUTO_FASTEST_MODE + assert session.thread.pinned_at is not None + assert session.commit_count == 1 + + +@pytest.mark.asyncio +async def test_next_turn_reuses_existing_pin(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError("premium_get_usage should not be called for valid pin reuse") + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + assert session.commit_count == 0 + + +@pytest.mark.asyncio +async def test_premium_eligible_auto_can_pin_premium(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.resolved_tier == "premium" + + +@pytest.mark.asyncio +async def test_premium_ineligible_auto_pins_free_only(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.resolved_tier == "free" + + +@pytest.mark.asyncio +async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_explicit_user_model_change_clears_pin(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-2, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + ], + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=7, + ) + assert result.resolved_llm_config_id == 7 + assert session.thread.pinned_llm_config_id is None + assert session.thread.pinned_auto_mode is None + assert session.thread.pinned_at is None + assert session.commit_count == 1 + + +@pytest.mark.asyncio +async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-999, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert session.thread.pinned_llm_config_id == -2 + assert session.commit_count == 1