From 4a6a282a469b97e230501d66c6a47f889ce17fad Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 13:53:01 +0530 Subject: [PATCH] feat(runtime-cooldown): implement Redis-based shared cooldown management for model selection --- .../app/services/auto_model_pin_service.py | 126 +++++++++++++--- .../services/test_auto_model_pin_service.py | 142 +++++++++++++++++- 2 files changed, 242 insertions(+), 26 deletions(-) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index dfd7c7be3..2770a99aa 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -21,6 +21,7 @@ import time from dataclasses import dataclass from uuid import UUID +import redis from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -39,12 +40,16 @@ AUTO_MODE_ID = 0 AUTO_PIN_HASH_NAMESPACE = "auto_fastest" _RUNTIME_COOLDOWN_SECONDS = 600 _HEALTHY_TTL_SECONDS = 45 +_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX = "auto:cooldown:llm:" +_REDIS_TIMEOUT_SECONDS = 0.2 # In-memory runtime cooldown map for configs that recently hard-failed at # provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps # the same unhealthy config from being reselected immediately during repair. _runtime_cooldown_until: dict[int, float] = {} _runtime_cooldown_lock = threading.Lock() +_runtime_cooldown_redis: redis.Redis | None = None +_runtime_cooldown_redis_lock = threading.Lock() # Short-TTL "recently healthy" cache for configs that just passed a runtime # preflight ping. Lets back-to-back turns on the same model skip the probe @@ -87,6 +92,79 @@ def _is_runtime_cooled_down(config_id: int) -> bool: return config_id in _runtime_cooldown_until +def _runtime_cooldown_redis_key(config_id: int) -> str: + return f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}{int(config_id)}" + + +def _get_runtime_cooldown_redis() -> redis.Redis: + global _runtime_cooldown_redis + if _runtime_cooldown_redis is None: + with _runtime_cooldown_redis_lock: + if _runtime_cooldown_redis is None: + _runtime_cooldown_redis = redis.from_url( + config.REDIS_APP_URL, + decode_responses=True, + socket_connect_timeout=_REDIS_TIMEOUT_SECONDS, + socket_timeout=_REDIS_TIMEOUT_SECONDS, + ) + return _runtime_cooldown_redis + + +def _mark_shared_runtime_cooldown( + config_id: int, + *, + reason: str, + cooldown_seconds: int, +) -> None: + try: + _get_runtime_cooldown_redis().set( + _runtime_cooldown_redis_key(config_id), + reason, + ex=int(cooldown_seconds), + ) + except Exception: + logger.warning( + "auto_pin_runtime_cooldown_redis_write_failed config_id=%s", + config_id, + exc_info=True, + ) + + +def _shared_runtime_cooled_down_ids(config_ids: list[int]) -> set[int]: + unique_ids = list(dict.fromkeys(int(cid) for cid in config_ids)) + if not unique_ids: + return set() + try: + values = _get_runtime_cooldown_redis().mget( + [_runtime_cooldown_redis_key(cid) for cid in unique_ids] + ) + except Exception: + logger.warning( + "auto_pin_runtime_cooldown_redis_read_failed count=%s", + len(unique_ids), + exc_info=True, + ) + return set() + return {cid for cid, value in zip(unique_ids, values, strict=False) if value is not None} + + +def _clear_shared_runtime_cooldown(config_id: int | None = None) -> None: + try: + client = _get_runtime_cooldown_redis() + if config_id is not None: + client.delete(_runtime_cooldown_redis_key(config_id)) + return + keys = list(client.scan_iter(f"{_RUNTIME_COOLDOWN_REDIS_KEY_PREFIX}*")) + if keys: + client.delete(*keys) + except Exception: + logger.warning( + "auto_pin_runtime_cooldown_redis_clear_failed config_id=%s", + config_id, + exc_info=True, + ) + + def mark_runtime_cooldown( config_id: int, *, @@ -105,6 +183,11 @@ def mark_runtime_cooldown( with _runtime_cooldown_lock: _runtime_cooldown_until[int(config_id)] = until _prune_runtime_cooldowns() + _mark_shared_runtime_cooldown( + int(config_id), + reason=reason, + cooldown_seconds=int(cooldown_seconds), + ) # A cooled cfg can never be "recently healthy"; drop any stale credit so # the next turn that resolves to it (after cooldown) re-runs preflight. clear_healthy(int(config_id)) @@ -121,8 +204,9 @@ def clear_runtime_cooldown(config_id: int | None = None) -> None: with _runtime_cooldown_lock: if config_id is None: _runtime_cooldown_until.clear() - return - _runtime_cooldown_until.pop(int(config_id), None) + else: + _runtime_cooldown_until.pop(int(config_id), None) + _clear_shared_runtime_cooldown(config_id) def _prune_healthy(now_ts: float | None = None) -> None: @@ -205,6 +289,7 @@ def _global_candidates( *, capability: str = "chat", requires_image_input: bool = False, + shared_cooled_down_ids: set[int] | None = None, ) -> list[dict]: """Return Auto-eligible global virtual models. @@ -228,9 +313,14 @@ def _global_candidates( if _is_usable_global_config(cfg) } candidates: list[dict] = [] + shared_cooled_down_ids = shared_cooled_down_ids or set() for model in config.GLOBAL_MODELS: model_id = int(model.get("id", 0)) - if model_id >= 0 or _is_runtime_cooled_down(model_id): + if ( + model_id >= 0 + or _is_runtime_cooled_down(model_id) + or model_id in shared_cooled_down_ids + ): continue if not _has_capability(model, capability): continue @@ -287,8 +377,12 @@ async def _db_candidates( .where(Model.enabled.is_(True), Connection.enabled.is_(True)) ) result = await session.execute(stmt) + models = result.scalars().all() + shared_cooled_down_ids = _shared_runtime_cooled_down_ids( + [int(model.id) for model in models] + ) candidates: list[dict] = [] - for model in result.scalars().all(): + for model in models: conn = model.connection if not conn: continue @@ -303,7 +397,7 @@ async def _db_candidates( if requires_image_input and not _has_capability(model, "vision"): continue model_id = int(model.id) - if _is_runtime_cooled_down(model_id): + if _is_runtime_cooled_down(model_id) or model_id in shared_cooled_down_ids: continue catalog = model.catalog or {} candidates.append( @@ -337,6 +431,12 @@ async def auto_model_candidates( exclude_model_ids: set[int] | None = None, ) -> list[dict]: excluded_ids = {int(mid) for mid in (exclude_model_ids or set())} + global_ids = [ + int(model.get("id", 0)) + for model in config.GLOBAL_MODELS + if int(model.get("id", 0)) < 0 + ] + shared_global_cooled_down_ids = _shared_runtime_cooled_down_ids(global_ids) db_candidates = await _db_candidates( session, search_space_id=search_space_id, @@ -348,6 +448,7 @@ async def auto_model_candidates( *_global_candidates( capability=capability, requires_image_input=requires_image_input, + shared_cooled_down_ids=shared_global_cooled_down_ids, ), *db_candidates, ] @@ -358,16 +459,6 @@ def _tier_of(cfg: dict) -> str: return str(cfg.get("billing_tier", "free")).lower() -def _is_preferred_premium_auto_config(cfg: dict) -> bool: - """Return True for the operator-preferred premium Auto model.""" - return ( - cfg.get("source") == "global" - and _tier_of(cfg) == "premium" - and str(cfg.get("provider", "")).lower() == "azure" - and str(cfg.get("model_name", "")).lower() == "gpt-5.4" - ) - - def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: """Pick a config with quality-first ranking + deterministic spread. @@ -546,10 +637,7 @@ async def resolve_or_get_pinned_llm_config_id( byok_candidates = [c for c in candidates if _tier_of(c) == "byok"] if premium_eligible: premium_candidates = [c for c in candidates if _tier_of(c) == "premium"] - preferred_premium = [ - c for c in premium_candidates if _is_preferred_premium_auto_config(c) - ] - eligible = preferred_premium or premium_candidates or byok_candidates + eligible = premium_candidates or byok_candidates else: eligible = [c for c in candidates if _tier_of(c) != "premium"] diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index d7c12a6e0..a53cd78b7 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -17,8 +17,39 @@ from app.services.auto_model_pin_service import ( pytestmark = pytest.mark.unit +class _FakeRedis: + def __init__(self): + self.values: dict[str, str] = {} + self.ttls: dict[str, int] = {} + + def set(self, key: str, value: str, *, ex: int | None = None): + self.values[key] = value + if ex is not None: + self.ttls[key] = ex + return True + + def mget(self, keys: list[str]): + return [self.values.get(key) for key in keys] + + def delete(self, *keys: str): + removed = 0 + for key in keys: + if key in self.values: + removed += 1 + self.values.pop(key, None) + self.ttls.pop(key, None) + return removed + + def scan_iter(self, pattern: str): + prefix = pattern.removesuffix("*") + return (key for key in list(self.values) if key.startswith(prefix)) + + @pytest.fixture(autouse=True) -def _clear_runtime_cooldown_map(): +def _clear_runtime_cooldown_map(monkeypatch): + import app.services.auto_model_pin_service as svc + + monkeypatch.setattr(svc, "_runtime_cooldown_redis", _FakeRedis()) clear_runtime_cooldown() clear_healthy() yield @@ -205,7 +236,9 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): @pytest.mark.asyncio -async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): +async def test_premium_eligible_auto_uses_quality_pool_not_single_preferred_model( + monkeypatch, +): from app.config import config session = _FakeSession(_thread()) @@ -233,12 +266,39 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): }, { "id": -3, - "litellm_provider": "openrouter", - "model_name": "openai/gpt-5.4", + "litellm_provider": "anthropic", + "model_name": "claude-opus", "api_key": "k3", "billing_tier": "premium", - "auto_pin_tier": "B", - "quality_score": 100, + "auto_pin_tier": "A", + "quality_score": 99, + }, + { + "id": -4, + "litellm_provider": "openai", + "model_name": "gpt-5.3", + "api_key": "k4", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 98, + }, + { + "id": -5, + "litellm_provider": "gemini", + "model_name": "gemini-3-pro", + "api_key": "k5", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 97, + }, + { + "id": -6, + "litellm_provider": "xai", + "model_name": "grok-5", + "api_key": "k6", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 96, }, ], ) @@ -258,7 +318,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): user_id="00000000-0000-0000-0000-000000000001", selected_llm_config_id=0, ) - assert result.resolved_llm_config_id == -2 + assert result.resolved_llm_config_id in {-1, -3, -4, -5, -6} assert result.resolved_tier == "premium" @@ -932,6 +992,74 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): assert result.from_existing_pin is False +def test_mark_runtime_cooldown_writes_shared_redis(monkeypatch): + import app.services.auto_model_pin_service as svc + + mark_runtime_cooldown(-9, reason="provider_rate_limited", cooldown_seconds=123) + + redis_client = svc._runtime_cooldown_redis + assert redis_client.values["auto:cooldown:llm:-9"] == "provider_rate_limited" + assert redis_client.ttls["auto:cooldown:llm:-9"] == 123 + + +@pytest.mark.asyncio +async def test_shared_runtime_cooldown_blocks_pin_across_workers(monkeypatch): + """A Redis cooldown written by another worker should invalidate local pins.""" + import app.services.auto_model_pin_service as svc + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + _set_global_llm_configs( + monkeypatch, + config, + [ + { + "id": -1, + "litellm_provider": "openrouter", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + { + "id": -2, + "litellm_provider": "openrouter", + "model_name": "google/gemini-2.5-flash:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 80, + "health_gated": False, + }, + ], + ) + svc._runtime_cooldown_redis.set( + "auto:cooldown:llm:-1", + "provider_rate_limited", + ex=600, + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.credit_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + + @pytest.mark.asyncio async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): from app.config import config