mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-14 20:55:15 +02:00
feat(runtime-cooldown): implement Redis-based shared cooldown management for model selection
This commit is contained in:
parent
6d7732436d
commit
4a6a282a46
2 changed files with 242 additions and 26 deletions
|
|
@ -21,6 +21,7 @@ import time
|
|||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
|
@ -39,12 +40,16 @@ AUTO_MODE_ID = 0
|
|||
AUTO_PIN_HASH_NAMESPACE = "auto_fastest"
|
||||
_RUNTIME_COOLDOWN_SECONDS = 600
|
||||
_HEALTHY_TTL_SECONDS = 45
|
||||
_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX = "auto:cooldown:llm:"
|
||||
_REDIS_TIMEOUT_SECONDS = 0.2
|
||||
|
||||
# In-memory runtime cooldown map for configs that recently hard-failed at
|
||||
# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps
|
||||
# the same unhealthy config from being reselected immediately during repair.
|
||||
_runtime_cooldown_until: dict[int, float] = {}
|
||||
_runtime_cooldown_lock = threading.Lock()
|
||||
_runtime_cooldown_redis: redis.Redis | None = None
|
||||
_runtime_cooldown_redis_lock = threading.Lock()
|
||||
|
||||
# Short-TTL "recently healthy" cache for configs that just passed a runtime
|
||||
# preflight ping. Lets back-to-back turns on the same model skip the probe
|
||||
|
|
@ -87,6 +92,79 @@ def _is_runtime_cooled_down(config_id: int) -> bool:
|
|||
return config_id in _runtime_cooldown_until
|
||||
|
||||
|
||||
def _runtime_cooldown_redis_key(config_id: int) -> str:
|
||||
return f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}{int(config_id)}"
|
||||
|
||||
|
||||
def _get_runtime_cooldown_redis() -> redis.Redis:
|
||||
global _runtime_cooldown_redis
|
||||
if _runtime_cooldown_redis is None:
|
||||
with _runtime_cooldown_redis_lock:
|
||||
if _runtime_cooldown_redis is None:
|
||||
_runtime_cooldown_redis = redis.from_url(
|
||||
config.REDIS_APP_URL,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=_REDIS_TIMEOUT_SECONDS,
|
||||
socket_timeout=_REDIS_TIMEOUT_SECONDS,
|
||||
)
|
||||
return _runtime_cooldown_redis
|
||||
|
||||
|
||||
def _mark_shared_runtime_cooldown(
|
||||
config_id: int,
|
||||
*,
|
||||
reason: str,
|
||||
cooldown_seconds: int,
|
||||
) -> None:
|
||||
try:
|
||||
_get_runtime_cooldown_redis().set(
|
||||
_runtime_cooldown_redis_key(config_id),
|
||||
reason,
|
||||
ex=int(cooldown_seconds),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"auto_pin_runtime_cooldown_redis_write_failed config_id=%s",
|
||||
config_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def _shared_runtime_cooled_down_ids(config_ids: list[int]) -> set[int]:
|
||||
unique_ids = list(dict.fromkeys(int(cid) for cid in config_ids))
|
||||
if not unique_ids:
|
||||
return set()
|
||||
try:
|
||||
values = _get_runtime_cooldown_redis().mget(
|
||||
[_runtime_cooldown_redis_key(cid) for cid in unique_ids]
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"auto_pin_runtime_cooldown_redis_read_failed count=%s",
|
||||
len(unique_ids),
|
||||
exc_info=True,
|
||||
)
|
||||
return set()
|
||||
return {cid for cid, value in zip(unique_ids, values, strict=False) if value is not None}
|
||||
|
||||
|
||||
def _clear_shared_runtime_cooldown(config_id: int | None = None) -> None:
|
||||
try:
|
||||
client = _get_runtime_cooldown_redis()
|
||||
if config_id is not None:
|
||||
client.delete(_runtime_cooldown_redis_key(config_id))
|
||||
return
|
||||
keys = list(client.scan_iter(f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}*"))
|
||||
if keys:
|
||||
client.delete(*keys)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"auto_pin_runtime_cooldown_redis_clear_failed config_id=%s",
|
||||
config_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def mark_runtime_cooldown(
|
||||
config_id: int,
|
||||
*,
|
||||
|
|
@ -105,6 +183,11 @@ def mark_runtime_cooldown(
|
|||
with _runtime_cooldown_lock:
|
||||
_runtime_cooldown_until[int(config_id)] = until
|
||||
_prune_runtime_cooldowns()
|
||||
_mark_shared_runtime_cooldown(
|
||||
int(config_id),
|
||||
reason=reason,
|
||||
cooldown_seconds=int(cooldown_seconds),
|
||||
)
|
||||
# A cooled cfg can never be "recently healthy"; drop any stale credit so
|
||||
# the next turn that resolves to it (after cooldown) re-runs preflight.
|
||||
clear_healthy(int(config_id))
|
||||
|
|
@ -121,8 +204,9 @@ def clear_runtime_cooldown(config_id: int | None = None) -> None:
|
|||
with _runtime_cooldown_lock:
|
||||
if config_id is None:
|
||||
_runtime_cooldown_until.clear()
|
||||
return
|
||||
_runtime_cooldown_until.pop(int(config_id), None)
|
||||
else:
|
||||
_runtime_cooldown_until.pop(int(config_id), None)
|
||||
_clear_shared_runtime_cooldown(config_id)
|
||||
|
||||
|
||||
def _prune_healthy(now_ts: float | None = None) -> None:
|
||||
|
|
@ -205,6 +289,7 @@ def _global_candidates(
|
|||
*,
|
||||
capability: str = "chat",
|
||||
requires_image_input: bool = False,
|
||||
shared_cooled_down_ids: set[int] | None = None,
|
||||
) -> list[dict]:
|
||||
"""Return Auto-eligible global virtual models.
|
||||
|
||||
|
|
@ -228,9 +313,14 @@ def _global_candidates(
|
|||
if _is_usable_global_config(cfg)
|
||||
}
|
||||
candidates: list[dict] = []
|
||||
shared_cooled_down_ids = shared_cooled_down_ids or set()
|
||||
for model in config.GLOBAL_MODELS:
|
||||
model_id = int(model.get("id", 0))
|
||||
if model_id >= 0 or _is_runtime_cooled_down(model_id):
|
||||
if (
|
||||
model_id >= 0
|
||||
or _is_runtime_cooled_down(model_id)
|
||||
or model_id in shared_cooled_down_ids
|
||||
):
|
||||
continue
|
||||
if not _has_capability(model, capability):
|
||||
continue
|
||||
|
|
@ -287,8 +377,12 @@ async def _db_candidates(
|
|||
.where(Model.enabled.is_(True), Connection.enabled.is_(True))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
models = result.scalars().all()
|
||||
shared_cooled_down_ids = _shared_runtime_cooled_down_ids(
|
||||
[int(model.id) for model in models]
|
||||
)
|
||||
candidates: list[dict] = []
|
||||
for model in result.scalars().all():
|
||||
for model in models:
|
||||
conn = model.connection
|
||||
if not conn:
|
||||
continue
|
||||
|
|
@ -303,7 +397,7 @@ async def _db_candidates(
|
|||
if requires_image_input and not _has_capability(model, "vision"):
|
||||
continue
|
||||
model_id = int(model.id)
|
||||
if _is_runtime_cooled_down(model_id):
|
||||
if _is_runtime_cooled_down(model_id) or model_id in shared_cooled_down_ids:
|
||||
continue
|
||||
catalog = model.catalog or {}
|
||||
candidates.append(
|
||||
|
|
@ -337,6 +431,12 @@ async def auto_model_candidates(
|
|||
exclude_model_ids: set[int] | None = None,
|
||||
) -> list[dict]:
|
||||
excluded_ids = {int(mid) for mid in (exclude_model_ids or set())}
|
||||
global_ids = [
|
||||
int(model.get("id", 0))
|
||||
for model in config.GLOBAL_MODELS
|
||||
if int(model.get("id", 0)) < 0
|
||||
]
|
||||
shared_global_cooled_down_ids = _shared_runtime_cooled_down_ids(global_ids)
|
||||
db_candidates = await _db_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
|
|
@ -348,6 +448,7 @@ async def auto_model_candidates(
|
|||
*_global_candidates(
|
||||
capability=capability,
|
||||
requires_image_input=requires_image_input,
|
||||
shared_cooled_down_ids=shared_global_cooled_down_ids,
|
||||
),
|
||||
*db_candidates,
|
||||
]
|
||||
|
|
@ -358,16 +459,6 @@ def _tier_of(cfg: dict) -> str:
|
|||
return str(cfg.get("billing_tier", "free")).lower()
|
||||
|
||||
|
||||
def _is_preferred_premium_auto_config(cfg: dict) -> bool:
|
||||
"""Return True for the operator-preferred premium Auto model."""
|
||||
return (
|
||||
cfg.get("source") == "global"
|
||||
and _tier_of(cfg) == "premium"
|
||||
and str(cfg.get("provider", "")).lower() == "azure"
|
||||
and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
|
||||
)
|
||||
|
||||
|
||||
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
|
||||
"""Pick a config with quality-first ranking + deterministic spread.
|
||||
|
||||
|
|
@ -546,10 +637,7 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
byok_candidates = [c for c in candidates if _tier_of(c) == "byok"]
|
||||
if premium_eligible:
|
||||
premium_candidates = [c for c in candidates if _tier_of(c) == "premium"]
|
||||
preferred_premium = [
|
||||
c for c in premium_candidates if _is_preferred_premium_auto_config(c)
|
||||
]
|
||||
eligible = preferred_premium or premium_candidates or byok_candidates
|
||||
eligible = premium_candidates or byok_candidates
|
||||
else:
|
||||
eligible = [c for c in candidates if _tier_of(c) != "premium"]
|
||||
|
||||
|
|
|
|||
|
|
@ -17,8 +17,39 @@ from app.services.auto_model_pin_service import (
|
|||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self):
|
||||
self.values: dict[str, str] = {}
|
||||
self.ttls: dict[str, int] = {}
|
||||
|
||||
def set(self, key: str, value: str, *, ex: int | None = None):
|
||||
self.values[key] = value
|
||||
if ex is not None:
|
||||
self.ttls[key] = ex
|
||||
return True
|
||||
|
||||
def mget(self, keys: list[str]):
|
||||
return [self.values.get(key) for key in keys]
|
||||
|
||||
def delete(self, *keys: str):
|
||||
removed = 0
|
||||
for key in keys:
|
||||
if key in self.values:
|
||||
removed += 1
|
||||
self.values.pop(key, None)
|
||||
self.ttls.pop(key, None)
|
||||
return removed
|
||||
|
||||
def scan_iter(self, pattern: str):
|
||||
prefix = pattern.removesuffix("*")
|
||||
return (key for key in list(self.values) if key.startswith(prefix))
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_runtime_cooldown_map():
|
||||
def _clear_runtime_cooldown_map(monkeypatch):
|
||||
import app.services.auto_model_pin_service as svc
|
||||
|
||||
monkeypatch.setattr(svc, "_runtime_cooldown_redis", _FakeRedis())
|
||||
clear_runtime_cooldown()
|
||||
clear_healthy()
|
||||
yield
|
||||
|
|
@ -205,7 +236,9 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
|
||||
async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_model(
|
||||
monkeypatch,
|
||||
):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
|
|
@ -233,12 +266,39 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -3,
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "openai/gpt-5.4",
|
||||
"litellm_provider": "anthropic",
|
||||
"model_name": "claude-opus",
|
||||
"api_key": "k3",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "B",
|
||||
"quality_score": 100,
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 99,
|
||||
},
|
||||
{
|
||||
"id": -4,
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-5.3",
|
||||
"api_key": "k4",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 98,
|
||||
},
|
||||
{
|
||||
"id": -5,
|
||||
"litellm_provider": "gemini",
|
||||
"model_name": "gemini-3-pro",
|
||||
"api_key": "k5",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 97,
|
||||
},
|
||||
{
|
||||
"id": -6,
|
||||
"litellm_provider": "xai",
|
||||
"model_name": "grok-5",
|
||||
"api_key": "k6",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 96,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
|
@ -258,7 +318,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
|
|||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.resolved_llm_config_id in {-1, -3, -4, -5, -6}
|
||||
assert result.resolved_tier == "premium"
|
||||
|
||||
|
||||
|
|
@ -932,6 +992,74 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
|
|||
assert result.from_existing_pin is False
|
||||
|
||||
|
||||
def test_mark_runtime_cooldown_writes_shared_redis(monkeypatch):
|
||||
import app.services.auto_model_pin_service as svc
|
||||
|
||||
mark_runtime_cooldown(-9, reason="provider_rate_limited", cooldown_seconds=123)
|
||||
|
||||
redis_client = svc._runtime_cooldown_redis
|
||||
assert redis_client.values["auto:cooldown:llm:-9"] == "provider_rate_limited"
|
||||
assert redis_client.ttls["auto:cooldown:llm:-9"] == 123
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shared_runtime_cooldown_blocks_pin_across_workers(monkeypatch):
|
||||
"""A Redis cooldown written by another worker should invalidate local pins."""
|
||||
import app.services.auto_model_pin_service as svc
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
_set_global_llm_configs(
|
||||
monkeypatch,
|
||||
config,
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 90,
|
||||
"health_gated": False,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-2.5-flash:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 80,
|
||||
"health_gated": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
svc._runtime_cooldown_redis.set(
|
||||
"auto:cooldown:llm:-1",
|
||||
"provider_rate_limited",
|
||||
ex=600,
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
|
||||
from app.config import config
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue