mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 05:12:38 +02:00
feat(auto_model_pin): implement runtime cooldown for error handling and enhance candidate selection
This commit is contained in:
parent
4bef75d298
commit
f65b3be1ce
4 changed files with 486 additions and 86 deletions
|
|
@ -16,6 +16,8 @@ from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
@ -31,6 +33,13 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
AUTO_FASTEST_ID = 0
|
AUTO_FASTEST_ID = 0
|
||||||
AUTO_FASTEST_MODE = "auto_fastest"
|
AUTO_FASTEST_MODE = "auto_fastest"
|
||||||
|
_RUNTIME_COOLDOWN_SECONDS = 600
|
||||||
|
|
||||||
|
# In-memory runtime cooldown map for configs that recently hard-failed at
|
||||||
|
# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps
|
||||||
|
# the same unhealthy config from being reselected immediately during repair.
|
||||||
|
_runtime_cooldown_until: dict[int, float] = {}
|
||||||
|
_runtime_cooldown_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -49,17 +58,68 @@ def _is_usable_global_config(cfg: dict) -> bool:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
|
||||||
|
now = time.time() if now_ts is None else now_ts
|
||||||
|
stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
|
||||||
|
for cid in stale:
|
||||||
|
_runtime_cooldown_until.pop(cid, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_runtime_cooled_down(config_id: int) -> bool:
|
||||||
|
with _runtime_cooldown_lock:
|
||||||
|
_prune_runtime_cooldowns()
|
||||||
|
return config_id in _runtime_cooldown_until
|
||||||
|
|
||||||
|
|
||||||
|
def mark_runtime_cooldown(
|
||||||
|
config_id: int,
|
||||||
|
*,
|
||||||
|
reason: str = "rate_limited",
|
||||||
|
cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS,
|
||||||
|
) -> None:
|
||||||
|
"""Temporarily suppress a config from Auto selection.
|
||||||
|
|
||||||
|
Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned
|
||||||
|
config that is currently unhealthy does not get immediately reused on the
|
||||||
|
same thread during repair.
|
||||||
|
"""
|
||||||
|
if cooldown_seconds <= 0:
|
||||||
|
cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS
|
||||||
|
until = time.time() + int(cooldown_seconds)
|
||||||
|
with _runtime_cooldown_lock:
|
||||||
|
_runtime_cooldown_until[int(config_id)] = until
|
||||||
|
_prune_runtime_cooldowns()
|
||||||
|
logger.info(
|
||||||
|
"auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s",
|
||||||
|
config_id,
|
||||||
|
reason,
|
||||||
|
cooldown_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_runtime_cooldown(config_id: int | None = None) -> None:
|
||||||
|
"""Test/ops helper to clear runtime cooldown entries."""
|
||||||
|
with _runtime_cooldown_lock:
|
||||||
|
if config_id is None:
|
||||||
|
_runtime_cooldown_until.clear()
|
||||||
|
return
|
||||||
|
_runtime_cooldown_until.pop(int(config_id), None)
|
||||||
|
|
||||||
|
|
||||||
def _global_candidates() -> list[dict]:
|
def _global_candidates() -> list[dict]:
|
||||||
"""Return Auto-eligible global cfgs.
|
"""Return Auto-eligible global cfgs.
|
||||||
|
|
||||||
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
|
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
|
||||||
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
|
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
|
||||||
can't be picked as the thread's pin.
|
can't be picked as the thread's pin. Also excludes configs currently
|
||||||
|
in runtime cooldown (e.g. temporary 429 bursts).
|
||||||
"""
|
"""
|
||||||
candidates = [
|
candidates = [
|
||||||
cfg
|
cfg
|
||||||
for cfg in config.GLOBAL_LLM_CONFIGS
|
for cfg in config.GLOBAL_LLM_CONFIGS
|
||||||
if _is_usable_global_config(cfg) and not cfg.get("health_gated")
|
if _is_usable_global_config(cfg)
|
||||||
|
and not cfg.get("health_gated")
|
||||||
|
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
|
||||||
]
|
]
|
||||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,10 @@ from app.db import (
|
||||||
shielded_async_session,
|
shielded_async_session,
|
||||||
)
|
)
|
||||||
from app.prompts import TITLE_GENERATION_PROMPT
|
from app.prompts import TITLE_GENERATION_PROMPT
|
||||||
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
|
from app.services.auto_model_pin_service import (
|
||||||
|
mark_runtime_cooldown,
|
||||||
|
resolve_or_get_pinned_llm_config_id,
|
||||||
|
)
|
||||||
from app.services.chat_session_state_service import (
|
from app.services.chat_session_state_service import (
|
||||||
clear_ai_responding,
|
clear_ai_responding,
|
||||||
set_ai_responding,
|
set_ai_responding,
|
||||||
|
|
@ -414,6 +417,60 @@ def _parse_error_payload(message: str) -> dict[str, Any] | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
return None
|
||||||
|
candidates: list[Any] = [parsed.get("code")]
|
||||||
|
nested = parsed.get("error")
|
||||||
|
if isinstance(nested, dict):
|
||||||
|
candidates.append(nested.get("code"))
|
||||||
|
for value in candidates:
|
||||||
|
try:
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
return int(value)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_provider_rate_limited(exc: BaseException) -> bool:
|
||||||
|
"""Best-effort detection for provider-side runtime throttling.
|
||||||
|
|
||||||
|
Covers LiteLLM/OpenRouter shapes like:
|
||||||
|
- class name contains ``RateLimit``
|
||||||
|
- nested payload ``{"error": {"code": 429}}``
|
||||||
|
- nested payload ``{"error": {"type": "rate_limit_error"}}``
|
||||||
|
"""
|
||||||
|
raw = str(exc)
|
||||||
|
lowered = raw.lower()
|
||||||
|
if "ratelimit" in type(exc).__name__.lower():
|
||||||
|
return True
|
||||||
|
parsed = _parse_error_payload(raw)
|
||||||
|
provider_code = _extract_provider_error_code(parsed)
|
||||||
|
if provider_code == 429:
|
||||||
|
return True
|
||||||
|
|
||||||
|
provider_error_type = ""
|
||||||
|
if parsed:
|
||||||
|
top_type = parsed.get("type")
|
||||||
|
if isinstance(top_type, str):
|
||||||
|
provider_error_type = top_type.lower()
|
||||||
|
nested = parsed.get("error")
|
||||||
|
if isinstance(nested, dict):
|
||||||
|
nested_type = nested.get("type")
|
||||||
|
if isinstance(nested_type, str):
|
||||||
|
provider_error_type = nested_type.lower()
|
||||||
|
if provider_error_type == "rate_limit_error":
|
||||||
|
return True
|
||||||
|
|
||||||
|
return (
|
||||||
|
"rate limited" in lowered
|
||||||
|
or "rate-limited" in lowered
|
||||||
|
or "temporarily rate-limited upstream" in lowered
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _classify_stream_exception(
|
def _classify_stream_exception(
|
||||||
exc: Exception,
|
exc: Exception,
|
||||||
*,
|
*,
|
||||||
|
|
@ -449,19 +506,7 @@ def _classify_stream_exception(
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
parsed = _parse_error_payload(raw)
|
if _is_provider_rate_limited(exc):
|
||||||
provider_error_type = ""
|
|
||||||
if parsed:
|
|
||||||
top_type = parsed.get("type")
|
|
||||||
if isinstance(top_type, str):
|
|
||||||
provider_error_type = top_type.lower()
|
|
||||||
nested = parsed.get("error")
|
|
||||||
if isinstance(nested, dict):
|
|
||||||
nested_type = nested.get("type")
|
|
||||||
if isinstance(nested_type, str):
|
|
||||||
provider_error_type = nested_type.lower()
|
|
||||||
|
|
||||||
if provider_error_type == "rate_limit_error":
|
|
||||||
return (
|
return (
|
||||||
"rate_limited",
|
"rate_limited",
|
||||||
"RATE_LIMITED",
|
"RATE_LIMITED",
|
||||||
|
|
@ -2671,54 +2716,144 @@ async def stream_new_chat(
|
||||||
|
|
||||||
_t_stream_start = time.perf_counter()
|
_t_stream_start = time.perf_counter()
|
||||||
_first_event_logged = False
|
_first_event_logged = False
|
||||||
async for sse in _stream_agent_events(
|
runtime_rate_limit_recovered = False
|
||||||
agent=agent,
|
while True:
|
||||||
config=config,
|
try:
|
||||||
input_data=input_state,
|
async for sse in _stream_agent_events(
|
||||||
streaming_service=streaming_service,
|
agent=agent,
|
||||||
result=stream_result,
|
config=config,
|
||||||
step_prefix="thinking",
|
input_data=input_state,
|
||||||
initial_step_id=initial_step_id,
|
streaming_service=streaming_service,
|
||||||
initial_step_title=initial_title,
|
result=stream_result,
|
||||||
initial_step_items=initial_items,
|
step_prefix="thinking",
|
||||||
fallback_commit_search_space_id=search_space_id,
|
initial_step_id=initial_step_id,
|
||||||
fallback_commit_created_by_id=user_id,
|
initial_step_title=initial_title,
|
||||||
fallback_commit_filesystem_mode=(
|
initial_step_items=initial_items,
|
||||||
filesystem_selection.mode
|
fallback_commit_search_space_id=search_space_id,
|
||||||
if filesystem_selection
|
fallback_commit_created_by_id=user_id,
|
||||||
else FilesystemMode.CLOUD
|
fallback_commit_filesystem_mode=(
|
||||||
),
|
filesystem_selection.mode
|
||||||
fallback_commit_thread_id=chat_id,
|
if filesystem_selection
|
||||||
):
|
else FilesystemMode.CLOUD
|
||||||
if not _first_event_logged:
|
),
|
||||||
_perf_log.info(
|
fallback_commit_thread_id=chat_id,
|
||||||
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
|
):
|
||||||
"%.3fs (total since request start) (chat_id=%s)",
|
if not _first_event_logged:
|
||||||
time.perf_counter() - _t_stream_start,
|
_perf_log.info(
|
||||||
time.perf_counter() - _t_total,
|
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
|
||||||
chat_id,
|
"%.3fs (total since request start) (chat_id=%s)",
|
||||||
)
|
time.perf_counter() - _t_stream_start,
|
||||||
_first_event_logged = True
|
time.perf_counter() - _t_total,
|
||||||
yield sse
|
chat_id,
|
||||||
|
|
||||||
# Inject title update mid-stream as soon as the background task finishes
|
|
||||||
if title_task is not None and title_task.done() and not title_emitted:
|
|
||||||
generated_title, title_usage = title_task.result()
|
|
||||||
if title_usage:
|
|
||||||
accumulator.add(**title_usage)
|
|
||||||
if generated_title:
|
|
||||||
async with shielded_async_session() as title_session:
|
|
||||||
title_thread_result = await title_session.execute(
|
|
||||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
|
||||||
)
|
)
|
||||||
title_thread = title_thread_result.scalars().first()
|
_first_event_logged = True
|
||||||
if title_thread:
|
yield sse
|
||||||
title_thread.title = generated_title
|
|
||||||
await title_session.commit()
|
# Inject title update mid-stream as soon as the background
|
||||||
yield streaming_service.format_thread_title_update(
|
# task finishes.
|
||||||
chat_id, generated_title
|
if title_task is not None and title_task.done() and not title_emitted:
|
||||||
|
generated_title, title_usage = title_task.result()
|
||||||
|
if title_usage:
|
||||||
|
accumulator.add(**title_usage)
|
||||||
|
if generated_title:
|
||||||
|
async with shielded_async_session() as title_session:
|
||||||
|
title_thread_result = await title_session.execute(
|
||||||
|
select(NewChatThread).filter(
|
||||||
|
NewChatThread.id == chat_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
title_thread = title_thread_result.scalars().first()
|
||||||
|
if title_thread:
|
||||||
|
title_thread.title = generated_title
|
||||||
|
await title_session.commit()
|
||||||
|
yield streaming_service.format_thread_title_update(
|
||||||
|
chat_id, generated_title
|
||||||
|
)
|
||||||
|
title_emitted = True
|
||||||
|
break
|
||||||
|
except Exception as stream_exc:
|
||||||
|
can_runtime_recover = (
|
||||||
|
not runtime_rate_limit_recovered
|
||||||
|
and requested_llm_config_id == 0
|
||||||
|
and llm_config_id < 0
|
||||||
|
and not _first_event_logged
|
||||||
|
and _is_provider_rate_limited(stream_exc)
|
||||||
|
)
|
||||||
|
if not can_runtime_recover:
|
||||||
|
raise
|
||||||
|
|
||||||
|
runtime_rate_limit_recovered = True
|
||||||
|
previous_config_id = llm_config_id
|
||||||
|
mark_runtime_cooldown(
|
||||||
|
previous_config_id,
|
||||||
|
reason="provider_rate_limited",
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_config_id = (
|
||||||
|
await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=chat_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
selected_llm_config_id=0,
|
||||||
)
|
)
|
||||||
title_emitted = True
|
).resolved_llm_config_id
|
||||||
|
|
||||||
|
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||||
|
if llm_load_error:
|
||||||
|
raise stream_exc
|
||||||
|
|
||||||
|
# Title generation uses the initial llm object. After a runtime
|
||||||
|
# repin we keep the stream focused on response recovery and skip
|
||||||
|
# title generation for this turn.
|
||||||
|
if title_task is not None and not title_task.done():
|
||||||
|
title_task.cancel()
|
||||||
|
title_task = None
|
||||||
|
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
agent = await create_surfsense_deep_agent(
|
||||||
|
llm=llm,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
db_session=session,
|
||||||
|
connector_service=connector_service,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=chat_id,
|
||||||
|
agent_config=agent_config,
|
||||||
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
|
thread_visibility=visibility,
|
||||||
|
disabled_tools=disabled_tools,
|
||||||
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
|
filesystem_selection=filesystem_selection,
|
||||||
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_new_chat] Runtime rate-limit recovery repinned "
|
||||||
|
"config_id=%s -> %s and rebuilt agent in %.3fs",
|
||||||
|
previous_config_id,
|
||||||
|
llm_config_id,
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
)
|
||||||
|
_log_chat_stream_error(
|
||||||
|
flow=flow,
|
||||||
|
error_kind="rate_limited",
|
||||||
|
error_code="RATE_LIMITED",
|
||||||
|
severity="info",
|
||||||
|
is_expected=True,
|
||||||
|
request_id=request_id,
|
||||||
|
thread_id=chat_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
message=(
|
||||||
|
"Auto-pinned model hit runtime rate limit; switched to "
|
||||||
|
"another eligible model and retried."
|
||||||
|
),
|
||||||
|
extra={
|
||||||
|
"auto_runtime_recover": True,
|
||||||
|
"previous_config_id": previous_config_id,
|
||||||
|
"fallback_config_id": llm_config_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
|
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
|
||||||
|
|
@ -3265,31 +3400,108 @@ async def stream_resume_chat(
|
||||||
|
|
||||||
_t_stream_start = time.perf_counter()
|
_t_stream_start = time.perf_counter()
|
||||||
_first_event_logged = False
|
_first_event_logged = False
|
||||||
async for sse in _stream_agent_events(
|
runtime_rate_limit_recovered = False
|
||||||
agent=agent,
|
while True:
|
||||||
config=config,
|
try:
|
||||||
input_data=Command(resume={"decisions": decisions}),
|
async for sse in _stream_agent_events(
|
||||||
streaming_service=streaming_service,
|
agent=agent,
|
||||||
result=stream_result,
|
config=config,
|
||||||
step_prefix="thinking-resume",
|
input_data=Command(resume={"decisions": decisions}),
|
||||||
fallback_commit_search_space_id=search_space_id,
|
streaming_service=streaming_service,
|
||||||
fallback_commit_created_by_id=user_id,
|
result=stream_result,
|
||||||
fallback_commit_filesystem_mode=(
|
step_prefix="thinking-resume",
|
||||||
filesystem_selection.mode
|
fallback_commit_search_space_id=search_space_id,
|
||||||
if filesystem_selection
|
fallback_commit_created_by_id=user_id,
|
||||||
else FilesystemMode.CLOUD
|
fallback_commit_filesystem_mode=(
|
||||||
),
|
filesystem_selection.mode
|
||||||
fallback_commit_thread_id=chat_id,
|
if filesystem_selection
|
||||||
):
|
else FilesystemMode.CLOUD
|
||||||
if not _first_event_logged:
|
),
|
||||||
_perf_log.info(
|
fallback_commit_thread_id=chat_id,
|
||||||
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
|
):
|
||||||
time.perf_counter() - _t_stream_start,
|
if not _first_event_logged:
|
||||||
time.perf_counter() - _t_total,
|
_perf_log.info(
|
||||||
chat_id,
|
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
|
||||||
|
time.perf_counter() - _t_stream_start,
|
||||||
|
time.perf_counter() - _t_total,
|
||||||
|
chat_id,
|
||||||
|
)
|
||||||
|
_first_event_logged = True
|
||||||
|
yield sse
|
||||||
|
break
|
||||||
|
except Exception as stream_exc:
|
||||||
|
can_runtime_recover = (
|
||||||
|
not runtime_rate_limit_recovered
|
||||||
|
and requested_llm_config_id == 0
|
||||||
|
and llm_config_id < 0
|
||||||
|
and not _first_event_logged
|
||||||
|
and _is_provider_rate_limited(stream_exc)
|
||||||
)
|
)
|
||||||
_first_event_logged = True
|
if not can_runtime_recover:
|
||||||
yield sse
|
raise
|
||||||
|
|
||||||
|
runtime_rate_limit_recovered = True
|
||||||
|
previous_config_id = llm_config_id
|
||||||
|
mark_runtime_cooldown(
|
||||||
|
previous_config_id,
|
||||||
|
reason="provider_rate_limited",
|
||||||
|
)
|
||||||
|
llm_config_id = (
|
||||||
|
await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=chat_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
).resolved_llm_config_id
|
||||||
|
|
||||||
|
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||||
|
if llm_load_error:
|
||||||
|
raise stream_exc
|
||||||
|
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
agent = await create_surfsense_deep_agent(
|
||||||
|
llm=llm,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
db_session=session,
|
||||||
|
connector_service=connector_service,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=chat_id,
|
||||||
|
agent_config=agent_config,
|
||||||
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
|
thread_visibility=visibility,
|
||||||
|
filesystem_selection=filesystem_selection,
|
||||||
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_resume] Runtime rate-limit recovery repinned "
|
||||||
|
"config_id=%s -> %s and rebuilt agent in %.3fs",
|
||||||
|
previous_config_id,
|
||||||
|
llm_config_id,
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
)
|
||||||
|
_log_chat_stream_error(
|
||||||
|
flow="resume",
|
||||||
|
error_kind="rate_limited",
|
||||||
|
error_code="RATE_LIMITED",
|
||||||
|
severity="info",
|
||||||
|
is_expected=True,
|
||||||
|
request_id=request_id,
|
||||||
|
thread_id=chat_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
message=(
|
||||||
|
"Auto-pinned model hit runtime rate limit; switched to "
|
||||||
|
"another eligible model and retried."
|
||||||
|
),
|
||||||
|
extra={
|
||||||
|
"auto_runtime_recover": True,
|
||||||
|
"previous_config_id": previous_config_id,
|
||||||
|
"fallback_config_id": llm_config_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
continue
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
|
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
|
||||||
time.perf_counter() - _t_stream_start,
|
time.perf_counter() - _t_stream_start,
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,21 @@ from types import SimpleNamespace
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.services.auto_model_pin_service import (
|
from app.services.auto_model_pin_service import (
|
||||||
|
clear_runtime_cooldown,
|
||||||
|
mark_runtime_cooldown,
|
||||||
resolve_or_get_pinned_llm_config_id,
|
resolve_or_get_pinned_llm_config_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clear_runtime_cooldown_map():
|
||||||
|
clear_runtime_cooldown()
|
||||||
|
yield
|
||||||
|
clear_runtime_cooldown()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _FakeQuotaResult:
|
class _FakeQuotaResult:
|
||||||
allowed: bool
|
allowed: bool
|
||||||
|
|
@ -701,3 +710,106 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
|
||||||
assert result.resolved_llm_config_id == -1
|
assert result.resolved_llm_config_id == -1
|
||||||
assert result.from_existing_pin is True
|
assert result.from_existing_pin is True
|
||||||
assert session.commit_count == 0
|
assert session.commit_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
|
||||||
|
"""A runtime-cooled config should be excluded from candidate reuse.
|
||||||
|
|
||||||
|
This enables one-shot recovery from transient provider 429 bursts: we can
|
||||||
|
mark the pinned cfg as cooled down and force a repair to another eligible
|
||||||
|
cfg on the next resolution.
|
||||||
|
"""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 90,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemini-2.5-flash:free",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 80,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _blocked(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=False)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_blocked,
|
||||||
|
)
|
||||||
|
|
||||||
|
mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -2
|
||||||
|
assert result.from_existing_pin is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 90,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _must_not_call(*_args, **_kwargs):
|
||||||
|
raise AssertionError("premium_get_usage should not run on healthy pin reuse")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_must_not_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
|
||||||
|
clear_runtime_cooldown(-1)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -1
|
||||||
|
assert result.from_existing_pin is True
|
||||||
|
|
|
||||||
|
|
@ -159,6 +159,22 @@ def test_stream_exception_classifies_rate_limited():
|
||||||
assert extra is None
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_exception_classifies_openrouter_429_payload():
|
||||||
|
exc = Exception(
|
||||||
|
'OpenrouterException - {"error":{"message":"Provider returned error","code":429,'
|
||||||
|
'"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}'
|
||||||
|
)
|
||||||
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
|
exc, flow_label="chat"
|
||||||
|
)
|
||||||
|
assert kind == "rate_limited"
|
||||||
|
assert code == "RATE_LIMITED"
|
||||||
|
assert severity == "warn"
|
||||||
|
assert is_expected is True
|
||||||
|
assert "temporarily rate-limited" in user_message
|
||||||
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
def test_stream_exception_classifies_thread_busy():
|
def test_stream_exception_classifies_thread_busy():
|
||||||
exc = BusyError(request_id="thread-123")
|
exc = BusyError(request_id="thread-123")
|
||||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue