mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 14:22:47 +02:00
feat(middleware): scope model fallback to provider/network errors only
This commit is contained in:
parent
f695298d30
commit
1745d7dccf
6 changed files with 275 additions and 17 deletions
|
|
@ -14,7 +14,6 @@ from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||||
from langchain.agents.middleware import (
|
from langchain.agents.middleware import (
|
||||||
LLMToolSelectorMiddleware,
|
LLMToolSelectorMiddleware,
|
||||||
ModelCallLimitMiddleware,
|
ModelCallLimitMiddleware,
|
||||||
ModelFallbackMiddleware,
|
|
||||||
TodoListMiddleware,
|
TodoListMiddleware,
|
||||||
ToolCallLimitMiddleware,
|
ToolCallLimitMiddleware,
|
||||||
)
|
)
|
||||||
|
|
@ -56,6 +55,9 @@ from app.agents.new_chat.middleware import (
|
||||||
create_surfsense_compaction_middleware,
|
create_surfsense_compaction_middleware,
|
||||||
default_skills_sources,
|
default_skills_sources,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||||
|
ScopedModelFallbackMiddleware,
|
||||||
|
)
|
||||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||||
from app.agents.new_chat.plugin_loader import (
|
from app.agents.new_chat.plugin_loader import (
|
||||||
PluginContext,
|
PluginContext,
|
||||||
|
|
@ -217,15 +219,15 @@ def build_main_agent_deepagent_middleware(
|
||||||
if flags.enable_retry_after and not flags.disable_new_agent_stack
|
if flags.enable_retry_after and not flags.disable_new_agent_stack
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
fallback_mw: ModelFallbackMiddleware | None = None
|
fallback_mw: ScopedModelFallbackMiddleware | None = None
|
||||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||||
try:
|
try:
|
||||||
fallback_mw = ModelFallbackMiddleware(
|
fallback_mw = ScopedModelFallbackMiddleware(
|
||||||
"openai:gpt-4o-mini",
|
"openai:gpt-4o-mini",
|
||||||
"anthropic:claude-3-5-haiku-20241022",
|
"anthropic:claude-3-5-haiku-20241022",
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
||||||
fallback_mw = None
|
fallback_mw = None
|
||||||
|
|
||||||
registry_subagents: list[SubAgent] = []
|
registry_subagents: list[SubAgent] = []
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,6 @@ from langchain.agents import create_agent
|
||||||
from langchain.agents.middleware import (
|
from langchain.agents.middleware import (
|
||||||
LLMToolSelectorMiddleware,
|
LLMToolSelectorMiddleware,
|
||||||
ModelCallLimitMiddleware,
|
ModelCallLimitMiddleware,
|
||||||
ModelFallbackMiddleware,
|
|
||||||
TodoListMiddleware,
|
TodoListMiddleware,
|
||||||
ToolCallLimitMiddleware,
|
ToolCallLimitMiddleware,
|
||||||
)
|
)
|
||||||
|
|
@ -77,6 +76,9 @@ from app.agents.new_chat.middleware import (
|
||||||
create_surfsense_compaction_middleware,
|
create_surfsense_compaction_middleware,
|
||||||
default_skills_sources,
|
default_skills_sources,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||||
|
ScopedModelFallbackMiddleware,
|
||||||
|
)
|
||||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||||
from app.agents.new_chat.plugin_loader import (
|
from app.agents.new_chat.plugin_loader import (
|
||||||
PluginContext,
|
PluginContext,
|
||||||
|
|
@ -792,15 +794,15 @@ def _build_compiled_agent_blocking(
|
||||||
# Fallback chain — primary is the agent's own model; we add cheap
|
# Fallback chain — primary is the agent's own model; we add cheap
|
||||||
# alternatives. Off by default; only the first call site that
|
# alternatives. Off by default; only the first call site that
|
||||||
# configures the chain via env should enable it.
|
# configures the chain via env should enable it.
|
||||||
fallback_mw: ModelFallbackMiddleware | None = None
|
fallback_mw: ScopedModelFallbackMiddleware | None = None
|
||||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||||
try:
|
try:
|
||||||
fallback_mw = ModelFallbackMiddleware(
|
fallback_mw = ScopedModelFallbackMiddleware(
|
||||||
"openai:gpt-4o-mini",
|
"openai:gpt-4o-mini",
|
||||||
"anthropic:claude-3-5-haiku-20241022",
|
"anthropic:claude-3-5-haiku-20241022",
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
||||||
fallback_mw = None
|
fallback_mw = None
|
||||||
model_call_limit_mw = (
|
model_call_limit_mw = (
|
||||||
ModelCallLimitMiddleware(
|
ModelCallLimitMiddleware(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,106 @@
|
||||||
|
"""Fallback only on provider/network errors; let programming bugs raise.
|
||||||
|
|
||||||
|
Upstream :class:`langchain.agents.middleware.ModelFallbackMiddleware` catches
|
||||||
|
every ``Exception``. With a non-provider bug (``KeyError``, ``TypeError``,
|
||||||
|
``AttributeError`` from middleware/state), every fallback model in the chain
|
||||||
|
hits the same bug — burning latency and tokens before the real cause finally
|
||||||
|
surfaces. Scoping the catch to provider-style exception types lets bugs fail
|
||||||
|
fast with clean tracebacks.
|
||||||
|
|
||||||
|
Class-name matching (instead of ``isinstance`` against imported provider
|
||||||
|
types) keeps the dependency surface flat: openai, anthropic, google,
|
||||||
|
mistral, etc. all ship their own ``RateLimitError`` and we don't want to
|
||||||
|
import them all.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
|
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
|
||||||
|
_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
# Rate / quota
|
||||||
|
"RateLimitError",
|
||||||
|
# Server-side
|
||||||
|
"APIStatusError",
|
||||||
|
"InternalServerError",
|
||||||
|
"ServiceUnavailableError",
|
||||||
|
"BadGatewayError",
|
||||||
|
"GatewayTimeoutError",
|
||||||
|
# Network
|
||||||
|
"APIConnectionError",
|
||||||
|
"APITimeoutError",
|
||||||
|
"ConnectError",
|
||||||
|
"ConnectTimeout",
|
||||||
|
"ReadTimeout",
|
||||||
|
"RemoteProtocolError",
|
||||||
|
"TimeoutError",
|
||||||
|
"TimeoutException",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_fallback_eligible(exc: BaseException) -> bool:
|
||||||
|
"""Eligible if the exception or any base in its MRO matches by class name."""
|
||||||
|
return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__)
|
||||||
|
|
||||||
|
|
||||||
|
class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
|
||||||
|
"""``ModelFallbackMiddleware`` that re-raises non-provider exceptions."""
|
||||||
|
|
||||||
|
def wrap_model_call( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest[Any],
|
||||||
|
handler: Callable[[ModelRequest[Any]], ModelResponse[Any]],
|
||||||
|
) -> ModelResponse[Any] | AIMessage:
|
||||||
|
last_exception: Exception
|
||||||
|
try:
|
||||||
|
return handler(request)
|
||||||
|
except Exception as e:
|
||||||
|
if not _is_fallback_eligible(e):
|
||||||
|
raise
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
for fallback_model in self.models:
|
||||||
|
try:
|
||||||
|
return handler(request.override(model=fallback_model))
|
||||||
|
except Exception as e:
|
||||||
|
if not _is_fallback_eligible(e):
|
||||||
|
raise
|
||||||
|
last_exception = e
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
async def awrap_model_call( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest[Any],
|
||||||
|
handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]],
|
||||||
|
) -> ModelResponse[Any] | AIMessage:
|
||||||
|
last_exception: Exception
|
||||||
|
try:
|
||||||
|
return await handler(request)
|
||||||
|
except Exception as e:
|
||||||
|
if not _is_fallback_eligible(e):
|
||||||
|
raise
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
for fallback_model in self.models:
|
||||||
|
try:
|
||||||
|
return await handler(request.override(model=fallback_model))
|
||||||
|
except Exception as e:
|
||||||
|
if not _is_fallback_eligible(e):
|
||||||
|
raise
|
||||||
|
last_exception = e
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise last_exception
|
||||||
|
|
@ -31,12 +31,12 @@ from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class _AlwaysFailingChatModel(BaseChatModel):
|
class RateLimitError(Exception):
|
||||||
"""Mimics a provider hard-failing on every call (rate limit / empty stream).
|
"""Provider-style 429; matches the scoped-fallback eligibility allowlist by name."""
|
||||||
|
|
||||||
``ModelFallbackMiddleware`` triggers on any ``Exception``, so the exact
|
|
||||||
error type doesn't matter for the contract under test.
|
class _AlwaysFailingChatModel(BaseChatModel):
|
||||||
"""
|
"""Mimics a provider hard-failing on every call (rate limit / empty stream)."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
|
|
@ -50,7 +50,7 @@ class _AlwaysFailingChatModel(BaseChatModel):
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
msg = "primary llm exploded"
|
msg = "primary llm exploded"
|
||||||
raise RuntimeError(msg)
|
raise RateLimitError(msg)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
|
|
@ -60,17 +60,17 @@ class _AlwaysFailingChatModel(BaseChatModel):
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
msg = "primary llm exploded"
|
msg = "primary llm exploded"
|
||||||
raise RuntimeError(msg)
|
raise RateLimitError(msg)
|
||||||
|
|
||||||
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
|
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
|
||||||
msg = "primary llm exploded"
|
msg = "primary llm exploded"
|
||||||
raise RuntimeError(msg)
|
raise RateLimitError(msg)
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self, *args: Any, **kwargs: Any
|
self, *args: Any, **kwargs: Any
|
||||||
) -> AsyncIterator[ChatGeneration]:
|
) -> AsyncIterator[ChatGeneration]:
|
||||||
msg = "primary llm exploded"
|
msg = "primary llm exploded"
|
||||||
raise RuntimeError(msg)
|
raise RateLimitError(msg)
|
||||||
yield # pragma: no cover - unreachable, satisfies async generator typing
|
yield # pragma: no cover - unreachable, satisfies async generator typing
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,148 @@
|
||||||
|
"""Exception-scope contract for ``ScopedModelFallbackMiddleware``.
|
||||||
|
|
||||||
|
Upstream ``ModelFallbackMiddleware`` catches every ``Exception`` and walks
|
||||||
|
the fallback chain. That means a programming bug (``KeyError`` from a
|
||||||
|
botched tool config, ``TypeError`` from middleware, ...) burns 1+N model
|
||||||
|
round-trips and ~Nx tokens before its real cause surfaces. The scoped
|
||||||
|
variant only falls back on provider/network exception types so bugs fail
|
||||||
|
fast, with clean tracebacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
|
|
||||||
|
|
||||||
|
class _RaisingChatModel(BaseChatModel):
|
||||||
|
"""LLM that raises a configurable exception on every invocation."""
|
||||||
|
|
||||||
|
exc_to_raise: Any
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "raising-test-model"
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
raise self.exc_to_raise
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
raise self.exc_to_raise
|
||||||
|
|
||||||
|
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
|
||||||
|
raise self.exc_to_raise
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self, *args: Any, **kwargs: Any
|
||||||
|
) -> AsyncIterator[ChatGeneration]:
|
||||||
|
raise self.exc_to_raise
|
||||||
|
yield # pragma: no cover - unreachable
|
||||||
|
|
||||||
|
|
||||||
|
class _RecordingChatModel(BaseChatModel):
|
||||||
|
"""Returns a fixed message and counts how often it was called."""
|
||||||
|
|
||||||
|
response_text: str = "fallback-ok"
|
||||||
|
call_count: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "recording-test-model"
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
self.call_count += 1
|
||||||
|
return ChatResult(
|
||||||
|
generations=[
|
||||||
|
ChatGeneration(message=AIMessage(content=self.response_text))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
return self._generate(messages, stop, None, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Locally defined provider-style error: importing openai/anthropic/etc.
|
||||||
|
# would couple the test to provider SDKs the contract intentionally avoids.
|
||||||
|
class RateLimitError(Exception):
|
||||||
|
"""Mimics ``openai.RateLimitError`` for name-based eligibility."""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_agent(primary: BaseChatModel, fallback: BaseChatModel):
|
||||||
|
"""Compile a no-tools agent with the scoped fallback wired in."""
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||||
|
ScopedModelFallbackMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_agent(
|
||||||
|
model=primary,
|
||||||
|
tools=[],
|
||||||
|
middleware=[ScopedModelFallbackMiddleware(fallback)],
|
||||||
|
system_prompt="be helpful",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_errors_trigger_fallback():
|
||||||
|
"""Class names matching the provider allowlist drive the fallback chain."""
|
||||||
|
primary = _RaisingChatModel(exc_to_raise=RateLimitError("429 from provider"))
|
||||||
|
fallback = _RecordingChatModel(response_text="recovered")
|
||||||
|
|
||||||
|
agent = _build_agent(primary, fallback)
|
||||||
|
result = await agent.ainvoke({"messages": [("user", "hi")]})
|
||||||
|
|
||||||
|
final = result["messages"][-1]
|
||||||
|
assert isinstance(final, AIMessage)
|
||||||
|
assert final.content == "recovered"
|
||||||
|
assert fallback.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_programming_errors_propagate_without_invoking_fallback():
|
||||||
|
"""``KeyError`` from agent-side bugs must surface immediately, no fallback retry."""
|
||||||
|
primary = _RaisingChatModel(exc_to_raise=KeyError("missing_state_field"))
|
||||||
|
fallback = _RecordingChatModel(response_text="should-never-arrive")
|
||||||
|
|
||||||
|
agent = _build_agent(primary, fallback)
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match="missing_state_field"):
|
||||||
|
await agent.ainvoke({"messages": [("user", "hi")]})
|
||||||
|
|
||||||
|
assert fallback.call_count == 0, (
|
||||||
|
"fallback was invoked for a programming error; "
|
||||||
|
"scoping rule is broken"
|
||||||
|
)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue