feat(auto_model_pin): implement runtime cooldown for error handling and enhance candidate selection

This commit is contained in:
Anish Sarkar 2026-05-02 00:57:52 +05:30
parent 4bef75d298
commit f65b3be1ce
4 changed files with 486 additions and 86 deletions

View file

@ -16,6 +16,8 @@ from __future__ import annotations
import hashlib
import logging
import threading
import time
from dataclasses import dataclass
from uuid import UUID
@ -31,6 +33,13 @@ logger = logging.getLogger(__name__)
AUTO_FASTEST_ID = 0
AUTO_FASTEST_MODE = "auto_fastest"
_RUNTIME_COOLDOWN_SECONDS = 600
# In-memory runtime cooldown map for configs that recently hard-failed at
# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps
# the same unhealthy config from being reselected immediately during repair.
_runtime_cooldown_until: dict[int, float] = {}
_runtime_cooldown_lock = threading.Lock()
@dataclass
@ -49,17 +58,68 @@ def _is_usable_global_config(cfg: dict) -> bool:
)
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
now = time.time() if now_ts is None else now_ts
stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
for cid in stale:
_runtime_cooldown_until.pop(cid, None)
def _is_runtime_cooled_down(config_id: int) -> bool:
with _runtime_cooldown_lock:
_prune_runtime_cooldowns()
return config_id in _runtime_cooldown_until
def mark_runtime_cooldown(
config_id: int,
*,
reason: str = "rate_limited",
cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS,
) -> None:
"""Temporarily suppress a config from Auto selection.
Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned
config that is currently unhealthy does not get immediately reused on the
same thread during repair.
"""
if cooldown_seconds <= 0:
cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS
until = time.time() + int(cooldown_seconds)
with _runtime_cooldown_lock:
_runtime_cooldown_until[int(config_id)] = until
_prune_runtime_cooldowns()
logger.info(
"auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s",
config_id,
reason,
cooldown_seconds,
)
def clear_runtime_cooldown(config_id: int | None = None) -> None:
"""Test/ops helper to clear runtime cooldown entries."""
with _runtime_cooldown_lock:
if config_id is None:
_runtime_cooldown_until.clear()
return
_runtime_cooldown_until.pop(int(config_id), None)
def _global_candidates() -> list[dict]:
"""Return Auto-eligible global cfgs.
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
can't be picked as the thread's pin.
can't be picked as the thread's pin. Also excludes configs currently
in runtime cooldown (e.g. temporary 429 bursts).
"""
candidates = [
cfg
for cfg in config.GLOBAL_LLM_CONFIGS
if _is_usable_global_config(cfg) and not cfg.get("health_gated")
if _is_usable_global_config(cfg)
and not cfg.get("health_gated")
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
]
return sorted(candidates, key=lambda c: int(c.get("id", 0)))

View file

