mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 22:32:39 +02:00
feat(chat): add auto model pin resolution service
This commit is contained in:
parent
57db198919
commit
41849fe10f
2 changed files with 496 additions and 0 deletions
205
surfsense_backend/app/services/auto_model_pin_service.py
Normal file
205
surfsense_backend/app/services/auto_model_pin_service.py
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
"""Resolve and persist Auto (Fastest) model pins per chat thread.
|
||||
|
||||
Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we
|
||||
resolve that virtual mode to one concrete global LLM config exactly once and
|
||||
persist the chosen config id on ``new_chat_threads`` so subsequent turns are
|
||||
stable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import NewChatThread
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AUTO_FASTEST_ID = 0
|
||||
AUTO_FASTEST_MODE = "auto_fastest"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoPinResolution:
|
||||
resolved_llm_config_id: int
|
||||
resolved_tier: str
|
||||
from_existing_pin: bool
|
||||
|
||||
|
||||
def _is_usable_global_config(cfg: dict) -> bool:
|
||||
return bool(
|
||||
cfg.get("id") is not None
|
||||
and cfg.get("model_name")
|
||||
and cfg.get("provider")
|
||||
and cfg.get("api_key")
|
||||
)
|
||||
|
||||
|
||||
def _global_candidates() -> list[dict]:
|
||||
candidates = [cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg)]
|
||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||
|
||||
|
||||
def _tier_of(cfg: dict) -> str:
|
||||
return str(cfg.get("billing_tier", "free")).lower()
|
||||
|
||||
|
||||
def _deterministic_pick(candidates: list[dict], thread_id: int) -> dict:
|
||||
digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
|
||||
idx = int.from_bytes(digest[:8], "big") % len(candidates)
|
||||
return candidates[idx]
|
||||
|
||||
|
||||
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
|
||||
if user_id is None:
|
||||
return None
|
||||
if isinstance(user_id, UUID):
|
||||
return user_id
|
||||
try:
|
||||
return UUID(str(user_id))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def _is_premium_eligible(session: AsyncSession, user_id: str | UUID | None) -> bool:
|
||||
parsed = _to_uuid(user_id)
|
||||
if parsed is None:
|
||||
return False
|
||||
usage = await TokenQuotaService.premium_get_usage(session, parsed)
|
||||
return bool(usage.allowed)
|
||||
|
||||
|
||||
async def resolve_or_get_pinned_llm_config_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
thread_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | UUID | None,
|
||||
selected_llm_config_id: int,
|
||||
) -> AutoPinResolution:
|
||||
"""Resolve Auto (Fastest) to one concrete config id and persist pin metadata.
|
||||
|
||||
For non-auto selections, this function clears existing auto pin metadata and
|
||||
returns the selected id as-is.
|
||||
"""
|
||||
thread = (
|
||||
(
|
||||
await session.execute(
|
||||
select(NewChatThread)
|
||||
.where(NewChatThread.id == thread_id)
|
||||
.with_for_update(of=NewChatThread)
|
||||
)
|
||||
)
|
||||
.unique()
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
if thread is None:
|
||||
raise ValueError(f"Thread {thread_id} not found")
|
||||
if thread.search_space_id != search_space_id:
|
||||
raise ValueError(
|
||||
f"Thread {thread_id} does not belong to search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Explicit model selected: clear stale auto pin metadata.
|
||||
if selected_llm_config_id != AUTO_FASTEST_ID:
|
||||
if (
|
||||
thread.pinned_llm_config_id is not None
|
||||
or thread.pinned_auto_mode is not None
|
||||
or thread.pinned_at is not None
|
||||
):
|
||||
thread.pinned_llm_config_id = None
|
||||
thread.pinned_auto_mode = None
|
||||
thread.pinned_at = None
|
||||
await session.commit()
|
||||
return AutoPinResolution(
|
||||
resolved_llm_config_id=selected_llm_config_id,
|
||||
resolved_tier="explicit",
|
||||
from_existing_pin=False,
|
||||
)
|
||||
|
||||
candidates = _global_candidates()
|
||||
if not candidates:
|
||||
raise ValueError("No usable global LLM configs are available for Auto mode")
|
||||
candidate_by_id = {int(c["id"]): c for c in candidates}
|
||||
|
||||
# Reuse existing valid pin without re-checking current quota (no silent tier switch).
|
||||
pinned_id = thread.pinned_llm_config_id
|
||||
if (
|
||||
thread.pinned_auto_mode == AUTO_FASTEST_MODE
|
||||
and pinned_id is not None
|
||||
and int(pinned_id) in candidate_by_id
|
||||
):
|
||||
pinned_cfg = candidate_by_id[int(pinned_id)]
|
||||
logger.info(
|
||||
"auto_pin_reused thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
_tier_of(pinned_cfg),
|
||||
)
|
||||
return AutoPinResolution(
|
||||
resolved_llm_config_id=int(pinned_id),
|
||||
resolved_tier=_tier_of(pinned_cfg),
|
||||
from_existing_pin=True,
|
||||
)
|
||||
if pinned_id is not None:
|
||||
logger.info(
|
||||
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s pinned_auto_mode=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
thread.pinned_auto_mode,
|
||||
)
|
||||
|
||||
premium_eligible = await _is_premium_eligible(session, user_id)
|
||||
if premium_eligible:
|
||||
eligible = candidates
|
||||
else:
|
||||
eligible = [c for c in candidates if _tier_of(c) != "premium"]
|
||||
|
||||
if not eligible:
|
||||
raise ValueError(
|
||||
"Auto mode could not find an eligible LLM config for this user and quota state"
|
||||
)
|
||||
|
||||
selected_cfg = _deterministic_pick(eligible, thread_id)
|
||||
selected_id = int(selected_cfg["id"])
|
||||
selected_tier = _tier_of(selected_cfg)
|
||||
|
||||
thread.pinned_llm_config_id = selected_id
|
||||
thread.pinned_auto_mode = AUTO_FASTEST_MODE
|
||||
thread.pinned_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
if pinned_id is None:
|
||||
logger.info(
|
||||
"auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
selected_id,
|
||||
selected_tier,
|
||||
premium_eligible,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"auto_pin_repaired thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
selected_id,
|
||||
selected_tier,
|
||||
premium_eligible,
|
||||
)
|
||||
return AutoPinResolution(
|
||||
resolved_llm_config_id=selected_id,
|
||||
resolved_tier=selected_tier,
|
||||
from_existing_pin=False,
|
||||
)
|
||||
|
|
@ -0,0 +1,291 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
AUTO_FASTEST_MODE,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeQuotaResult:
|
||||
allowed: bool
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, thread):
|
||||
self._thread = thread
|
||||
|
||||
def unique(self):
|
||||
return self
|
||||
|
||||
def scalar_one_or_none(self):
|
||||
return self._thread
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, thread):
|
||||
self.thread = thread
|
||||
self.commit_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self.thread)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
|
||||
def _thread(
|
||||
*,
|
||||
search_space_id: int = 10,
|
||||
pinned_llm_config_id: int | None = None,
|
||||
pinned_auto_mode: str | None = None,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
id=1,
|
||||
search_space_id=search_space_id,
|
||||
pinned_llm_config_id=pinned_llm_config_id,
|
||||
pinned_auto_mode=pinned_auto_mode,
|
||||
pinned_at=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_first_turn_pins_one_model(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id in {-1, -2}
|
||||
assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id
|
||||
assert session.thread.pinned_auto_mode == AUTO_FASTEST_MODE
|
||||
assert session.thread.pinned_at is not None
|
||||
assert session.commit_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_turn_reuses_existing_pin(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _must_not_call(*_args, **_kwargs):
|
||||
raise AssertionError("premium_get_usage should not be called for valid pin reuse")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_must_not_call,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
assert result.from_existing_pin is True
|
||||
assert session.commit_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_eligible_auto_can_pin_premium(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
assert result.resolved_tier == "premium"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"},
|
||||
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.resolved_tier == "free"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"},
|
||||
{"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
assert result.from_existing_pin is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_user_model_change_clears_pin(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-2, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||
],
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=7,
|
||||
)
|
||||
assert result.resolved_llm_config_id == 7
|
||||
assert session.thread.pinned_llm_config_id is None
|
||||
assert session.thread.pinned_auto_mode is None
|
||||
assert session.thread.pinned_at is None
|
||||
assert session.commit_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(
|
||||
_thread(pinned_llm_config_id=-999, pinned_auto_mode=AUTO_FASTEST_MODE)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||
],
|
||||
)
|
||||
|
||||
async def _allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert session.thread.pinned_llm_config_id == -2
|
||||
assert session.commit_count == 1
|
||||
Loading…
Add table
Add a link
Reference in a new issue