mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
feat(auto_model_pin): implement runtime cooldown for error handling and enhance candidate selection
This commit is contained in:
parent
4bef75d298
commit
f65b3be1ce
4 changed files with 486 additions and 86 deletions
|
|
@ -64,7 +64,10 @@ from app.db import (
|
|||
shielded_async_session,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT
|
||||
from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id
|
||||
from app.services.auto_model_pin_service import (
|
||||
mark_runtime_cooldown,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
from app.services.chat_session_state_service import (
|
||||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
|
|
@ -414,6 +417,60 @@ def _parse_error_payload(message: str) -> dict[str, Any] | None:
|
|||
return None
|
||||
|
||||
|
||||
def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None:
|
||||
if not isinstance(parsed, dict):
|
||||
return None
|
||||
candidates: list[Any] = [parsed.get("code")]
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
candidates.append(nested.get("code"))
|
||||
for value in candidates:
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
return int(value)
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _is_provider_rate_limited(exc: BaseException) -> bool:
|
||||
"""Best-effort detection for provider-side runtime throttling.
|
||||
|
||||
Covers LiteLLM/OpenRouter shapes like:
|
||||
- class name contains ``RateLimit``
|
||||
- nested payload ``{"error": {"code": 429}}``
|
||||
- nested payload ``{"error": {"type": "rate_limit_error"}}``
|
||||
"""
|
||||
raw = str(exc)
|
||||
lowered = raw.lower()
|
||||
if "ratelimit" in type(exc).__name__.lower():
|
||||
return True
|
||||
parsed = _parse_error_payload(raw)
|
||||
provider_code = _extract_provider_error_code(parsed)
|
||||
if provider_code == 429:
|
||||
return True
|
||||
|
||||
provider_error_type = ""
|
||||
if parsed:
|
||||
top_type = parsed.get("type")
|
||||
if isinstance(top_type, str):
|
||||
provider_error_type = top_type.lower()
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
nested_type = nested.get("type")
|
||||
if isinstance(nested_type, str):
|
||||
provider_error_type = nested_type.lower()
|
||||
if provider_error_type == "rate_limit_error":
|
||||
return True
|
||||
|
||||
return (
|
||||
"rate limited" in lowered
|
||||
or "rate-limited" in lowered
|
||||
or "temporarily rate-limited upstream" in lowered
|
||||
)
|
||||
|
||||
|
||||
def _classify_stream_exception(
|
||||
exc: Exception,
|
||||
*,
|
||||
|
|
@ -449,19 +506,7 @@ def _classify_stream_exception(
|
|||
None,
|
||||
)
|
||||
|
||||
parsed = _parse_error_payload(raw)
|
||||
provider_error_type = ""
|
||||
if parsed:
|
||||
top_type = parsed.get("type")
|
||||
if isinstance(top_type, str):
|
||||
provider_error_type = top_type.lower()
|
||||
nested = parsed.get("error")
|
||||
if isinstance(nested, dict):
|
||||
nested_type = nested.get("type")
|
||||
if isinstance(nested_type, str):
|
||||
provider_error_type = nested_type.lower()
|
||||
|
||||
if provider_error_type == "rate_limit_error":
|
||||
if _is_provider_rate_limited(exc):
|
||||
return (
|
||||
"rate_limited",
|
||||
"RATE_LIMITED",
|
||||
|
|
@ -2671,54 +2716,144 @@ async def stream_new_chat(
|
|||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=input_state,
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking",
|
||||
initial_step_id=initial_step_id,
|
||||
initial_step_title=initial_title,
|
||||
initial_step_items=initial_items,
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
|
||||
"%.3fs (total since request start) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
|
||||
# Inject title update mid-stream as soon as the background task finishes
|
||||
if title_task is not None and title_task.done() and not title_emitted:
|
||||
generated_title, title_usage = title_task.result()
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
runtime_rate_limit_recovered = False
|
||||
while True:
|
||||
try:
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=input_state,
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking",
|
||||
initial_step_id=initial_step_id,
|
||||
initial_step_title=initial_title,
|
||||
initial_step_items=initial_items,
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
|
||||
"%.3fs (total since request start) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(
|
||||
chat_id, generated_title
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
|
||||
# Inject title update mid-stream as soon as the background
|
||||
# task finishes.
|
||||
if title_task is not None and title_task.done() and not title_emitted:
|
||||
generated_title, title_usage = title_task.result()
|
||||
if title_usage:
|
||||
accumulator.add(**title_usage)
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(
|
||||
NewChatThread.id == chat_id
|
||||
)
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(
|
||||
chat_id, generated_title
|
||||
)
|
||||
title_emitted = True
|
||||
break
|
||||
except Exception as stream_exc:
|
||||
can_runtime_recover = (
|
||||
not runtime_rate_limit_recovered
|
||||
and requested_llm_config_id == 0
|
||||
and llm_config_id < 0
|
||||
and not _first_event_logged
|
||||
and _is_provider_rate_limited(stream_exc)
|
||||
)
|
||||
if not can_runtime_recover:
|
||||
raise
|
||||
|
||||
runtime_rate_limit_recovered = True
|
||||
previous_config_id = llm_config_id
|
||||
mark_runtime_cooldown(
|
||||
previous_config_id,
|
||||
reason="provider_rate_limited",
|
||||
)
|
||||
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
title_emitted = True
|
||||
).resolved_llm_config_id
|
||||
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||
if llm_load_error:
|
||||
raise stream_exc
|
||||
|
||||
# Title generation uses the initial llm object. After a runtime
|
||||
# repin we keep the stream focused on response recovery and skip
|
||||
# title generation for this turn.
|
||||
if title_task is not None and not title_task.done():
|
||||
title_task.cancel()
|
||||
title_task = None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
filesystem_selection=filesystem_selection,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Runtime rate-limit recovery repinned "
|
||||
"config_id=%s -> %s and rebuilt agent in %.3fs",
|
||||
previous_config_id,
|
||||
llm_config_id,
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
_log_chat_stream_error(
|
||||
flow=flow,
|
||||
error_kind="rate_limited",
|
||||
error_code="RATE_LIMITED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Auto-pinned model hit runtime rate limit; switched to "
|
||||
"another eligible model and retried."
|
||||
),
|
||||
extra={
|
||||
"auto_runtime_recover": True,
|
||||
"previous_config_id": previous_config_id,
|
||||
"fallback_config_id": llm_config_id,
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
|
|
@ -3265,31 +3400,108 @@ async def stream_resume_chat(
|
|||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=Command(resume={"decisions": decisions}),
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking-resume",
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
runtime_rate_limit_recovered = False
|
||||
while True:
|
||||
try:
|
||||
async for sse in _stream_agent_events(
|
||||
agent=agent,
|
||||
config=config,
|
||||
input_data=Command(resume={"decisions": decisions}),
|
||||
streaming_service=streaming_service,
|
||||
result=stream_result,
|
||||
step_prefix="thinking-resume",
|
||||
fallback_commit_search_space_id=search_space_id,
|
||||
fallback_commit_created_by_id=user_id,
|
||||
fallback_commit_filesystem_mode=(
|
||||
filesystem_selection.mode
|
||||
if filesystem_selection
|
||||
else FilesystemMode.CLOUD
|
||||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
time.perf_counter() - _t_total,
|
||||
chat_id,
|
||||
)
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
break
|
||||
except Exception as stream_exc:
|
||||
can_runtime_recover = (
|
||||
not runtime_rate_limit_recovered
|
||||
and requested_llm_config_id == 0
|
||||
and llm_config_id < 0
|
||||
and not _first_event_logged
|
||||
and _is_provider_rate_limited(stream_exc)
|
||||
)
|
||||
_first_event_logged = True
|
||||
yield sse
|
||||
if not can_runtime_recover:
|
||||
raise
|
||||
|
||||
runtime_rate_limit_recovered = True
|
||||
previous_config_id = llm_config_id
|
||||
mark_runtime_cooldown(
|
||||
previous_config_id,
|
||||
reason="provider_rate_limited",
|
||||
)
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
|
||||
llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id)
|
||||
if llm_load_error:
|
||||
raise stream_exc
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=chat_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Runtime rate-limit recovery repinned "
|
||||
"config_id=%s -> %s and rebuilt agent in %.3fs",
|
||||
previous_config_id,
|
||||
llm_config_id,
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
_log_chat_stream_error(
|
||||
flow="resume",
|
||||
error_kind="rate_limited",
|
||||
error_code="RATE_LIMITED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id=request_id,
|
||||
thread_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
message=(
|
||||
"Auto-pinned model hit runtime rate limit; switched to "
|
||||
"another eligible model and retried."
|
||||
),
|
||||
extra={
|
||||
"auto_runtime_recover": True,
|
||||
"previous_config_id": previous_config_id,
|
||||
"fallback_config_id": llm_config_id,
|
||||
},
|
||||
)
|
||||
continue
|
||||
_perf_log.info(
|
||||
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue