refactor(chat): add streaming/flows/shared/ rate-limit recovery + stream loop

Two cooperating modules that wrap stream_agent_events with in-stream
recovery from provider 429s:

* rate_limit_recovery: can_recover_provider_rate_limit truth-table
  guard, reroute_to_next_auto_pin (selects the next eligible auto-pin
  config and reloads the LLM bundle), log_rate_limit_recovered.
* stream_loop: run_stream_loop drives stream_agent_events in a
  while-True loop, delegating recovery to a flow-supplied RecoverFn
  callback so new_chat and resume_chat can share the same loop while
  keeping their own nonlocal state.

Add-only; not yet wired into any orchestrator.
This commit is contained in:
CREDO23 2026-05-25 21:49:27 +02:00
parent 2c3edb7c84
commit b54b803dc9
2 changed files with 214 additions and 0 deletions

View file

@ -0,0 +1,129 @@
"""Shared steps for the in-stream provider rate-limit recovery loop.
Both flows wrap ``run_stream_loop`` with a flow-specific ``recover`` closure;
the *guard*, the *auto-pin reroute*, and the *post-recovery telemetry* are the
same on both sides and live here so behaviour can't drift.
The orchestrator owns the parts that genuinely diverge:
* cancelling the title task (new_chat only),
* passing ``mentioned_document_ids`` to ``build_main_agent_for_thread``,
* the log prefix (``stream_new_chat`` vs ``stream_resume``).
"""
from __future__ import annotations
from typing import Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.middleware.busy_mutex import end_turn
from app.observability import otel as ot
from app.services.auto_model_pin_service import (
mark_runtime_cooldown,
resolve_or_get_pinned_llm_config_id,
)
from app.tasks.chat.streaming.errors.classifier import (
is_provider_rate_limited,
log_chat_stream_error,
)
def can_recover_provider_rate_limit(
exc: BaseException,
*,
first_event_seen: bool,
runtime_rate_limit_recovered: bool,
requested_llm_config_id: int,
current_llm_config_id: int,
) -> bool:
"""Guard: only the first auto-pin → provider-rate-limited failure recovers.
All conditions must hold:
* ``runtime_rate_limit_recovered is False`` at most one recovery per turn.
* ``requested_llm_config_id == 0`` caller opted into auto-pin (id=0).
* ``current_llm_config_id < 0`` currently on a YAML config (the only
kind the auto-pin pool draws from).
* ``first_event_seen is False`` we haven't sent any SSE to the user yet,
so a silent rebuild + retry is invisible.
* The exception is provider-side rate-limited (HTTP 429 or known shape).
"""
return (
not runtime_rate_limit_recovered
and requested_llm_config_id == 0
and current_llm_config_id < 0
and not first_event_seen
and is_provider_rate_limited(exc)
)
async def reroute_to_next_auto_pin(
session: AsyncSession,
*,
chat_id: int,
search_space_id: int,
user_id: str | None,
current_llm_config_id: int,
requires_image_input: bool,
) -> int:
"""Release lock, cool down the failing config, pick a new auto-pin id.
Returns the new ``llm_config_id``. ``end_turn`` is called because the failed
attempt may still hold the per-thread busy mutex (middleware teardown can
lag behind raised provider errors) the same-request retry would otherwise
bounce on ``BusyError``.
"""
end_turn(str(chat_id))
mark_runtime_cooldown(current_llm_config_id, reason="provider_rate_limited")
pinned = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={current_llm_config_id},
requires_image_input=requires_image_input,
)
return pinned.resolved_llm_config_id
def log_rate_limit_recovered(
*,
flow: Literal["new", "regenerate", "resume"],
request_id: str | None,
chat_id: int,
search_space_id: int,
user_id: str | None,
previous_config_id: int,
new_config_id: int,
) -> None:
"""Emit the OTEL event + structured ``[chat_stream_error]`` log line."""
ot.add_event(
"chat.rate_limit.recovered",
{
"recovery.reason": "provider_rate_limited",
"recovery.previous_config_id": previous_config_id,
"recovery.fallback_config_id": new_config_id,
},
)
log_chat_stream_error(
flow=flow,
error_kind="rate_limited",
error_code="RATE_LIMITED",
severity="info",
is_expected=True,
request_id=request_id,
thread_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
message=(
"Auto-pinned model hit runtime rate limit; switched to "
"another eligible model and retried."
),
extra={
"auto_runtime_recover": True,
"previous_config_id": previous_config_id,
"fallback_config_id": new_config_id,
},
)

View file

@ -0,0 +1,85 @@
"""Drive ``stream_agent_events`` with in-stream rate-limit recovery.
Both ``stream_new_chat`` and ``stream_resume_chat`` wrap the agent event loop
in a ``while True`` that catches the *first* provider rate-limit error
(``can_runtime_recover``) before any SSE event reaches the user, rebuilds the
agent on an alternative auto-pin, and retries the turn.
The recovery callback is flow-specific (different ``mentioned_document_ids``
contract, different logging label, etc.) this module owns the loop shape,
the caller owns the rebuild.
"""
from __future__ import annotations
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.streaming.agent.event_loop import stream_agent_events
from app.tasks.chat.streaming.shared.stream_result import StreamResult
# Returns the rebuilt agent on a successful recovery, or ``None`` to re-raise
# the original exception (and let the orchestrator's terminal-error path
# handle it).
RecoverFn = Callable[[BaseException, bool], Awaitable[Any | None]]
async def run_stream_loop(
*,
agent: Any,
streaming_service: VercelStreamingService,
config: dict[str, Any],
input_data: Any,
stream_result: StreamResult,
step_prefix: str = "thinking",
initial_step_id: str | None = None,
initial_step_title: str = "",
initial_step_items: list[str] | None = None,
fallback_commit_search_space_id: int | None,
fallback_commit_created_by_id: str | None,
fallback_commit_filesystem_mode: FilesystemMode,
fallback_commit_thread_id: int | None,
runtime_context: Any,
content_builder: Any | None,
recover: RecoverFn,
on_first_event: Callable[[], None] | None = None,
) -> AsyncGenerator[str, None]:
"""Yield SSE frames; rebuild and retry once on a pre-first-event rate limit.
``on_first_event`` fires after the first frame is observed (used by both
flows to write a one-time ``First agent event in N.NNNs`` perf line).
"""
first_event_logged = False
while True:
try:
async for sse in stream_agent_events(
agent=agent,
config=config,
input_data=input_data,
streaming_service=streaming_service,
result=stream_result,
step_prefix=step_prefix,
initial_step_id=initial_step_id,
initial_step_title=initial_step_title,
initial_step_items=initial_step_items,
fallback_commit_search_space_id=fallback_commit_search_space_id,
fallback_commit_created_by_id=fallback_commit_created_by_id,
fallback_commit_filesystem_mode=fallback_commit_filesystem_mode,
fallback_commit_thread_id=fallback_commit_thread_id,
runtime_context=runtime_context,
content_builder=content_builder,
):
if not first_event_logged:
if on_first_event is not None:
on_first_event()
first_event_logged = True
yield sse
return
except Exception as exc:
new_agent = await recover(exc, first_event_logged)
if new_agent is None:
raise
agent = new_agent
continue