feat(auto_model_pin): implement runtime cooldown for error handling and enhance candidate selection

This commit is contained in:
Anish Sarkar 2026-05-02 00:57:52 +05:30
parent 4bef75d298
commit f65b3be1ce
4 changed files with 486 additions and 86 deletions

View file

@ -64,7 +64,10 @@ from app.db import (
shielded_async_session,
)
from app.prompts import TITLE_GENERATION_PROMPT
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
from app.services.auto_model_pin_service import (
mark_runtime_cooldown,
resolve_or_get_pinned_llm_config_id,
)
from app.services.chat_session_state_service import (
clear_ai_responding,
set_ai_responding,
@ -414,6 +417,60 @@ def _parse_error_payload(message: str) -> dict[str, Any] | None:
return None
def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
if not isinstance(parsed, dict):
return None
candidates: list[Any] = [parsed.get("code")]
nested = parsed.get("error")
if isinstance(nested, dict):
candidates.append(nested.get("code"))
for value in candidates:
try:
if value is None:
continue
return int(value)
except Exception:
continue
return None
def _is_provider_rate_limited(exc: BaseException) -> bool:
"""Best-effort detection for provider-side runtime throttling.
Covers LiteLLM/OpenRouter shapes like:
- class name contains ``RateLimit``
- nested payload ``{"error": {"code": 429}}``
- nested payload ``{"error": {"type": "rate_limit_error"}}``
"""
raw = str(exc)
lowered = raw.lower()
if "ratelimit" in type(exc).__name__.lower():
return True
parsed = _parse_error_payload(raw)
provider_code = _extract_provider_error_code(parsed)
if provider_code == 429:
return True
provider_error_type = ""
if parsed:
top_type = parsed.get("type")
if isinstance(top_type, str):
provider_error_type = top_type.lower()
nested = parsed.get("error")
if isinstance(nested, dict):
nested_type = nested.get("type")
if isinstance(nested_type, str):
provider_error_type = nested_type.lower()
if provider_error_type == "rate_limit_error":
return True
return (
"rate limited" in lowered
or "rate-limited" in lowered
or "temporarily rate-limited upstream" in lowered
)
def _classify_stream_exception(
exc: Exception,
*,
@ -449,19 +506,7 @@ def _classify_stream_exception(
None,
)
parsed = _parse_error_payload(raw)
provider_error_type = ""
if parsed:
top_type = parsed.get("type")
if isinstance(top_type, str):
provider_error_type = top_type.lower()
nested = parsed.get("error")
if isinstance(nested, dict):
nested_type = nested.get("type")
if isinstance(nested_type, str):
provider_error_type = nested_type.lower()
if provider_error_type == "rate_limit_error":
if _is_provider_rate_limited(exc):
return (
"rate_limited",
"RATE_LIMITED",
@ -2671,54 +2716,144 @@ async def stream_new_chat(
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
agent=agent,
config=config,
input_data=input_state,
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking",
initial_step_id=initial_step_id,
initial_step_title=initial_title,
initial_step_items=initial_items,
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
"%.3fs (total since request start) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
_first_event_logged = True
yield sse
# Inject title update mid-stream as soon as the background task finishes
if title_task is not None and title_task.done() and not title_emitted:
generated_title, title_usage = title_task.result()
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
runtime_rate_limit_recovered = False
while True:
try:
async for sse in _stream_agent_events(
agent=agent,
config=config,
input_data=input_state,
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking",
initial_step_id=initial_step_id,
initial_step_title=initial_title,
initial_step_items=initial_items,
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
"%.3fs (total since request start) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
title_thread = title_thread_result.scalars().first()
if title_thread:
title_thread.title = generated_title
await title_session.commit()
yield streaming_service.format_thread_title_update(
chat_id, generated_title
_first_event_logged = True
yield sse
# Inject title update mid-stream as soon as the background
# task finishes.
if title_task is not None and title_task.done() and not title_emitted:
generated_title, title_usage = title_task.result()
if title_usage:
accumulator.add(**title_usage)
if generated_title:
async with shielded_async_session() as title_session:
title_thread_result = await title_session.execute(
select(NewChatThread).filter(
NewChatThread.id == chat_id
)
)
title_thread = title_thread_result.scalars().first()
if title_thread:
title_thread.title = generated_title
await title_session.commit()
yield streaming_service.format_thread_title_update(
chat_id, generated_title
)
title_emitted = True
break
except Exception as stream_exc:
can_runtime_recover = (
not runtime_rate_limit_recovered
and requested_llm_config_id == 0
and llm_config_id < 0
and not _first_event_logged
and _is_provider_rate_limited(stream_exc)
)
if not can_runtime_recover:
raise
runtime_rate_limit_recovered = True
previous_config_id = llm_config_id
mark_runtime_cooldown(
previous_config_id,
reason="provider_rate_limited",
)
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
)
title_emitted = True
).resolved_llm_config_id
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
if llm_load_error:
raise stream_exc
# Title generation uses the initial llm object. After a runtime
# repin we keep the stream focused on response recovery and skip
# title generation for this turn.
if title_task is not None and not title_task.done():
title_task.cancel()
title_task = None
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
)
_perf_log.info(
"[stream_new_chat] Runtime rate-limit recovery repinned "
"config_id=%s -> %s and rebuilt agent in %.3fs",
previous_config_id,
llm_config_id,
time.perf_counter() - _t0,
)
_log_chat_stream_error(
flow=flow,
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model hit runtime rate limit; switched to "
"another eligible model and retried."
),
extra={
"auto_runtime_recover": True,
"previous_config_id": previous_config_id,
"fallback_config_id": llm_config_id,
},
)
continue
_perf_log.info(
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
@ -3265,31 +3400,108 @@ async def stream_resume_chat(
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
agent=agent,
config=config,
input_data=Command(resume={"decisions": decisions}),
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking-resume",
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
runtime_rate_limit_recovered = False
while True:
try:
async for sse in _stream_agent_events(
agent=agent,
config=config,
input_data=Command(resume={"decisions": decisions}),
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking-resume",
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
time.perf_counter() - _t_stream_start,
time.perf_counter() - _t_total,
chat_id,
)
_first_event_logged = True
yield sse
break
except Exception as stream_exc:
can_runtime_recover = (
not runtime_rate_limit_recovered
and requested_llm_config_id == 0
and llm_config_id < 0
and not _first_event_logged
and _is_provider_rate_limited(stream_exc)
)
_first_event_logged = True
yield sse
if not can_runtime_recover:
raise
runtime_rate_limit_recovered = True
previous_config_id = llm_config_id
mark_runtime_cooldown(
previous_config_id,
reason="provider_rate_limited",
)
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
)
).resolved_llm_config_id
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
if llm_load_error:
raise stream_exc
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=chat_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
)
_perf_log.info(
"[stream_resume] Runtime rate-limit recovery repinned "
"config_id=%s -> %s and rebuilt agent in %.3fs",
previous_config_id,
llm_config_id,
time.perf_counter() - _t0,
)
_log_chat_stream_error(
flow="resume",
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model hit runtime rate limit; switched to "
"another eligible model and retried."
),
extra={
"auto_runtime_recover": True,
"previous_config_id": previous_config_id,
"fallback_config_id": llm_config_id,
},
)
continue
_perf_log.info(
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
time.perf_counter() - _t_stream_start,