mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 12:52:39 +02:00
feat(auto_model_pin): implement runtime cooldown for error handling and enhance candidate selection
This commit is contained in:
parent
4bef75d298
commit
f65b3be1ce
4 changed files with 486 additions and 86 deletions
|
|
@ -16,6 +16,8 @@ from __future__ import annotations
|
|||
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
|
|
@ -31,6 +33,13 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
AUTO_FASTEST_ID = 0
|
||||
AUTO_FASTEST_MODE = "auto_fastest"
|
||||
_RUNTIME_COOLDOWN_SECONDS = 600
|
||||
|
||||
# In-memory runtime cooldown map for configs that recently hard-failed at
|
||||
# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps
|
||||
# the same unhealthy config from being reselected immediately during repair.
|
||||
_runtime_cooldown_until: dict[int, float] = {}
|
||||
_runtime_cooldown_lock = threading.Lock()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -49,17 +58,68 @@ def _is_usable_global_config(cfg: dict) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
|
||||
now = time.time() if now_ts is None else now_ts
|
||||
stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
|
||||
for cid in stale:
|
||||
_runtime_cooldown_until.pop(cid, None)
|
||||
|
||||
|
||||
def _is_runtime_cooled_down(config_id: int) -> bool:
|
||||
with _runtime_cooldown_lock:
|
||||
_prune_runtime_cooldowns()
|
||||
return config_id in _runtime_cooldown_until
|
||||
|
||||
|
||||
def mark_runtime_cooldown(
|
||||
config_id: int,
|
||||
*,
|
||||
reason: str = "rate_limited",
|
||||
cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS,
|
||||
) -> None:
|
||||
"""Temporarily suppress a config from Auto selection.
|
||||
|
||||
Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned
|
||||
config that is currently unhealthy does not get immediately reused on the
|
||||
same thread during repair.
|
||||
"""
|
||||
if cooldown_seconds <= 0:
|
||||
cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS
|
||||
until = time.time() + int(cooldown_seconds)
|
||||
with _runtime_cooldown_lock:
|
||||
_runtime_cooldown_until[int(config_id)] = until
|
||||
_prune_runtime_cooldowns()
|
||||
logger.info(
|
||||
"auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s",
|
||||
config_id,
|
||||
reason,
|
||||
cooldown_seconds,
|
||||
)
|
||||
|
||||
|
||||
def clear_runtime_cooldown(config_id: int | None = None) -> None:
|
||||
"""Test/ops helper to clear runtime cooldown entries."""
|
||||
with _runtime_cooldown_lock:
|
||||
if config_id is None:
|
||||
_runtime_cooldown_until.clear()
|
||||
return
|
||||
_runtime_cooldown_until.pop(int(config_id), None)
|
||||
|
||||
|
||||
def _global_candidates() -> list[dict]:
|
||||
"""Return Auto-eligible global cfgs.
|
||||
|
||||
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
|
||||
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
|
||||
can't be picked as the thread's pin.
|
||||
can't be picked as the thread's pin. Also excludes configs currently
|
||||
in runtime cooldown (e.g. temporary 429 bursts).
|
||||
"""
|
||||
candidates = [
|
||||
cfg
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS
|
||||
if _is_usable_global_config(cfg) and not cfg.get("health_gated")
|
||||
if _is_usable_global_config(cfg)
|
||||
and not cfg.get("health_gated")
|
||||
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
|
||||
]
|
||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||
|
||||
|
|
|
|||
|
|
@ -64,7 +64,10 @@ from app.db import (
|
|||
shielded_async_session,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT
|
||||
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
|
||||
from app.services.auto_model_pin_service import (
|
||||
mark_runtime_cooldown,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
|
|
@ -414,6 +417,60 @@ def _parse_error_payload(message: str) -> dict[str, Any] | None:
|
|||
return None
|
||||
|
||||
|
||||
def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
|
||||
if not isinstance(parsed, dict):
|
||||
return None
|
||||
candidates: list[Any] = [parsed.get("code")]
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
candidates.append(nested.get("code"))
|
||||
for value in candidates:
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
return int(value)
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _is_provider_rate_limited(exc: BaseException) -> bool:
|
||||
"""Best-effort detection for provider-side runtime throttling.
|
||||
|
||||
Covers LiteLLM/OpenRouter shapes like:
|
||||
- class name contains ``RateLimit``
|
||||
- nested payload ``{"error": {"code": 429}}``
|
||||
- nested payload ``{"error": {"type": "rate_limit_error"}}``
|
||||
"""
|
||||
raw = str(exc)
|
||||
lowered = raw.lower()
|
||||
if "ratelimit" in type(exc).__name__.lower():
|
||||
return True
|
||||
parsed = _parse_error_payload(raw)
|
||||
provider_code = _extract_provider_error_code(parsed)
|
||||
if provider_code == 429:
|
||||
return True
|
||||
|
||||
provider_error_type = ""
|
||||
if parsed:
|
||||
top_type = parsed.get("type")
|
||||
if isinstance(top_type, str):
|
||||
provider_error_type = top_type.lower()
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
nested_type = nested.get("type")
|
||||
if isinstance(nested_type, str):
|
||||
provider_error_type = nested_type.lower()
|
||||
if provider_error_type == "rate_limit_error":
|
||||
return True
|
||||
|
||||
return (
|
||||
"rate limited" in lowered
|
||||
or "rate-limited" in lowered
|
||||
or "temporarily rate-limited upstream" in lowered
|
||||
)
|
||||
|
||||
|
||||
def _classify_stream_exception(
|
||||
exc: Exception,
|
||||
*,
|
||||
|
|
@ -449,19 +506,7 @@ def _classify_stream_exception(
|
|||
None,
|
||||
)
|
||||
|
||||
parsed = _parse_error_payload(raw)
|
||||
provider_error_type = ""
|
||||
if parsed:
|
||||
top_type = parsed.get("type")
|
||||
if isinstance(top_type, str):
|
||||
provider_error_type = top_type.lower()
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
nested_type = nested.get("type")
|
||||
if isinstance(nested_type, str):
|
||||
provider_error_type = nested_type.lower()
|
||||
|
||||
if provider_error_type == "rate_limit_error":
|
||||
if _is_provider_rate_limited(exc):
|
||||
return (
|
||||
"rate_limited",
|
||||
"RATE_LIMITED",
|
||||
|
|
@ -2671,54 +2716,144 @@ async def stream_new_chat(
|
|||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=input_state,
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking",
|
||||
initial_step_id=initial_step_id,
|
||||
initial_step_title=initial_title,
|
||||
initial_step_items=initial_items,
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
|
||||
"%.3fs (total since request start) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
|
||||
# Inject title update mid-stream as soon as the background task finishes
|
||||
if title_task is not None and title_task.done() and not title_emitted:
|
||||
generated_title, title_usage = title_task.result()
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
runtime_rate_limit_recovered = False
|
||||
while True:
|
||||
try:
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=input_state,
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking",
|
||||
initial_step_id=initial_step_id,
|
||||
initial_step_title=initial_title,
|
||||
initial_step_items=initial_items,
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
|
||||
"%.3fs (total since request start) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(
|
||||
chat_id, generated_title
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
|
||||
# Inject title update mid-stream as soon as the background
|
||||
# task finishes.
|
||||
if title_task is not None and title_task.done() and not title_emitted:
|
||||
generated_title, title_usage = title_task.result()
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(
|
||||
NewChatThread.id == chat_id
|
||||
)
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(
|
||||
chat_id, generated_title
|
||||
)
|
||||
title_emitted = True
|
||||
break
|
||||
except Exception as stream_exc:
|
||||
can_runtime_recover = (
|
||||
not runtime_rate_limit_recovered
|
||||
and requested_llm_config_id == 0
|
||||
and llm_config_id < 0
|
||||
and not _first_event_logged
|
||||
and _is_provider_rate_limited(stream_exc)
|
||||
)
|
||||
if not can_runtime_recover:
|
||||
raise
|
||||
|
||||
runtime_rate_limit_recovered = True
|
||||
previous_config_id = llm_config_id
|
||||
mark_runtime_cooldown(
|
||||
previous_config_id,
|
||||
reason="provider_rate_limited",
|
||||
)
|
||||
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
title_emitted = True
|
||||
).resolved_llm_config_id
|
||||
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||
if llm_load_error:
|
||||
raise stream_exc
|
||||
|
||||
# Title generation uses the initial llm object. After a runtime
|
||||
# repin we keep the stream focused on response recovery and skip
|
||||
# title generation for this turn.
|
||||
if title_task is not None and not title_task.done():
|
||||
title_task.cancel()
|
||||
title_task = None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
filesystem_selection=filesystem_selection,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Runtime rate-limit recovery repinned "
|
||||
"config_id=%s -> %s and rebuilt agent in %.3fs",
|
||||
previous_config_id,
|
||||
llm_config_id,
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
_log_chat_stream_error(
|
||||
flow=flow,
|
||||
error_kind="rate_limited",
|
||||
error_code="RATE_LIMITED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Auto-pinned model hit runtime rate limit; switched to "
|
||||
"another eligible model and retried."
|
||||
),
|
||||
extra={
|
||||
"auto_runtime_recover": True,
|
||||
"previous_config_id": previous_config_id,
|
||||
"fallback_config_id": llm_config_id,
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
|
|
@ -3265,31 +3400,108 @@ async def stream_resume_chat(
|
|||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=Command(resume={"decisions": decisions}),
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking-resume",
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
runtime_rate_limit_recovered = False
|
||||
while True:
|
||||
try:
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=Command(resume={"decisions": decisions}),
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking-resume",
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
break
|
||||
except Exception as stream_exc:
|
||||
can_runtime_recover = (
|
||||
not runtime_rate_limit_recovered
|
||||
and requested_llm_config_id == 0
|
||||
and llm_config_id < 0
|
||||
and not _first_event_logged
|
||||
and _is_provider_rate_limited(stream_exc)
|
||||
)
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
if not can_runtime_recover:
|
||||
raise
|
||||
|
||||
runtime_rate_limit_recovered = True
|
||||
previous_config_id = llm_config_id
|
||||
mark_runtime_cooldown(
|
||||
previous_config_id,
|
||||
reason="provider_rate_limited",
|
||||
)
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||
if llm_load_error:
|
||||
raise stream_exc
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Runtime rate-limit recovery repinned "
|
||||
"config_id=%s -> %s and rebuilt agent in %.3fs",
|
||||
previous_config_id,
|
||||
llm_config_id,
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
_log_chat_stream_error(
|
||||
flow="resume",
|
||||
error_kind="rate_limited",
|
||||
error_code="RATE_LIMITED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Auto-pinned model hit runtime rate limit; switched to "
|
||||
"another eligible model and retried."
|
||||
),
|
||||
extra={
|
||||
"auto_runtime_recover": True,
|
||||
"previous_config_id": previous_config_id,
|
||||
"fallback_config_id": llm_config_id,
|
||||
},
|
||||
)
|
||||
continue
|
||||
_perf_log.info(
|
||||
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
|
|
|
|||
|
|
@ -6,12 +6,21 @@ from types import SimpleNamespace
|
|||
import pytest
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
clear_runtime_cooldown,
|
||||
mark_runtime_cooldown,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_runtime_cooldown_map():
|
||||
clear_runtime_cooldown()
|
||||
yield
|
||||
clear_runtime_cooldown()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeQuotaResult:
|
||||
allowed: bool
|
||||
|
|
@ -701,3 +710,106 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
|
|||
assert result.resolved_llm_config_id == -1
|
||||
assert result.from_existing_pin is True
|
||||
assert session.commit_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
|
||||
"""A runtime-cooled config should be excluded from candidate reuse.
|
||||
|
||||
This enables one-shot recovery from transient provider 429 bursts: we can
|
||||
mark the pinned cfg as cooled down and force a repair to another eligible
|
||||
cfg on the next resolution.
|
||||
"""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 90,
|
||||
"health_gated": False,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 80,
|
||||
"health_gated": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 90,
|
||||
"health_gated": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def _must_not_call(*_args, **_kwargs):
|
||||
raise AssertionError("premium_get_usage should not run on healthy pin reuse")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_must_not_call,
|
||||
)
|
||||
|
||||
mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
|
||||
clear_runtime_cooldown(-1)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
assert result.from_existing_pin is True
|
||||
|
|
|
|||
|
|
@ -159,6 +159,22 @@ def test_stream_exception_classifies_rate_limited():
|
|||
assert extra is None
|
||||
|
||||
|
||||
def test_stream_exception_classifies_openrouter_429_payload():
|
||||
exc = Exception(
|
||||
'OpenrouterException - {"error":{"message":"Provider returned error","code":429,'
|
||||
'"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}'
|
||||
)
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "rate_limited"
|
||||
assert code == "RATE_LIMITED"
|
||||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "temporarily rate-limited" in user_message
|
||||
assert extra is None
|
||||
|
||||
|
||||
def test_stream_exception_classifies_thread_busy():
|
||||
exc = BusyError(request_id="thread-123")
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue