mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
Relocate the entire new_chat/middleware/ package to the shared kernel as one cohesive unit (it is live shared infrastructure: the multi-agent stack wraps nearly every middleware via multi_agent_chat/middleware/main_agent/*, and anonymous_agent consumes it too). Flip 69 live importers across both the package-path and submodule-path forms. Shims left for the frozen single-agent stack: a package __init__ re-export plus submodule shims for permission, skills_backends, and scoped_model_fallback (the three imported via submodule path by chat_deepagent/subagents). Cycle break: importing shared.middleware previously reached back into new_chat.tools at module load, which dragged in new_chat.__init__ -> chat_deepagent -> the middleware shim -> half-initialized shared.middleware. Made action_log's ToolDefinition import TYPE_CHECKING-only and tool_call_repair's INVALID_TOOL_NAME import function-local. These tools-package back-edges fully resolve in slice 6. Asset note: skills_backends._default_builtin_root now walks to app/agents/new_chat/skills/builtin (the skills/ tree migrates in slice 7).
277 lines
10 KiB
Python
277 lines
10 KiB
Python
"""
|
|
RetryAfterMiddleware — Header-aware retry with custom backoff and SSE eventing.
|
|
|
|
LangChain's :class:`ModelRetryMiddleware` retries on exceptions but ignores
|
|
the ``Retry-After`` HTTP header — it just runs its own exponential backoff.
|
|
That wastes time when a provider has explicitly told us how long to wait.
|
|
This middleware honors the header (mirroring OpenCode's
|
|
``packages/opencode/src/session/llm.ts`` retry pathway) and emits an SSE
|
|
event so the UI can show "rate-limited, retrying in Ns".
|
|
|
|
We can't subclass ``ModelRetryMiddleware`` cleanly because its loop calls a
|
|
module-level ``calculate_delay`` inline (no overridable
|
|
``_calculate_delay`` hook), so this is a standalone implementation.
|
|
|
|
Behaviour:
|
|
- Extracts ``Retry-After`` / ``retry-after-ms`` from
|
|
``litellm.exceptions.RateLimitError.response.headers`` (or any exception
|
|
exposing a similar shape).
|
|
- Sleeps ``max(exponential_backoff, header_delay)`` between retries.
|
|
- Returns ``False`` from ``retry_on`` for ``ContextWindowExceededError`` /
|
|
``ContextOverflowError`` so :class:`SurfSenseCompactionMiddleware` (or
|
|
the LangChain summarization fallback path) handles those instead.
|
|
- Emits ``surfsense.retrying`` via ``adispatch_custom_event`` on each retry
|
|
so ``stream_new_chat`` can forward it to clients as an SSE event.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import random
|
|
import re
|
|
import time
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import Any
|
|
|
|
from langchain.agents.middleware.types import (
|
|
AgentMiddleware,
|
|
AgentState,
|
|
ContextT,
|
|
ModelRequest,
|
|
ModelResponse,
|
|
ResponseT,
|
|
)
|
|
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
|
|
from langchain_core.messages import AIMessage
|
|
|
|
from app.observability import metrics as ot_metrics, otel as ot
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Names of exception classes for which a retry would not help — context
|
|
# overflow needs compaction, auth needs human intervention, etc. Detected
|
|
# by class-name substring so we don't have to import LiteLLM/Anthropic
|
|
# here (which would tie this module to optional deps).
|
|
_NON_RETRYABLE_NAME_HINTS: tuple[str, ...] = (
|
|
"ContextWindowExceeded",
|
|
"ContextOverflow",
|
|
"AuthenticationError",
|
|
"InvalidRequestError",
|
|
"PermissionDenied",
|
|
"InvalidApiKey",
|
|
"ContextLimit",
|
|
)
|
|
|
|
|
|
def _is_non_retryable(exc: BaseException) -> bool:
|
|
name = type(exc).__name__
|
|
return any(hint in name for hint in _NON_RETRYABLE_NAME_HINTS)
|
|
|
|
|
|
def _extract_retry_after_seconds(exc: BaseException) -> float | None:
|
|
"""Return seconds-to-wait suggested by the provider, if any.
|
|
|
|
Looks at ``exc.response.headers`` or ``exc.headers`` for the standard
|
|
HTTP ``Retry-After`` header (in seconds) or its millisecond cousin
|
|
``retry-after-ms`` (sometimes used by Anthropic / OpenAI). Falls back
|
|
to a regex on the exception message for shapes like
|
|
``"Please retry after 30s"``.
|
|
"""
|
|
headers: dict[str, Any] | None = None
|
|
response = getattr(exc, "response", None)
|
|
if response is not None:
|
|
headers = getattr(response, "headers", None)
|
|
if headers is None:
|
|
headers = getattr(exc, "headers", None)
|
|
|
|
if isinstance(headers, dict):
|
|
# Normalize keys to lowercase for case-insensitive matching
|
|
norm = {str(k).lower(): v for k, v in headers.items()}
|
|
ms = norm.get("retry-after-ms")
|
|
if ms is not None:
|
|
try:
|
|
return float(ms) / 1000.0
|
|
except (TypeError, ValueError):
|
|
pass
|
|
seconds = norm.get("retry-after")
|
|
if seconds is not None:
|
|
try:
|
|
return float(seconds)
|
|
except (TypeError, ValueError):
|
|
pass
|
|
|
|
# Last resort: scan the message for "retry after Xs" or "X seconds"
|
|
msg = str(exc)
|
|
match = re.search(r"retry\s+after\s+([0-9]+(?:\.[0-9]+)?)", msg, re.IGNORECASE)
|
|
if match:
|
|
try:
|
|
return float(match.group(1))
|
|
except ValueError:
|
|
return None
|
|
return None
|
|
|
|
|
|
def _exponential_delay(
|
|
attempt: int,
|
|
*,
|
|
initial_delay: float,
|
|
backoff_factor: float,
|
|
max_delay: float,
|
|
jitter: bool,
|
|
) -> float:
|
|
"""Compute an exponential-backoff delay with optional ±25% jitter."""
|
|
delay = (
|
|
initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
|
|
)
|
|
delay = min(delay, max_delay)
|
|
if jitter and delay > 0:
|
|
delay *= 1 + random.uniform(-0.25, 0.25)
|
|
return max(delay, 0.0)
|
|
|
|
|
|
class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
|
"""Retry middleware that honors provider-issued Retry-After hints.
|
|
|
|
Drop-in replacement for :class:`langchain.agents.middleware.ModelRetryMiddleware`
|
|
when working with LiteLLM/Anthropic/OpenAI providers that surface
|
|
rate-limit hints in headers. Always emits ``surfsense.retrying`` SSE
|
|
events so the UI can show a friendly "rate limited, retrying in Xs"
|
|
indicator.
|
|
|
|
Args:
|
|
max_retries: Maximum retries after the initial attempt (default 3).
|
|
initial_delay: Initial backoff delay in seconds.
|
|
backoff_factor: Exponential growth factor for backoff.
|
|
max_delay: Cap on per-attempt delay in seconds.
|
|
jitter: Whether to add ±25% jitter.
|
|
retry_on: Optional callable that returns True for retryable
|
|
exceptions. The default retries everything except known
|
|
non-retryable classes (context overflow, auth, etc.).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
max_retries: int = 3,
|
|
initial_delay: float = 1.0,
|
|
backoff_factor: float = 2.0,
|
|
max_delay: float = 60.0,
|
|
jitter: bool = True,
|
|
retry_on: Callable[[BaseException], bool] | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.max_retries = max_retries
|
|
self.initial_delay = initial_delay
|
|
self.backoff_factor = backoff_factor
|
|
self.max_delay = max_delay
|
|
self.jitter = jitter
|
|
self._retry_on: Callable[[BaseException], bool] = retry_on or (
|
|
lambda exc: not _is_non_retryable(exc)
|
|
)
|
|
|
|
def _should_retry(self, exc: BaseException) -> bool:
|
|
try:
|
|
return bool(self._retry_on(exc))
|
|
except Exception:
|
|
logger.exception("retry_on callable raised; defaulting to False")
|
|
return False
|
|
|
|
def _delay_for_attempt(self, attempt: int, exc: BaseException) -> float:
|
|
backoff = _exponential_delay(
|
|
attempt,
|
|
initial_delay=self.initial_delay,
|
|
backoff_factor=self.backoff_factor,
|
|
max_delay=self.max_delay,
|
|
jitter=self.jitter,
|
|
)
|
|
header = _extract_retry_after_seconds(exc) or 0.0
|
|
return max(backoff, header)
|
|
|
|
def wrap_model_call( # type: ignore[override]
|
|
self,
|
|
request: ModelRequest[ContextT],
|
|
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
|
) -> ModelResponse[ResponseT] | AIMessage:
|
|
for attempt in range(self.max_retries + 1):
|
|
try:
|
|
return handler(request)
|
|
except Exception as exc:
|
|
if not self._should_retry(exc) or attempt >= self.max_retries:
|
|
raise
|
|
delay = self._delay_for_attempt(attempt, exc)
|
|
ot.add_event(
|
|
"model.retry.scheduled",
|
|
{
|
|
"retry.attempt": attempt + 1,
|
|
"retry.max": self.max_retries,
|
|
"retry.delay_ms": int(delay * 1000),
|
|
"retry.reason": ot_metrics.categorize_exception(exc),
|
|
},
|
|
)
|
|
try:
|
|
dispatch_custom_event(
|
|
"surfsense.retrying",
|
|
{
|
|
"attempt": attempt + 1,
|
|
"max_retries": self.max_retries,
|
|
"delay_ms": int(delay * 1000),
|
|
"reason": type(exc).__name__,
|
|
},
|
|
)
|
|
except Exception:
|
|
logger.debug(
|
|
"dispatch_custom_event failed; suppressed", exc_info=True
|
|
)
|
|
if delay > 0:
|
|
time.sleep(delay)
|
|
# Unreachable
|
|
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
|
|
|
|
async def awrap_model_call( # type: ignore[override]
|
|
self,
|
|
request: ModelRequest[ContextT],
|
|
handler: Callable[
|
|
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
|
],
|
|
) -> ModelResponse[ResponseT] | AIMessage:
|
|
for attempt in range(self.max_retries + 1):
|
|
try:
|
|
return await handler(request)
|
|
except Exception as exc:
|
|
if not self._should_retry(exc) or attempt >= self.max_retries:
|
|
raise
|
|
delay = self._delay_for_attempt(attempt, exc)
|
|
ot.add_event(
|
|
"model.retry.scheduled",
|
|
{
|
|
"retry.attempt": attempt + 1,
|
|
"retry.max": self.max_retries,
|
|
"retry.delay_ms": int(delay * 1000),
|
|
"retry.reason": ot_metrics.categorize_exception(exc),
|
|
},
|
|
)
|
|
try:
|
|
await adispatch_custom_event(
|
|
"surfsense.retrying",
|
|
{
|
|
"attempt": attempt + 1,
|
|
"max_retries": self.max_retries,
|
|
"delay_ms": int(delay * 1000),
|
|
"reason": type(exc).__name__,
|
|
},
|
|
)
|
|
except Exception:
|
|
logger.debug(
|
|
"adispatch_custom_event failed; suppressed", exc_info=True
|
|
)
|
|
if delay > 0:
|
|
await asyncio.sleep(delay)
|
|
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
|
|
|
|
|
|
__all__ = [
|
|
"RetryAfterMiddleware",
|
|
"_extract_retry_after_seconds",
|
|
"_is_non_retryable",
|
|
]
|