mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 14:22:47 +02:00
feat(middleware): scope model fallback to provider/network errors only
This commit is contained in:
parent
f695298d30
commit
1745d7dccf
6 changed files with 275 additions and 17 deletions
|
|
@ -14,7 +14,6 @@ from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
|||
from langchain.agents.middleware import (
|
||||
LLMToolSelectorMiddleware,
|
||||
ModelCallLimitMiddleware,
|
||||
ModelFallbackMiddleware,
|
||||
TodoListMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
|
|
@ -56,6 +55,9 @@ from app.agents.new_chat.middleware import (
|
|||
create_surfsense_compaction_middleware,
|
||||
default_skills_sources,
|
||||
)
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
PluginContext,
|
||||
|
|
@ -217,15 +219,15 @@ def build_main_agent_deepagent_middleware(
|
|||
if flags.enable_retry_after and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
fallback_mw: ModelFallbackMiddleware | None = None
|
||||
fallback_mw: ScopedModelFallbackMiddleware | None = None
|
||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
fallback_mw = ModelFallbackMiddleware(
|
||||
fallback_mw = ScopedModelFallbackMiddleware(
|
||||
"openai:gpt-4o-mini",
|
||||
"anthropic:claude-3-5-haiku-20241022",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
||||
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
||||
fallback_mw = None
|
||||
|
||||
registry_subagents: list[SubAgent] = []
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ from langchain.agents import create_agent
|
|||
from langchain.agents.middleware import (
|
||||
LLMToolSelectorMiddleware,
|
||||
ModelCallLimitMiddleware,
|
||||
ModelFallbackMiddleware,
|
||||
TodoListMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
|
|
@ -77,6 +76,9 @@ from app.agents.new_chat.middleware import (
|
|||
create_surfsense_compaction_middleware,
|
||||
default_skills_sources,
|
||||
)
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
PluginContext,
|
||||
|
|
@ -792,15 +794,15 @@ def _build_compiled_agent_blocking(
|
|||
# Fallback chain — primary is the agent's own model; we add cheap
|
||||
# alternatives. Off by default; only the first call site that
|
||||
# configures the chain via env should enable it.
|
||||
fallback_mw: ModelFallbackMiddleware | None = None
|
||||
fallback_mw: ScopedModelFallbackMiddleware | None = None
|
||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
fallback_mw = ModelFallbackMiddleware(
|
||||
fallback_mw = ScopedModelFallbackMiddleware(
|
||||
"openai:gpt-4o-mini",
|
||||
"anthropic:claude-3-5-haiku-20241022",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
||||
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
||||
fallback_mw = None
|
||||
model_call_limit_mw = (
|
||||
ModelCallLimitMiddleware(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,106 @@
|
|||
"""Fallback only on provider/network errors; let programming bugs raise.
|
||||
|
||||
Upstream :class:`langchain.agents.middleware.ModelFallbackMiddleware` catches
|
||||
every ``Exception``. With a non-provider bug (``KeyError``, ``TypeError``,
|
||||
``AttributeError`` from middleware/state), every fallback model in the chain
|
||||
hits the same bug — burning latency and tokens before the real cause finally
|
||||
surfaces. Scoping the catch to provider-style exception types lets bugs fail
|
||||
fast with clean tracebacks.
|
||||
|
||||
Class-name matching (instead of ``isinstance`` against imported provider
|
||||
types) keeps the dependency surface flat: openai, anthropic, google,
|
||||
mistral, etc. all ship their own ``RateLimitError`` and we don't want to
|
||||
import them all.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
# Rate / quota
|
||||
"RateLimitError",
|
||||
# Server-side
|
||||
"APIStatusError",
|
||||
"InternalServerError",
|
||||
"ServiceUnavailableError",
|
||||
"BadGatewayError",
|
||||
"GatewayTimeoutError",
|
||||
# Network
|
||||
"APIConnectionError",
|
||||
"APITimeoutError",
|
||||
"ConnectError",
|
||||
"ConnectTimeout",
|
||||
"ReadTimeout",
|
||||
"RemoteProtocolError",
|
||||
"TimeoutError",
|
||||
"TimeoutException",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_fallback_eligible(exc: BaseException) -> bool:
|
||||
"""Eligible if the exception or any base in its MRO matches by class name."""
|
||||
return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__)
|
||||
|
||||
|
||||
class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
|
||||
"""``ModelFallbackMiddleware`` that re-raises non-provider exceptions."""
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[Any],
|
||||
handler: Callable[[ModelRequest[Any]], ModelResponse[Any]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
last_exception: Exception
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
|
||||
for fallback_model in self.models:
|
||||
try:
|
||||
return handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[Any],
|
||||
handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
last_exception: Exception
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
|
||||
for fallback_model in self.models:
|
||||
try:
|
||||
return await handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
Loading…
Add table
Add a link
Reference in a new issue