From f65b3be1ce72e311dffd03de2d60e0fe73f2aef8 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 00:57:52 +0530 Subject: [PATCH] feat(auto_model_pin): implement runtime cooldown for error handling and enhance candidate selection --- .../app/services/auto_model_pin_service.py | 64 ++- .../app/tasks/chat/stream_new_chat.py | 380 ++++++++++++++---- .../services/test_auto_model_pin_service.py | 112 ++++++ .../unit/test_stream_new_chat_contract.py | 16 + 4 files changed, 486 insertions(+), 86 deletions(-) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 94aa6b734..05a54b257 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -16,6 +16,8 @@ from __future__ import annotations import hashlib import logging +import threading +import time from dataclasses import dataclass from uuid import UUID @@ -31,6 +33,13 @@ logger = logging.getLogger(__name__) AUTO_FASTEST_ID = 0 AUTO_FASTEST_MODE = "auto_fastest" +_RUNTIME_COOLDOWN_SECONDS = 600 + +# In-memory runtime cooldown map for configs that recently hard-failed at +# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps +# the same unhealthy config from being reselected immediately during repair. +_runtime_cooldown_until: dict[int, float] = {} +_runtime_cooldown_lock = threading.Lock() @dataclass @@ -49,17 +58,68 @@ def _is_usable_global_config(cfg: dict) -> bool: ) +def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: + now = time.time() if now_ts is None else now_ts + stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now] + for cid in stale: + _runtime_cooldown_until.pop(cid, None) + + +def _is_runtime_cooled_down(config_id: int) -> bool: + with _runtime_cooldown_lock: + _prune_runtime_cooldowns() + return config_id in _runtime_cooldown_until + + +def mark_runtime_cooldown( + config_id: int, + *, + reason: str = "rate_limited", + cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS, +) -> None: + """Temporarily suppress a config from Auto selection. + + Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned + config that is currently unhealthy does not get immediately reused on the + same thread during repair. + """ + if cooldown_seconds <= 0: + cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS + until = time.time() + int(cooldown_seconds) + with _runtime_cooldown_lock: + _runtime_cooldown_until[int(config_id)] = until + _prune_runtime_cooldowns() + logger.info( + "auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s", + config_id, + reason, + cooldown_seconds, + ) + + +def clear_runtime_cooldown(config_id: int | None = None) -> None: + """Test/ops helper to clear runtime cooldown entries.""" + with _runtime_cooldown_lock: + if config_id is None: + _runtime_cooldown_until.clear() + return + _runtime_cooldown_until.pop(int(config_id), None) + + def _global_candidates() -> list[dict]: """Return Auto-eligible global cfgs. Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers - can't be picked as the thread's pin. + can't be picked as the thread's pin. Also excludes configs currently + in runtime cooldown (e.g. temporary 429 bursts). """ candidates = [ cfg for cfg in config.GLOBAL_LLM_CONFIGS - if _is_usable_global_config(cfg) and not cfg.get("health_gated") + if _is_usable_global_config(cfg) + and not cfg.get("health_gated") + and not _is_runtime_cooled_down(int(cfg.get("id", 0))) ] return sorted(candidates, key=lambda c: int(c.get("id", 0))) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 5abcb63eb..8f596927d 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -64,7 +64,10 @@ from app.db import ( shielded_async_session, ) from app.prompts import TITLE_GENERATION_PROMPT -from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id +from app.services.auto_model_pin_service import ( + mark_runtime_cooldown, + resolve_or_get_pinned_llm_config_id, +) from app.services.chat_session_state_service import ( clear_ai_responding, set_ai_responding, @@ -414,6 +417,60 @@ def _parse_error_payload(message: str) -> dict[str, Any] | None: return None +def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: + if not isinstance(parsed, dict): + return None + candidates: list[Any] = [parsed.get("code")] + nested = parsed.get("error") + if isinstance(nested, dict): + candidates.append(nested.get("code")) + for value in candidates: + try: + if value is None: + continue + return int(value) + except Exception: + continue + return None + + +def _is_provider_rate_limited(exc: BaseException) -> bool: + """Best-effort detection for provider-side runtime throttling. + + Covers LiteLLM/OpenRouter shapes like: + - class name contains ``RateLimit`` + - nested payload ``{"error": {"code": 429}}`` + - nested payload ``{"error": {"type": "rate_limit_error"}}`` + """ + raw = str(exc) + lowered = raw.lower() + if "ratelimit" in type(exc).__name__.lower(): + return True + parsed = _parse_error_payload(raw) + provider_code = _extract_provider_error_code(parsed) + if provider_code == 429: + return True + + provider_error_type = "" + if parsed: + top_type = parsed.get("type") + if isinstance(top_type, str): + provider_error_type = top_type.lower() + nested = parsed.get("error") + if isinstance(nested, dict): + nested_type = nested.get("type") + if isinstance(nested_type, str): + provider_error_type = nested_type.lower() + if provider_error_type == "rate_limit_error": + return True + + return ( + "rate limited" in lowered + or "rate-limited" in lowered + or "temporarily rate-limited upstream" in lowered + ) + + def _classify_stream_exception( exc: Exception, *, @@ -449,19 +506,7 @@ def _classify_stream_exception( None, ) - parsed = _parse_error_payload(raw) - provider_error_type = "" - if parsed: - top_type = parsed.get("type") - if isinstance(top_type, str): - provider_error_type = top_type.lower() - nested = parsed.get("error") - if isinstance(nested, dict): - nested_type = nested.get("type") - if isinstance(nested_type, str): - provider_error_type = nested_type.lower() - - if provider_error_type == "rate_limit_error": + if _is_provider_rate_limited(exc): return ( "rate_limited", "RATE_LIMITED", @@ -2671,54 +2716,144 @@ async def stream_new_chat( _t_stream_start = time.perf_counter() _first_event_logged = False - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=input_state, - streaming_service=streaming_service, - result=stream_result, - step_prefix="thinking", - initial_step_id=initial_step_id, - initial_step_title=initial_title, - initial_step_items=initial_items, - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_new_chat] First agent event in %.3fs (time since stream start), " - "%.3fs (total since request start) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, - ) - _first_event_logged = True - yield sse - - # Inject title update mid-stream as soon as the background task finishes - if title_task is not None and title_task.done() and not title_emitted: - generated_title, title_usage = title_task.result() - if title_usage: - accumulator.add(**title_usage) - if generated_title: - async with shielded_async_session() as title_session: - title_thread_result = await title_session.execute( - select(NewChatThread).filter(NewChatThread.id == chat_id) + runtime_rate_limit_recovered = False + while True: + try: + async for sse in _stream_agent_events( + agent=agent, + config=config, + input_data=input_state, + streaming_service=streaming_service, + result=stream_result, + step_prefix="thinking", + initial_step_id=initial_step_id, + initial_step_title=initial_title, + initial_step_items=initial_items, + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), + fallback_commit_thread_id=chat_id, + ): + if not _first_event_logged: + _perf_log.info( + "[stream_new_chat] First agent event in %.3fs (time since stream start), " + "%.3fs (total since request start) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, ) - title_thread = title_thread_result.scalars().first() - if title_thread: - title_thread.title = generated_title - await title_session.commit() - yield streaming_service.format_thread_title_update( - chat_id, generated_title + _first_event_logged = True + yield sse + + # Inject title update mid-stream as soon as the background + # task finishes. + if title_task is not None and title_task.done() and not title_emitted: + generated_title, title_usage = title_task.result() + if title_usage: + accumulator.add(**title_usage) + if generated_title: + async with shielded_async_session() as title_session: + title_thread_result = await title_session.execute( + select(NewChatThread).filter( + NewChatThread.id == chat_id + ) + ) + title_thread = title_thread_result.scalars().first() + if title_thread: + title_thread.title = generated_title + await title_session.commit() + yield streaming_service.format_thread_title_update( + chat_id, generated_title + ) + title_emitted = True + break + except Exception as stream_exc: + can_runtime_recover = ( + not runtime_rate_limit_recovered + and requested_llm_config_id == 0 + and llm_config_id < 0 + and not _first_event_logged + and _is_provider_rate_limited(stream_exc) + ) + if not can_runtime_recover: + raise + + runtime_rate_limit_recovered = True + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, + reason="provider_rate_limited", + ) + + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, ) - title_emitted = True + ).resolved_llm_config_id + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + raise stream_exc + + # Title generation uses the initial llm object. After a runtime + # repin we keep the stream focused on response recovery and skip + # title generation for this turn. + if title_task is not None and not title_task.done(): + title_task.cancel() + title_task = None + + _t0 = time.perf_counter() + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + filesystem_selection=filesystem_selection, + ) + _perf_log.info( + "[stream_new_chat] Runtime rate-limit recovery repinned " + "config_id=%s -> %s and rebuilt agent in %.3fs", + previous_config_id, + llm_config_id, + time.perf_counter() - _t0, + ) + _log_chat_stream_error( + flow=flow, + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model hit runtime rate limit; switched to " + "another eligible model and retried." + ), + extra={ + "auto_runtime_recover": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + continue _perf_log.info( "[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)", @@ -3265,31 +3400,108 @@ async def stream_resume_chat( _t_stream_start = time.perf_counter() _first_event_logged = False - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=Command(resume={"decisions": decisions}), - streaming_service=streaming_service, - result=stream_result, - step_prefix="thinking-resume", - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, + runtime_rate_limit_recovered = False + while True: + try: + async for sse in _stream_agent_events( + agent=agent, + config=config, + input_data=Command(resume={"decisions": decisions}), + streaming_service=streaming_service, + result=stream_result, + step_prefix="thinking-resume", + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), + fallback_commit_thread_id=chat_id, + ): + if not _first_event_logged: + _perf_log.info( + "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, + ) + _first_event_logged = True + yield sse + break + except Exception as stream_exc: + can_runtime_recover = ( + not runtime_rate_limit_recovered + and requested_llm_config_id == 0 + and llm_config_id < 0 + and not _first_event_logged + and _is_provider_rate_limited(stream_exc) ) - _first_event_logged = True - yield sse + if not can_runtime_recover: + raise + + runtime_rate_limit_recovered = True + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, + reason="provider_rate_limited", + ) + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + ) + ).resolved_llm_config_id + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + raise stream_exc + + _t0 = time.perf_counter() + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + ) + _perf_log.info( + "[stream_resume] Runtime rate-limit recovery repinned " + "config_id=%s -> %s and rebuilt agent in %.3fs", + previous_config_id, + llm_config_id, + time.perf_counter() - _t0, + ) + _log_chat_stream_error( + flow="resume", + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model hit runtime rate limit; switched to " + "another eligible model and retried." + ), + extra={ + "auto_runtime_recover": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + continue _perf_log.info( "[stream_resume] Agent stream completed in %.3fs (chat_id=%s)", time.perf_counter() - _t_stream_start, diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index be9d7f721..8261fdfe0 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -6,12 +6,21 @@ from types import SimpleNamespace import pytest from app.services.auto_model_pin_service import ( + clear_runtime_cooldown, + mark_runtime_cooldown, resolve_or_get_pinned_llm_config_id, ) pytestmark = pytest.mark.unit +@pytest.fixture(autouse=True) +def _clear_runtime_cooldown_map(): + clear_runtime_cooldown() + yield + clear_runtime_cooldown() + + @dataclass class _FakeQuotaResult: allowed: bool @@ -701,3 +710,106 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): assert result.resolved_llm_config_id == -1 assert result.from_existing_pin is True assert session.commit_count == 0 + + +@pytest.mark.asyncio +async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): + """A runtime-cooled config should be excluded from candidate reuse. + + This enables one-shot recovery from transient provider 429 bursts: we can + mark the pinned cfg as cooled down and force a repair to another eligible + cfg on the next resolution. + """ + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 80, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + + +@pytest.mark.asyncio +async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError("premium_get_usage should not run on healthy pin reuse") + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600) + clear_runtime_cooldown(-1) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 5e6ad6abd..ed69ca348 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -159,6 +159,22 @@ def test_stream_exception_classifies_rate_limited(): assert extra is None +def test_stream_exception_classifies_openrouter_429_payload(): + exc = Exception( + 'OpenrouterException - {"error":{"message":"Provider returned error","code":429,' + '"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}' + ) + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "rate_limited" + assert code == "RATE_LIMITED" + assert severity == "warn" + assert is_expected is True + assert "temporarily rate-limited" in user_message + assert extra is None + + def test_stream_exception_classifies_thread_busy(): exc = BusyError(request_id="thread-123") kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(