From b54b803dc9a844d74700f3dc27eb00282d63b081 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Mon, 25 May 2026 21:49:27 +0200 Subject: [PATCH] refactor(chat): add streaming/flows/shared/ rate-limit recovery + stream loop Two cooperating modules that wrap stream_agent_events with in-stream recovery from provider 429s: * rate_limit_recovery: can_recover_provider_rate_limit truth-table guard, reroute_to_next_auto_pin (selects the next eligible auto-pin config and reloads the LLM bundle), log_rate_limit_recovered. * stream_loop: run_stream_loop drives stream_agent_events in a while-True loop, delegating recovery to a flow-supplied RecoverFn callback so new_chat and resume_chat can share the same loop while keeping their own nonlocal state. Add-only; not yet wired into any orchestrator. --- .../flows/shared/rate_limit_recovery.py | 129 ++++++++++++++++++ .../streaming/flows/shared/stream_loop.py | 85 ++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 surfsense_backend/app/tasks/chat/streaming/flows/shared/rate_limit_recovery.py create mode 100644 surfsense_backend/app/tasks/chat/streaming/flows/shared/stream_loop.py diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/rate_limit_recovery.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/rate_limit_recovery.py new file mode 100644 index 000000000..6b3857594 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/rate_limit_recovery.py @@ -0,0 +1,129 @@ +"""Shared steps for the in-stream provider rate-limit recovery loop. + +Both flows wrap ``run_stream_loop`` with a flow-specific ``recover`` closure; +the *guard*, the *auto-pin reroute*, and the *post-recovery telemetry* are the +same on both sides and live here so behaviour can't drift. + +The orchestrator owns the parts that genuinely diverge: + + * cancelling the title task (new_chat only), + * passing ``mentioned_document_ids`` to ``build_main_agent_for_thread``, + * the log prefix (``stream_new_chat`` vs ``stream_resume``). +""" + +from __future__ import annotations + +from typing import Literal + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.middleware.busy_mutex import end_turn +from app.observability import otel as ot +from app.services.auto_model_pin_service import ( + mark_runtime_cooldown, + resolve_or_get_pinned_llm_config_id, +) +from app.tasks.chat.streaming.errors.classifier import ( + is_provider_rate_limited, + log_chat_stream_error, +) + + +def can_recover_provider_rate_limit( + exc: BaseException, + *, + first_event_seen: bool, + runtime_rate_limit_recovered: bool, + requested_llm_config_id: int, + current_llm_config_id: int, +) -> bool: + """Guard: only the first auto-pin → provider-rate-limited failure recovers. + + All conditions must hold: + + * ``runtime_rate_limit_recovered is False`` — at most one recovery per turn. + * ``requested_llm_config_id == 0`` — caller opted into auto-pin (id=0). + * ``current_llm_config_id < 0`` — currently on a YAML config (the only + kind the auto-pin pool draws from). + * ``first_event_seen is False`` — we haven't sent any SSE to the user yet, + so a silent rebuild + retry is invisible. + * The exception is provider-side rate-limited (HTTP 429 or known shape). + """ + return ( + not runtime_rate_limit_recovered + and requested_llm_config_id == 0 + and current_llm_config_id < 0 + and not first_event_seen + and is_provider_rate_limited(exc) + ) + + +async def reroute_to_next_auto_pin( + session: AsyncSession, + *, + chat_id: int, + search_space_id: int, + user_id: str | None, + current_llm_config_id: int, + requires_image_input: bool, +) -> int: + """Release lock, cool down the failing config, pick a new auto-pin id. + + Returns the new ``llm_config_id``. ``end_turn`` is called because the failed + attempt may still hold the per-thread busy mutex (middleware teardown can + lag behind raised provider errors) — the same-request retry would otherwise + bounce on ``BusyError``. + """ + end_turn(str(chat_id)) + mark_runtime_cooldown(current_llm_config_id, reason="provider_rate_limited") + pinned = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={current_llm_config_id}, + requires_image_input=requires_image_input, + ) + return pinned.resolved_llm_config_id + + +def log_rate_limit_recovered( + *, + flow: Literal["new", "regenerate", "resume"], + request_id: str | None, + chat_id: int, + search_space_id: int, + user_id: str | None, + previous_config_id: int, + new_config_id: int, +) -> None: + """Emit the OTEL event + structured ``[chat_stream_error]`` log line.""" + ot.add_event( + "chat.rate_limit.recovered", + { + "recovery.reason": "provider_rate_limited", + "recovery.previous_config_id": previous_config_id, + "recovery.fallback_config_id": new_config_id, + }, + ) + log_chat_stream_error( + flow=flow, + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model hit runtime rate limit; switched to " + "another eligible model and retried." + ), + extra={ + "auto_runtime_recover": True, + "previous_config_id": previous_config_id, + "fallback_config_id": new_config_id, + }, + ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/stream_loop.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/stream_loop.py new file mode 100644 index 000000000..6cf0df855 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/stream_loop.py @@ -0,0 +1,85 @@ +"""Drive ``stream_agent_events`` with in-stream rate-limit recovery. + +Both ``stream_new_chat`` and ``stream_resume_chat`` wrap the agent event loop +in a ``while True`` that catches the *first* provider rate-limit error +(``can_runtime_recover``) before any SSE event reaches the user, rebuilds the +agent on an alternative auto-pin, and retries the turn. + +The recovery callback is flow-specific (different ``mentioned_document_ids`` +contract, different logging label, etc.) — this module owns the loop shape, +the caller owns the rebuild. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, Awaitable, Callable +from typing import Any + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat.streaming.agent.event_loop import stream_agent_events +from app.tasks.chat.streaming.shared.stream_result import StreamResult + +# Returns the rebuilt agent on a successful recovery, or ``None`` to re-raise +# the original exception (and let the orchestrator's terminal-error path +# handle it). +RecoverFn = Callable[[BaseException, bool], Awaitable[Any | None]] + + +async def run_stream_loop( + *, + agent: Any, + streaming_service: VercelStreamingService, + config: dict[str, Any], + input_data: Any, + stream_result: StreamResult, + step_prefix: str = "thinking", + initial_step_id: str | None = None, + initial_step_title: str = "", + initial_step_items: list[str] | None = None, + fallback_commit_search_space_id: int | None, + fallback_commit_created_by_id: str | None, + fallback_commit_filesystem_mode: FilesystemMode, + fallback_commit_thread_id: int | None, + runtime_context: Any, + content_builder: Any | None, + recover: RecoverFn, + on_first_event: Callable[[], None] | None = None, +) -> AsyncGenerator[str, None]: + """Yield SSE frames; rebuild and retry once on a pre-first-event rate limit. + + ``on_first_event`` fires after the first frame is observed (used by both + flows to write a one-time ``First agent event in N.NNNs`` perf line). + """ + first_event_logged = False + while True: + try: + async for sse in stream_agent_events( + agent=agent, + config=config, + input_data=input_data, + streaming_service=streaming_service, + result=stream_result, + step_prefix=step_prefix, + initial_step_id=initial_step_id, + initial_step_title=initial_step_title, + initial_step_items=initial_step_items, + fallback_commit_search_space_id=fallback_commit_search_space_id, + fallback_commit_created_by_id=fallback_commit_created_by_id, + fallback_commit_filesystem_mode=fallback_commit_filesystem_mode, + fallback_commit_thread_id=fallback_commit_thread_id, + runtime_context=runtime_context, + content_builder=content_builder, + ): + if not first_event_logged: + if on_first_event is not None: + on_first_event() + first_event_logged = True + yield sse + return + except Exception as exc: + new_agent = await recover(exc, first_event_logged) + if new_agent is None: + raise + agent = new_agent + continue