@ -64,7 +64,10 @@ from app.db import (
shielded_async_session,
)
from app.prompts import TITLE_GENERATION_PROMPT
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
from app.services.auto_model_pin_service import (
mark_runtime_cooldown,
resolve_or_get_pinned_llm_config_id,
)
from app.services.chat_session_state_service import (
clear_ai_responding,
set_ai_responding,
@ -414,6 +417,60 @@ def _parse_error_payload(message: str) -> dict[str, Any] | None:
return None
def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
if not isinstance(parsed, dict):
return None
candidates: list[Any] = [parsed.get("code")]
nested = parsed.get("error")
if isinstance(nested, dict):
candidates.append(nested.get("code"))
for value in candidates:
try:
if value is None:
continue
return int(value)
except Exception:
continue
return None
def _is_provider_rate_limited(exc: BaseException) -> bool:
"""Best-effort detection for provider-side runtime throttling.
Covers LiteLLM/OpenRouter shapes like:
- class name contains ``RateLimit``
- nested payload ``{"error": {"code": 429}}``
- nested payload ``{"error": {"type": "rate_limit_error"}}``
"""
raw = str(exc)
lowered = raw.lower()
if "ratelimit" in type(exc).__name__.lower():
return True
parsed = _parse_error_payload(raw)
provider_code = _extract_provider_error_code(parsed)
if provider_code == 429:
return True
provider_error_type = ""
if parsed:
top_type = parsed.get("type")
if isinstance(top_type, str):
provider_error_type = top_type.lower()
nested = parsed.get("error")
if isinstance(nested, dict):
nested_type = nested.get("type")
if isinstance(nested_type, str):
provider_error_type = nested_type.lower()
if provider_error_type == "rate_limit_error":
return True
return (
"rate limited" in lowered
or "rate-limited" in lowered
or "temporarily rate-limited upstream" in lowered
)
def _classify_stream_exception(
exc: Exception,
*,
@ -449,19 +506,7 @@ def _classify_stream_exception(
None,
)
parsed = _parse_error_payload(raw)
provider_error_type = ""
if parsed:
top_type = parsed.get("type")
if isinstance(top_type, str):
provider_error_type = top_type.lower()
nested = parsed.get("error")
if isinstance(nested, dict):
nested_type = nested.get("type")
if isinstance(nested_type, str):
provider_error_type = nested_type.lower()
if provider_error_type == "rate_limit_error":
if _is_provider_rate_limited(exc):
return (
"rate_limited",
"RATE_LIMITED",
@ -2671,54 +2716,144 @@ async def stream_new_chat(
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
agent=agent,
config=config,
input_data=input_state,
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking",
initial_step_id=initial_step_id,
initial_step_title=initial_title,
initial_step_items=initial_items,
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
"%.3fs (total since request start) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
_first_event_logged = True
yield sse
# Inject title update mid-stream as soon as the background task finishes
if title_task is not None and title_task.done() and not title_emitted:
generated_title, title_usage = title_task.result()
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
runtime_rate_limit_recovered = False
while True:
try:
async for sse in _stream_agent_events(
agent=agent,
config=config,
input_data=input_state,
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking",
initial_step_id=initial_step_id,
initial_step_title=initial_title,
initial_step_items=initial_items,
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
"%.3fs (total since request start) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
title_thread = title_thread_result.scalars().first()
if title_thread:
title_thread.title = generated_title
await title_session.commit()
yield streaming_service.format_thread_title_update(
chat_id, generated_title
_first_event_logged = True
yield sse
# Inject title update mid-stream as soon as the background
# task finishes.
if title_task is not None and title_task.done() and not title_emitted:
generated_title, title_usage = title_task.result()
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
select(NewChatThread).filter(
NewChatThread.id == chat_id
)
)
title_thread = title_thread_result.scalars().first()
if title_thread:
title_thread.title = generated_title
await title_session.commit()
yield streaming_service.format_thread_title_update(
chat_id, generated_title
)
title_emitted = True
break
except Exception as stream_exc:
can_runtime_recover = (
not runtime_rate_limit_recovered
and requested_llm_config_id == 0
and llm_config_id < 0
and not _first_event_logged
and _is_provider_rate_limited(stream_exc)
)
if not can_runtime_recover:
raise
runtime_rate_limit_recovered = True
previous_config_id = llm_config_id
mark_runtime_cooldown(
previous_config_id,
reason="provider_rate_limited",
)
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
)
title_emitted = True
).resolved_llm_config_id
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
if llm_load_error:
raise stream_exc
# Title generation uses the initial llm object. After a runtime
# repin we keep the stream focused on response recovery and skip
# title generation for this turn.
if title_task is not None and not title_task.done():
title_task.cancel()
title_task = None
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
)
_perf_log.info(
"[stream_new_chat] Runtime rate-limit recovery repinned "
"config_id=%s -> %s and rebuilt agent in %.3fs",
previous_config_id,
llm_config_id,
time.perf_counter() - _t0,
)
_log_chat_stream_error(
flow=flow,
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model hit runtime rate limit; switched to "
"another eligible model and retried."
),
extra={
"auto_runtime_recover": True,
"previous_config_id": previous_config_id,
"fallback_config_id": llm_config_id,
},
)
continue
_perf_log.info(
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
@ -3265,31 +3400,108 @@ async def stream_resume_chat(
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
agent=agent,
config=config,
input_data=Command(resume={"decisions": decisions}),
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking-resume",
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
runtime_rate_limit_recovered = False
while True:
try:
async for sse in _stream_agent_events(
agent=agent,
config=config,
input_data=Command(resume={"decisions": decisions}),
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking-resume",
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
_first_event_logged = True
yield sse
break
except Exception as stream_exc:
can_runtime_recover = (
not runtime_rate_limit_recovered
and requested_llm_config_id == 0
and llm_config_id < 0
and not _first_event_logged
and _is_provider_rate_limited(stream_exc)
)
_first_event_logged = True
yield sse
if not can_runtime_recover:
raise
runtime_rate_limit_recovered = True
previous_config_id = llm_config_id
mark_runtime_cooldown(
previous_config_id,
reason="provider_rate_limited",
)
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
)
).resolved_llm_config_id
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
if llm_load_error:
raise stream_exc
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
)
_perf_log.info(
"[stream_resume] Runtime rate-limit recovery repinned "
"config_id=%s -> %s and rebuilt agent in %.3fs",
previous_config_id,
llm_config_id,
time.perf_counter() - _t0,
)
_log_chat_stream_error(
flow="resume",
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model hit runtime rate limit; switched to "
"another eligible model and retried."
),
extra={
"auto_runtime_recover": True,
"previous_config_id": previous_config_id,
"fallback_config_id": llm_config_id,
},
)
continue
_perf_log.info(
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
time.perf_counter() - _t_stream_start,