From 1745d7dccf250fe489c1f9ff422491ca11661d4d Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Tue, 5 May 2026 18:04:47 +0200 Subject: [PATCH] feat(middleware): scope model fallback to provider/network errors only --- .../graph/middleware/deepagent_stack.py | 10 +- .../app/agents/new_chat/chat_deepagent.py | 10 +- .../middleware/scoped_model_fallback.py | 106 +++++++++++++ .../subagents/shared/test_subagent_builder.py | 18 +-- .../agents/new_chat/middleware/__init__.py | 0 .../middleware/test_scoped_model_fallback.py | 148 ++++++++++++++++++ 6 files changed, 275 insertions(+), 17 deletions(-) create mode 100644 surfsense_backend/app/agents/new_chat/middleware/scoped_model_fallback.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/middleware/__init__.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/middleware/test_scoped_model_fallback.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py index 8b7e3d0b0..e490b6b47 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py @@ -14,7 +14,6 @@ from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT from langchain.agents.middleware import ( LLMToolSelectorMiddleware, ModelCallLimitMiddleware, - ModelFallbackMiddleware, TodoListMiddleware, ToolCallLimitMiddleware, ) @@ -56,6 +55,9 @@ from app.agents.new_chat.middleware import ( create_surfsense_compaction_middleware, default_skills_sources, ) +from app.agents.new_chat.middleware.scoped_model_fallback import ( + ScopedModelFallbackMiddleware, +) from app.agents.new_chat.permissions import Rule, Ruleset from app.agents.new_chat.plugin_loader import ( PluginContext, @@ -217,15 +219,15 @@ def build_main_agent_deepagent_middleware( if flags.enable_retry_after and not flags.disable_new_agent_stack else None ) - fallback_mw: ModelFallbackMiddleware | None = None + fallback_mw: ScopedModelFallbackMiddleware | None = None if flags.enable_model_fallback and not flags.disable_new_agent_stack: try: - fallback_mw = ModelFallbackMiddleware( + fallback_mw = ScopedModelFallbackMiddleware( "openai:gpt-4o-mini", "anthropic:claude-3-5-haiku-20241022", ) except Exception: - logging.warning("ModelFallbackMiddleware init failed; skipping.") + logging.warning("ScopedModelFallbackMiddleware init failed; skipping.") fallback_mw = None registry_subagents: list[SubAgent] = [] diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 1f4024d9d..605c31416 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -31,7 +31,6 @@ from langchain.agents import create_agent from langchain.agents.middleware import ( LLMToolSelectorMiddleware, ModelCallLimitMiddleware, - ModelFallbackMiddleware, TodoListMiddleware, ToolCallLimitMiddleware, ) @@ -77,6 +76,9 @@ from app.agents.new_chat.middleware import ( create_surfsense_compaction_middleware, default_skills_sources, ) +from app.agents.new_chat.middleware.scoped_model_fallback import ( + ScopedModelFallbackMiddleware, +) from app.agents.new_chat.permissions import Rule, Ruleset from app.agents.new_chat.plugin_loader import ( PluginContext, @@ -792,15 +794,15 @@ def _build_compiled_agent_blocking( # Fallback chain — primary is the agent's own model; we add cheap # alternatives. Off by default; only the first call site that # configures the chain via env should enable it. - fallback_mw: ModelFallbackMiddleware | None = None + fallback_mw: ScopedModelFallbackMiddleware | None = None if flags.enable_model_fallback and not flags.disable_new_agent_stack: try: - fallback_mw = ModelFallbackMiddleware( + fallback_mw = ScopedModelFallbackMiddleware( "openai:gpt-4o-mini", "anthropic:claude-3-5-haiku-20241022", ) except Exception: - logging.warning("ModelFallbackMiddleware init failed; skipping.") + logging.warning("ScopedModelFallbackMiddleware init failed; skipping.") fallback_mw = None model_call_limit_mw = ( ModelCallLimitMiddleware( diff --git a/surfsense_backend/app/agents/new_chat/middleware/scoped_model_fallback.py b/surfsense_backend/app/agents/new_chat/middleware/scoped_model_fallback.py new file mode 100644 index 000000000..de367fda9 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/scoped_model_fallback.py @@ -0,0 +1,106 @@ +"""Fallback only on provider/network errors; let programming bugs raise. + +Upstream :class:`langchain.agents.middleware.ModelFallbackMiddleware` catches +every ``Exception``. With a non-provider bug (``KeyError``, ``TypeError``, +``AttributeError`` from middleware/state), every fallback model in the chain +hits the same bug — burning latency and tokens before the real cause finally +surfaces. Scoping the catch to provider-style exception types lets bugs fail +fast with clean tracebacks. + +Class-name matching (instead of ``isinstance`` against imported provider +types) keeps the dependency surface flat: openai, anthropic, google, +mistral, etc. all ship their own ``RateLimitError`` and we don't want to +import them all. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from langchain.agents.middleware import ModelFallbackMiddleware + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from langchain.agents.middleware.types import ModelRequest, ModelResponse + from langchain_core.messages import AIMessage + + +_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset( + { + # Rate / quota + "RateLimitError", + # Server-side + "APIStatusError", + "InternalServerError", + "ServiceUnavailableError", + "BadGatewayError", + "GatewayTimeoutError", + # Network + "APIConnectionError", + "APITimeoutError", + "ConnectError", + "ConnectTimeout", + "ReadTimeout", + "RemoteProtocolError", + "TimeoutError", + "TimeoutException", + } +) + + +def _is_fallback_eligible(exc: BaseException) -> bool: + """Eligible if the exception or any base in its MRO matches by class name.""" + return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__) + + +class ScopedModelFallbackMiddleware(ModelFallbackMiddleware): + """``ModelFallbackMiddleware`` that re-raises non-provider exceptions.""" + + def wrap_model_call( # type: ignore[override] + self, + request: ModelRequest[Any], + handler: Callable[[ModelRequest[Any]], ModelResponse[Any]], + ) -> ModelResponse[Any] | AIMessage: + last_exception: Exception + try: + return handler(request) + except Exception as e: + if not _is_fallback_eligible(e): + raise + last_exception = e + + for fallback_model in self.models: + try: + return handler(request.override(model=fallback_model)) + except Exception as e: + if not _is_fallback_eligible(e): + raise + last_exception = e + continue + + raise last_exception + + async def awrap_model_call( # type: ignore[override] + self, + request: ModelRequest[Any], + handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]], + ) -> ModelResponse[Any] | AIMessage: + last_exception: Exception + try: + return await handler(request) + except Exception as e: + if not _is_fallback_eligible(e): + raise + last_exception = e + + for fallback_model in self.models: + try: + return await handler(request.override(model=fallback_model)) + except Exception as e: + if not _is_fallback_eligible(e): + raise + last_exception = e + continue + + raise last_exception diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py index 82f66891a..859833f1c 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py @@ -31,12 +31,12 @@ from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( ) -class _AlwaysFailingChatModel(BaseChatModel): - """Mimics a provider hard-failing on every call (rate limit / empty stream). +class RateLimitError(Exception): + """Provider-style 429; matches the scoped-fallback eligibility allowlist by name.""" - ``ModelFallbackMiddleware`` triggers on any ``Exception``, so the exact - error type doesn't matter for the contract under test. - """ + +class _AlwaysFailingChatModel(BaseChatModel): + """Mimics a provider hard-failing on every call (rate limit / empty stream).""" @property def _llm_type(self) -> str: @@ -50,7 +50,7 @@ class _AlwaysFailingChatModel(BaseChatModel): **kwargs: Any, ) -> ChatResult: msg = "primary llm exploded" - raise RuntimeError(msg) + raise RateLimitError(msg) async def _agenerate( self, @@ -60,17 +60,17 @@ class _AlwaysFailingChatModel(BaseChatModel): **kwargs: Any, ) -> ChatResult: msg = "primary llm exploded" - raise RuntimeError(msg) + raise RateLimitError(msg) def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]: msg = "primary llm exploded" - raise RuntimeError(msg) + raise RateLimitError(msg) async def _astream( self, *args: Any, **kwargs: Any ) -> AsyncIterator[ChatGeneration]: msg = "primary llm exploded" - raise RuntimeError(msg) + raise RateLimitError(msg) yield # pragma: no cover - unreachable, satisfies async generator typing diff --git a/surfsense_backend/tests/unit/agents/new_chat/middleware/__init__.py b/surfsense_backend/tests/unit/agents/new_chat/middleware/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/new_chat/middleware/test_scoped_model_fallback.py b/surfsense_backend/tests/unit/agents/new_chat/middleware/test_scoped_model_fallback.py new file mode 100644 index 000000000..af464d1dc --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/middleware/test_scoped_model_fallback.py @@ -0,0 +1,148 @@ +"""Exception-scope contract for ``ScopedModelFallbackMiddleware``. + +Upstream ``ModelFallbackMiddleware`` catches every ``Exception`` and walks +the fallback chain. That means a programming bug (``KeyError`` from a +botched tool config, ``TypeError`` from middleware, ...) burns 1+N model +round-trips and ~Nx tokens before its real cause surfaces. The scoped +variant only falls back on provider/network exception types so bugs fail +fast, with clean tracebacks. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Iterator +from typing import Any + +import pytest +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.outputs import ChatGeneration, ChatResult + + +class _RaisingChatModel(BaseChatModel): + """LLM that raises a configurable exception on every invocation.""" + + exc_to_raise: Any + + @property + def _llm_type(self) -> str: + return "raising-test-model" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + raise self.exc_to_raise + + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + raise self.exc_to_raise + + def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]: + raise self.exc_to_raise + + async def _astream( + self, *args: Any, **kwargs: Any + ) -> AsyncIterator[ChatGeneration]: + raise self.exc_to_raise + yield # pragma: no cover - unreachable + + +class _RecordingChatModel(BaseChatModel): + """Returns a fixed message and counts how often it was called.""" + + response_text: str = "fallback-ok" + call_count: int = 0 + + @property + def _llm_type(self) -> str: + return "recording-test-model" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + self.call_count += 1 + return ChatResult( + generations=[ + ChatGeneration(message=AIMessage(content=self.response_text)) + ] + ) + + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + return self._generate(messages, stop, None, **kwargs) + + +# Locally defined provider-style error: importing openai/anthropic/etc. +# would couple the test to provider SDKs the contract intentionally avoids. +class RateLimitError(Exception): + """Mimics ``openai.RateLimitError`` for name-based eligibility.""" + + +def _build_agent(primary: BaseChatModel, fallback: BaseChatModel): + """Compile a no-tools agent with the scoped fallback wired in.""" + from langchain.agents import create_agent + + from app.agents.new_chat.middleware.scoped_model_fallback import ( + ScopedModelFallbackMiddleware, + ) + + return create_agent( + model=primary, + tools=[], + middleware=[ScopedModelFallbackMiddleware(fallback)], + system_prompt="be helpful", + ) + + +@pytest.mark.asyncio +async def test_provider_errors_trigger_fallback(): + """Class names matching the provider allowlist drive the fallback chain.""" + primary = _RaisingChatModel(exc_to_raise=RateLimitError("429 from provider")) + fallback = _RecordingChatModel(response_text="recovered") + + agent = _build_agent(primary, fallback) + result = await agent.ainvoke({"messages": [("user", "hi")]}) + + final = result["messages"][-1] + assert isinstance(final, AIMessage) + assert final.content == "recovered" + assert fallback.call_count == 1 + + +@pytest.mark.asyncio +async def test_programming_errors_propagate_without_invoking_fallback(): + """``KeyError`` from agent-side bugs must surface immediately, no fallback retry.""" + primary = _RaisingChatModel(exc_to_raise=KeyError("missing_state_field")) + fallback = _RecordingChatModel(response_text="should-never-arrive") + + agent = _build_agent(primary, fallback) + + with pytest.raises(KeyError, match="missing_state_field"): + await agent.ainvoke({"messages": [("user", "hi")]}) + + assert fallback.call_count == 0, ( + "fallback was invoked for a programming error; " + "scoping rule is broken" + )