mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 22:32:39 +02:00
feat(multi-agent): wire model fallback and retry into subagent middleware
This commit is contained in:
parent
fa6f3015a9
commit
f695298d30
4 changed files with 129 additions and 15 deletions
|
|
@ -208,6 +208,26 @@ def build_main_agent_deepagent_middleware(
|
||||||
)
|
)
|
||||||
gp_middleware.insert(_patch_idx, subagent_deny_permission_mw)
|
gp_middleware.insert(_patch_idx, subagent_deny_permission_mw)
|
||||||
|
|
||||||
|
# Defined here (instead of further down with the other ``wrap_model_call``
|
||||||
|
# middlewares) so subagents share the same instances as the parent —
|
||||||
|
# otherwise a connector subagent would die on the first provider hiccup
|
||||||
|
# while the parent stays resilient.
|
||||||
|
retry_mw = (
|
||||||
|
RetryAfterMiddleware(max_retries=3)
|
||||||
|
if flags.enable_retry_after and not flags.disable_new_agent_stack
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
fallback_mw: ModelFallbackMiddleware | None = None
|
||||||
|
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||||
|
try:
|
||||||
|
fallback_mw = ModelFallbackMiddleware(
|
||||||
|
"openai:gpt-4o-mini",
|
||||||
|
"anthropic:claude-3-5-haiku-20241022",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
||||||
|
fallback_mw = None
|
||||||
|
|
||||||
registry_subagents: list[SubAgent] = []
|
registry_subagents: list[SubAgent] = []
|
||||||
try:
|
try:
|
||||||
subagent_extra_middleware: list[Any] = [
|
subagent_extra_middleware: list[Any] = [
|
||||||
|
|
@ -222,6 +242,10 @@ def build_main_agent_deepagent_middleware(
|
||||||
]
|
]
|
||||||
if subagent_deny_permission_mw is not None:
|
if subagent_deny_permission_mw is not None:
|
||||||
subagent_extra_middleware.append(subagent_deny_permission_mw)
|
subagent_extra_middleware.append(subagent_deny_permission_mw)
|
||||||
|
if retry_mw is not None:
|
||||||
|
subagent_extra_middleware.append(retry_mw)
|
||||||
|
if fallback_mw is not None:
|
||||||
|
subagent_extra_middleware.append(fallback_mw)
|
||||||
registry_subagents = build_subagents(
|
registry_subagents = build_subagents(
|
||||||
dependencies=subagent_dependencies,
|
dependencies=subagent_dependencies,
|
||||||
model=llm,
|
model=llm,
|
||||||
|
|
@ -268,21 +292,6 @@ def build_main_agent_deepagent_middleware(
|
||||||
backend_resolver=backend_resolver,
|
backend_resolver=backend_resolver,
|
||||||
)
|
)
|
||||||
|
|
||||||
retry_mw = (
|
|
||||||
RetryAfterMiddleware(max_retries=3)
|
|
||||||
if flags.enable_retry_after and not flags.disable_new_agent_stack
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
fallback_mw: ModelFallbackMiddleware | None = None
|
|
||||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
|
||||||
try:
|
|
||||||
fallback_mw = ModelFallbackMiddleware(
|
|
||||||
"openai:gpt-4o-mini",
|
|
||||||
"anthropic:claude-3-5-haiku-20241022",
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
|
||||||
fallback_mw = None
|
|
||||||
model_call_limit_mw = (
|
model_call_limit_mw = (
|
||||||
ModelCallLimitMiddleware(
|
ModelCallLimitMiddleware(
|
||||||
thread_limit=120,
|
thread_limit=120,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,105 @@
|
||||||
|
"""Resilience contract for subagents built via ``pack_subagent``.
|
||||||
|
|
||||||
|
Subagents (jira, linear, notion, ...) run on the same LLM as the parent. When
|
||||||
|
the provider rate-limits or returns an empty stream, a single hiccup must not
|
||||||
|
abort the user's HITL flow — the connector subagent has to keep moving. This
|
||||||
|
relies on ``ModelFallbackMiddleware`` being usable as a subagent
|
||||||
|
``extra_middleware`` so the production builder can wire it in.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
from langchain_core.language_models.fake_chat_models import (
|
||||||
|
FakeMessagesListChatModel,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||||
|
pack_subagent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _AlwaysFailingChatModel(BaseChatModel):
|
||||||
|
"""Mimics a provider hard-failing on every call (rate limit / empty stream).
|
||||||
|
|
||||||
|
``ModelFallbackMiddleware`` triggers on any ``Exception``, so the exact
|
||||||
|
error type doesn't matter for the contract under test.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "always-failing-test-model"
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
msg = "primary llm exploded"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
msg = "primary llm exploded"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
|
||||||
|
msg = "primary llm exploded"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self, *args: Any, **kwargs: Any
|
||||||
|
) -> AsyncIterator[ChatGeneration]:
|
||||||
|
msg = "primary llm exploded"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
yield # pragma: no cover - unreachable, satisfies async generator typing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subagent_recovers_when_primary_llm_fails():
|
||||||
|
"""Primary blows up → fallback in extra_middleware finishes the turn."""
|
||||||
|
primary = _AlwaysFailingChatModel()
|
||||||
|
fallback = FakeMessagesListChatModel(
|
||||||
|
responses=[AIMessage(content="recovered via fallback")]
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = pack_subagent(
|
||||||
|
name="resilience_test",
|
||||||
|
description="test subagent",
|
||||||
|
system_prompt="be helpful",
|
||||||
|
tools=[],
|
||||||
|
model=primary,
|
||||||
|
extra_middleware=[ModelFallbackMiddleware(fallback)],
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = create_agent(
|
||||||
|
model=spec["model"],
|
||||||
|
tools=spec["tools"],
|
||||||
|
middleware=spec["middleware"],
|
||||||
|
system_prompt=spec["system_prompt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await agent.ainvoke({"messages": [HumanMessage(content="hi")]})
|
||||||
|
|
||||||
|
final = result["messages"][-1]
|
||||||
|
assert isinstance(final, AIMessage)
|
||||||
|
assert final.content == "recovered via fallback"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue