mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
feat(auto_pin): quality-aware tier-locked selection with health gate
This commit is contained in:
parent
1eedcaa551
commit
4bef75d298
2 changed files with 387 additions and 5 deletions
|
|
@ -24,6 +24,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import NewChatThread
|
from app.db import NewChatThread
|
||||||
|
from app.services.quality_score import _QUALITY_TOP_K
|
||||||
from app.services.token_quota_service import TokenQuotaService
|
from app.services.token_quota_service import TokenQuotaService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -49,8 +50,16 @@ def _is_usable_global_config(cfg: dict) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def _global_candidates() -> list[dict]:
|
def _global_candidates() -> list[dict]:
|
||||||
|
"""Return Auto-eligible global cfgs.
|
||||||
|
|
||||||
|
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
|
||||||
|
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
|
||||||
|
can't be picked as the thread's pin.
|
||||||
|
"""
|
||||||
candidates = [
|
candidates = [
|
||||||
cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg)
|
cfg
|
||||||
|
for cfg in config.GLOBAL_LLM_CONFIGS
|
||||||
|
if _is_usable_global_config(cfg) and not cfg.get("health_gated")
|
||||||
]
|
]
|
||||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||||
|
|
||||||
|
|
@ -59,10 +68,26 @@ def _tier_of(cfg: dict) -> str:
|
||||||
return str(cfg.get("billing_tier", "free")).lower()
|
return str(cfg.get("billing_tier", "free")).lower()
|
||||||
|
|
||||||
|
|
||||||
def _deterministic_pick(candidates: list[dict], thread_id: int) -> dict:
|
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
|
||||||
|
"""Pick a config with quality-first ranking + deterministic spread.
|
||||||
|
|
||||||
|
Tier policy is lock-first: prefer Tier A (operator-curated YAML)
|
||||||
|
cfgs and only fall through to Tier B/C (dynamic OpenRouter) if no
|
||||||
|
Tier A cfg is eligible after upstream filters. Within the locked
|
||||||
|
pool, sort by ``quality_score`` and pick from the top-K via
|
||||||
|
``SHA256(thread_id)`` so different new threads spread across the
|
||||||
|
best models without ever picking a low-ranked one.
|
||||||
|
|
||||||
|
Returns ``(chosen_cfg, top_k_size)``. ``top_k_size`` is exposed for
|
||||||
|
structured logging in the caller.
|
||||||
|
"""
|
||||||
|
tier_a = [c for c in eligible if c.get("auto_pin_tier") in (None, "A")]
|
||||||
|
pool = tier_a if tier_a else eligible
|
||||||
|
pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0))
|
||||||
|
top_k = pool[:_QUALITY_TOP_K]
|
||||||
digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
|
digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
|
||||||
idx = int.from_bytes(digest[:8], "big") % len(candidates)
|
idx = int.from_bytes(digest[:8], "big") % len(top_k)
|
||||||
return candidates[idx]
|
return top_k[idx], len(top_k)
|
||||||
|
|
||||||
|
|
||||||
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
|
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
|
||||||
|
|
@ -150,6 +175,15 @@ async def resolve_or_get_pinned_llm_config_id(
|
||||||
pinned_id,
|
pinned_id,
|
||||||
_tier_of(pinned_cfg),
|
_tier_of(pinned_cfg),
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"auto_pin_resolved thread_id=%s config_id=%s tier=%s "
|
||||||
|
"auto_pin_tier=%s score=%s top_k_size=0 from_existing_pin=True",
|
||||||
|
thread_id,
|
||||||
|
pinned_id,
|
||||||
|
_tier_of(pinned_cfg),
|
||||||
|
pinned_cfg.get("auto_pin_tier", "?"),
|
||||||
|
int(pinned_cfg.get("quality_score") or 0),
|
||||||
|
)
|
||||||
return AutoPinResolution(
|
return AutoPinResolution(
|
||||||
resolved_llm_config_id=int(pinned_id),
|
resolved_llm_config_id=int(pinned_id),
|
||||||
resolved_tier=_tier_of(pinned_cfg),
|
resolved_tier=_tier_of(pinned_cfg),
|
||||||
|
|
@ -176,7 +210,7 @@ async def resolve_or_get_pinned_llm_config_id(
|
||||||
"Auto mode could not find an eligible LLM config for this user and quota state"
|
"Auto mode could not find an eligible LLM config for this user and quota state"
|
||||||
)
|
)
|
||||||
|
|
||||||
selected_cfg = _deterministic_pick(eligible, thread_id)
|
selected_cfg, top_k_size = _select_pin(eligible, thread_id)
|
||||||
selected_id = int(selected_cfg["id"])
|
selected_id = int(selected_cfg["id"])
|
||||||
selected_tier = _tier_of(selected_cfg)
|
selected_tier = _tier_of(selected_cfg)
|
||||||
|
|
||||||
|
|
@ -211,6 +245,18 @@ async def resolve_or_get_pinned_llm_config_id(
|
||||||
selected_tier,
|
selected_tier,
|
||||||
premium_eligible,
|
premium_eligible,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"auto_pin_resolved thread_id=%s config_id=%s tier=%s "
|
||||||
|
"auto_pin_tier=%s score=%s top_k_size=%d from_existing_pin=False",
|
||||||
|
thread_id,
|
||||||
|
selected_id,
|
||||||
|
selected_tier,
|
||||||
|
selected_cfg.get("auto_pin_tier", "?"),
|
||||||
|
int(selected_cfg.get("quality_score") or 0),
|
||||||
|
top_k_size,
|
||||||
|
)
|
||||||
|
|
||||||
return AutoPinResolution(
|
return AutoPinResolution(
|
||||||
resolved_llm_config_id=selected_id,
|
resolved_llm_config_id=selected_id,
|
||||||
resolved_tier=selected_tier,
|
resolved_tier=selected_tier,
|
||||||
|
|
|
||||||
|
|
@ -365,3 +365,339 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
|
||||||
assert result.resolved_llm_config_id == -2
|
assert result.resolved_llm_config_id == -2
|
||||||
assert session.thread.pinned_llm_config_id == -2
|
assert session.thread.pinned_llm_config_id == -2
|
||||||
assert session.commit_count == 1
|
assert session.commit_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Quality-aware pin selection (Auto Fastest upgrade)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
|
||||||
|
"""A cfg flagged ``health_gated`` must never be picked even if it has
|
||||||
|
the highest score among eligible cfgs."""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "venice/dead-model",
|
||||||
|
"api_key": "k1",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 95,
|
||||||
|
"health_gated": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemini-flash",
|
||||||
|
"api_key": "k1",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 60,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _blocked(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=False)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_blocked,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
|
||||||
|
"""Premium-eligible users with Tier A available should never spill to
|
||||||
|
Tier B even if a B cfg ranks higher by ``quality_score``."""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"api_key": "k-yaml",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 70,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "openai/gpt-5",
|
||||||
|
"api_key": "k-or",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "B",
|
||||||
|
"quality_score": 95,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _allowed(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_allowed,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -1
|
||||||
|
assert result.resolved_tier == "premium"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch):
|
||||||
|
"""Free-only user with no Tier A free cfg should pick from Tier C."""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"api_key": "k-yaml",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 100,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemini-flash:free",
|
||||||
|
"api_key": "k-or",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 60,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _blocked(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=False)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_blocked,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_top_k_picks_only_high_score_models(monkeypatch):
|
||||||
|
"""Different thread IDs should spread across top-K, never pick the
|
||||||
|
obvious low-quality cfg even when it sits in the candidate list."""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
high_score_cfgs = [
|
||||||
|
{
|
||||||
|
"id": -i,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": f"gpt-x-{i}",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 90,
|
||||||
|
"health_gated": False,
|
||||||
|
}
|
||||||
|
for i in range(1, 6) # 5 high-quality Tier A cfgs
|
||||||
|
]
|
||||||
|
low_score_trap = {
|
||||||
|
"id": -99,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "tiny-legacy",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 10,
|
||||||
|
"health_gated": False,
|
||||||
|
}
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
high_score_cfgs + [low_score_trap],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _allowed(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_allowed,
|
||||||
|
)
|
||||||
|
|
||||||
|
high_score_ids = {c["id"] for c in high_score_cfgs}
|
||||||
|
seen = set()
|
||||||
|
for thread_id in range(1, 50):
|
||||||
|
session = _FakeSession(_thread())
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=thread_id,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
seen.add(result.resolved_llm_config_id)
|
||||||
|
assert result.resolved_llm_config_id != -99, (
|
||||||
|
"low-score trap cfg should never be picked"
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id in high_score_ids
|
||||||
|
|
||||||
|
# Spread across at least a couple of top-K cfgs.
|
||||||
|
assert len(seen) > 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
|
||||||
|
"""An *already* pinned cfg that later flips to ``health_gated`` should
|
||||||
|
still not be reused — gated cfgs are filtered out of the candidate
|
||||||
|
pool, which forces a repair to a healthy cfg.
|
||||||
|
|
||||||
|
This guards the no-silent-tier-switch invariant: we don't keep using
|
||||||
|
a known-broken model just because the thread happened to be pinned
|
||||||
|
to it before the gate fired."""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "venice/dead-model",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "B",
|
||||||
|
"quality_score": 50,
|
||||||
|
"health_gated": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 90,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _allowed(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_allowed,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -2
|
||||||
|
assert result.from_existing_pin is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
|
||||||
|
"""Existing pin reuse must short-circuit the new tier/score logic."""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 50, # lower than -2
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5-pro",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 99,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _must_not_call(*_args, **_kwargs):
|
||||||
|
raise AssertionError("premium_get_usage should not run on pin reuse")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_must_not_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -1
|
||||||
|
assert result.from_existing_pin is True
|
||||||
|
assert session.commit_count == 0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue