feat(runtime-cooldown): implement Redis-based shared cooldown management for model selection

This commit is contained in:
Anish Sarkar 2026-06-13 13:53:01 +05:30
parent 6d7732436d
commit 4a6a282a46
2 changed files with 242 additions and 26 deletions

View file

@ -21,6 +21,7 @@ import time
from dataclasses import dataclass
from uuid import UUID
import redis
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
@ -39,12 +40,16 @@ AUTO_MODE_ID = 0
AUTO_PIN_HASH_NAMESPACE = "auto_fastest"
_RUNTIME_COOLDOWN_SECONDS = 600
_HEALTHY_TTL_SECONDS = 45
_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX = "auto:cooldown:llm:"
_REDIS_TIMEOUT_SECONDS = 0.2
# In-memory runtime cooldown map for configs that recently hard-failed at
# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps
# the same unhealthy config from being reselected immediately during repair.
_runtime_cooldown_until: dict[int, float] = {}
_runtime_cooldown_lock = threading.Lock()
_runtime_cooldown_redis: redis.Redis | None = None
_runtime_cooldown_redis_lock = threading.Lock()
# Short-TTL "recently healthy" cache for configs that just passed a runtime
# preflight ping. Lets back-to-back turns on the same model skip the probe
@ -87,6 +92,79 @@ def _is_runtime_cooled_down(config_id: int) -> bool:
return config_id in _runtime_cooldown_until
def _runtime_cooldown_redis_key(config_id: int) -> str:
return f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}{int(config_id)}"
def _get_runtime_cooldown_redis() -> redis.Redis:
global _runtime_cooldown_redis
if _runtime_cooldown_redis is None:
with _runtime_cooldown_redis_lock:
if _runtime_cooldown_redis is None:
_runtime_cooldown_redis = redis.from_url(
config.REDIS_APP_URL,
decode_responses=True,
socket_connect_timeout=_REDIS_TIMEOUT_SECONDS,
socket_timeout=_REDIS_TIMEOUT_SECONDS,
)
return _runtime_cooldown_redis
def _mark_shared_runtime_cooldown(
config_id: int,
*,
reason: str,
cooldown_seconds: int,
) -> None:
try:
_get_runtime_cooldown_redis().set(
_runtime_cooldown_redis_key(config_id),
reason,
ex=int(cooldown_seconds),
)
except Exception:
logger.warning(
"auto_pin_runtime_cooldown_redis_write_failed config_id=%s",
config_id,
exc_info=True,
)
def _shared_runtime_cooled_down_ids(config_ids: list[int]) -> set[int]:
unique_ids = list(dict.fromkeys(int(cid) for cid in config_ids))
if not unique_ids:
return set()
try:
values = _get_runtime_cooldown_redis().mget(
[_runtime_cooldown_redis_key(cid) for cid in unique_ids]
)
except Exception:
logger.warning(
"auto_pin_runtime_cooldown_redis_read_failed count=%s",
len(unique_ids),
exc_info=True,
)
return set()
return {cid for cid, value in zip(unique_ids, values, strict=False) if value is not None}
def _clear_shared_runtime_cooldown(config_id: int | None = None) -> None:
try:
client = _get_runtime_cooldown_redis()
if config_id is not None:
client.delete(_runtime_cooldown_redis_key(config_id))
return
keys = list(client.scan_iter(f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}*"))
if keys:
client.delete(*keys)
except Exception:
logger.warning(
"auto_pin_runtime_cooldown_redis_clear_failed config_id=%s",
config_id,
exc_info=True,
)
def mark_runtime_cooldown(
config_id: int,
*,
@ -105,6 +183,11 @@ def mark_runtime_cooldown(
with _runtime_cooldown_lock:
_runtime_cooldown_until[int(config_id)] = until
_prune_runtime_cooldowns()
_mark_shared_runtime_cooldown(
int(config_id),
reason=reason,
cooldown_seconds=int(cooldown_seconds),
)
# A cooled cfg can never be "recently healthy"; drop any stale credit so
# the next turn that resolves to it (after cooldown) re-runs preflight.
clear_healthy(int(config_id))
@ -121,8 +204,9 @@ def clear_runtime_cooldown(config_id: int | None = None) -> None:
with _runtime_cooldown_lock:
if config_id is None:
_runtime_cooldown_until.clear()
return
_runtime_cooldown_until.pop(int(config_id), None)
else:
_runtime_cooldown_until.pop(int(config_id), None)
_clear_shared_runtime_cooldown(config_id)
def _prune_healthy(now_ts: float | None = None) -> None:
@ -205,6 +289,7 @@ def _global_candidates(
*,
capability: str = "chat",
requires_image_input: bool = False,
shared_cooled_down_ids: set[int] | None = None,
) -> list[dict]:
"""Return Auto-eligible global virtual models.
@ -228,9 +313,14 @@ def _global_candidates(
if _is_usable_global_config(cfg)
}
candidates: list[dict] = []
shared_cooled_down_ids = shared_cooled_down_ids or set()
for model in config.GLOBAL_MODELS:
model_id = int(model.get("id", 0))
if model_id >= 0 or _is_runtime_cooled_down(model_id):
if (
model_id >= 0
or _is_runtime_cooled_down(model_id)
or model_id in shared_cooled_down_ids
):
continue
if not _has_capability(model, capability):
continue
@ -287,8 +377,12 @@ async def _db_candidates(
.where(Model.enabled.is_(True), Connection.enabled.is_(True))
)
result = await session.execute(stmt)
models = result.scalars().all()
shared_cooled_down_ids = _shared_runtime_cooled_down_ids(
[int(model.id) for model in models]
)
candidates: list[dict] = []
for model in result.scalars().all():
for model in models:
conn = model.connection
if not conn:
continue
@ -303,7 +397,7 @@ async def _db_candidates(
if requires_image_input and not _has_capability(model, "vision"):
continue
model_id = int(model.id)
if _is_runtime_cooled_down(model_id):
if _is_runtime_cooled_down(model_id) or model_id in shared_cooled_down_ids:
continue
catalog = model.catalog or {}
candidates.append(
@ -337,6 +431,12 @@ async def auto_model_candidates(
exclude_model_ids: set[int] | None = None,
) -> list[dict]:
excluded_ids = {int(mid) for mid in (exclude_model_ids or set())}
global_ids = [
int(model.get("id", 0))
for model in config.GLOBAL_MODELS
if int(model.get("id", 0)) < 0
]
shared_global_cooled_down_ids = _shared_runtime_cooled_down_ids(global_ids)
db_candidates = await _db_candidates(
session,
search_space_id=search_space_id,
@ -348,6 +448,7 @@ async def auto_model_candidates(
*_global_candidates(
capability=capability,
requires_image_input=requires_image_input,
shared_cooled_down_ids=shared_global_cooled_down_ids,
),
*db_candidates,
]
@ -358,16 +459,6 @@ def _tier_of(cfg: dict) -> str:
return str(cfg.get("billing_tier", "free")).lower()
def _is_preferred_premium_auto_config(cfg: dict) -> bool:
"""Return True for the operator-preferred premium Auto model."""
return (
cfg.get("source") == "global"
and _tier_of(cfg) == "premium"
and str(cfg.get("provider", "")).lower() == "azure"
and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
)
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
"""Pick a config with quality-first ranking + deterministic spread.
@ -546,10 +637,7 @@ async def resolve_or_get_pinned_llm_config_id(
byok_candidates = [c for c in candidates if _tier_of(c) == "byok"]
if premium_eligible:
premium_candidates = [c for c in candidates if _tier_of(c) == "premium"]
preferred_premium = [
c for c in premium_candidates if _is_preferred_premium_auto_config(c)
]
eligible = preferred_premium or premium_candidates or byok_candidates
eligible = premium_candidates or byok_candidates
else:
eligible = [c for c in candidates if _tier_of(c) != "premium"]

View file

@ -17,8 +17,39 @@ from app.services.auto_model_pin_service import (
pytestmark = pytest.mark.unit
class _FakeRedis:
def __init__(self):
self.values: dict[str, str] = {}
self.ttls: dict[str, int] = {}
def set(self, key: str, value: str, *, ex: int | None = None):
self.values[key] = value
if ex is not None:
self.ttls[key] = ex
return True
def mget(self, keys: list[str]):
return [self.values.get(key) for key in keys]
def delete(self, *keys: str):
removed = 0
for key in keys:
if key in self.values:
removed += 1
self.values.pop(key, None)
self.ttls.pop(key, None)
return removed
def scan_iter(self, pattern: str):
prefix = pattern.removesuffix("*")
return (key for key in list(self.values) if key.startswith(prefix))
@pytest.fixture(autouse=True)
def _clear_runtime_cooldown_map():
def _clear_runtime_cooldown_map(monkeypatch):
import app.services.auto_model_pin_service as svc
monkeypatch.setattr(svc, "_runtime_cooldown_redis", _FakeRedis())
clear_runtime_cooldown()
clear_healthy()
yield
@ -205,7 +236,9 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
@pytest.mark.asyncio
async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_model(
monkeypatch,
):
from app.config import config
session = _FakeSession(_thread())
@ -233,12 +266,39 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
},
{
"id": -3,
"litellm_provider": "openrouter",
"model_name": "openai/gpt-5.4",
"litellm_provider": "anthropic",
"model_name": "claude-opus",
"api_key": "k3",
"billing_tier": "premium",
"auto_pin_tier": "B",
"quality_score": 100,
"auto_pin_tier": "A",
"quality_score": 99,
},
{
"id": -4,
"litellm_provider": "openai",
"model_name": "gpt-5.3",
"api_key": "k4",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 98,
},
{
"id": -5,
"litellm_provider": "gemini",
"model_name": "gemini-3-pro",
"api_key": "k5",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 97,
},
{
"id": -6,
"litellm_provider": "xai",
"model_name": "grok-5",
"api_key": "k6",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 96,
},
],
)
@ -258,7 +318,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
assert result.resolved_llm_config_id in {-1, -3, -4, -5, -6}
assert result.resolved_tier == "premium"
@ -932,6 +992,74 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
assert result.from_existing_pin is False
def test_mark_runtime_cooldown_writes_shared_redis(monkeypatch):
import app.services.auto_model_pin_service as svc
mark_runtime_cooldown(-9, reason="provider_rate_limited", cooldown_seconds=123)
redis_client = svc._runtime_cooldown_redis
assert redis_client.values["auto:cooldown:llm:-9"] == "provider_rate_limited"
assert redis_client.ttls["auto:cooldown:llm:-9"] == 123
@pytest.mark.asyncio
async def test_shared_runtime_cooldown_blocks_pin_across_workers(monkeypatch):
"""A Redis cooldown written by another worker should invalidate local pins."""
import app.services.auto_model_pin_service as svc
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
_set_global_llm_configs(
monkeypatch,
config,
[
{
"id": -1,
"litellm_provider": "openrouter",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 90,
"health_gated": False,
},
{
"id": -2,
"litellm_provider": "openrouter",
"model_name": "google/gemini-2.5-flash:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 80,
"health_gated": False,
},
],
)
svc._runtime_cooldown_redis.set(
"auto:cooldown:llm:-1",
"provider_rate_limited",
ex=600,
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is False
@pytest.mark.asyncio
async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
from app.config import config