diff --git a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py index 74e47cfab..8b7e3d0b0 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py +++ b/surfsense_backend/app/agents/multi_agent_chat/main_agent/graph/middleware/deepagent_stack.py @@ -208,6 +208,26 @@ def build_main_agent_deepagent_middleware( ) gp_middleware.insert(_patch_idx, subagent_deny_permission_mw) + # Defined here (instead of further down with the other ``wrap_model_call`` + # middlewares) so subagents share the same instances as the parent — + # otherwise a connector subagent would die on the first provider hiccup + # while the parent stays resilient. + retry_mw = ( + RetryAfterMiddleware(max_retries=3) + if flags.enable_retry_after and not flags.disable_new_agent_stack + else None + ) + fallback_mw: ModelFallbackMiddleware | None = None + if flags.enable_model_fallback and not flags.disable_new_agent_stack: + try: + fallback_mw = ModelFallbackMiddleware( + "openai:gpt-4o-mini", + "anthropic:claude-3-5-haiku-20241022", + ) + except Exception: + logging.warning("ModelFallbackMiddleware init failed; skipping.") + fallback_mw = None + registry_subagents: list[SubAgent] = [] try: subagent_extra_middleware: list[Any] = [ @@ -222,6 +242,10 @@ def build_main_agent_deepagent_middleware( ] if subagent_deny_permission_mw is not None: subagent_extra_middleware.append(subagent_deny_permission_mw) + if retry_mw is not None: + subagent_extra_middleware.append(retry_mw) + if fallback_mw is not None: + subagent_extra_middleware.append(fallback_mw) registry_subagents = build_subagents( dependencies=subagent_dependencies, model=llm, @@ -268,21 +292,6 @@ def build_main_agent_deepagent_middleware( backend_resolver=backend_resolver, ) - retry_mw = ( - RetryAfterMiddleware(max_retries=3) - if flags.enable_retry_after and not flags.disable_new_agent_stack - else None - ) - fallback_mw: ModelFallbackMiddleware | None = None - if flags.enable_model_fallback and not flags.disable_new_agent_stack: - try: - fallback_mw = ModelFallbackMiddleware( - "openai:gpt-4o-mini", - "anthropic:claude-3-5-haiku-20241022", - ) - except Exception: - logging.warning("ModelFallbackMiddleware init failed; skipping.") - fallback_mw = None model_call_limit_mw = ( ModelCallLimitMiddleware( thread_limit=120, diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/__init__.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/__init__.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py new file mode 100644 index 000000000..82f66891a --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/subagents/shared/test_subagent_builder.py @@ -0,0 +1,105 @@ +"""Resilience contract for subagents built via ``pack_subagent``. + +Subagents (jira, linear, notion, ...) run on the same LLM as the parent. When +the provider rate-limits or returns an empty stream, a single hiccup must not +abort the user's HITL flow — the connector subagent has to keep moving. This +relies on ``ModelFallbackMiddleware`` being usable as a subagent +``extra_middleware`` so the production builder can wire it in. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Iterator +from typing import Any + +import pytest +from langchain.agents import create_agent +from langchain.agents.middleware import ModelFallbackMiddleware +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.fake_chat_models import ( + FakeMessagesListChatModel, +) +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.outputs import ChatGeneration, ChatResult + +from app.agents.multi_agent_chat.subagents.shared.subagent_builder import ( + pack_subagent, +) + + +class _AlwaysFailingChatModel(BaseChatModel): + """Mimics a provider hard-failing on every call (rate limit / empty stream). + + ``ModelFallbackMiddleware`` triggers on any ``Exception``, so the exact + error type doesn't matter for the contract under test. + """ + + @property + def _llm_type(self) -> str: + return "always-failing-test-model" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + msg = "primary llm exploded" + raise RuntimeError(msg) + + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + msg = "primary llm exploded" + raise RuntimeError(msg) + + def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]: + msg = "primary llm exploded" + raise RuntimeError(msg) + + async def _astream( + self, *args: Any, **kwargs: Any + ) -> AsyncIterator[ChatGeneration]: + msg = "primary llm exploded" + raise RuntimeError(msg) + yield # pragma: no cover - unreachable, satisfies async generator typing + + +@pytest.mark.asyncio +async def test_subagent_recovers_when_primary_llm_fails(): + """Primary blows up → fallback in extra_middleware finishes the turn.""" + primary = _AlwaysFailingChatModel() + fallback = FakeMessagesListChatModel( + responses=[AIMessage(content="recovered via fallback")] + ) + + spec = pack_subagent( + name="resilience_test", + description="test subagent", + system_prompt="be helpful", + tools=[], + model=primary, + extra_middleware=[ModelFallbackMiddleware(fallback)], + ) + + agent = create_agent( + model=spec["model"], + tools=spec["tools"], + middleware=spec["middleware"], + system_prompt=spec["system_prompt"], + ) + + result = await agent.ainvoke({"messages": [HumanMessage(content="hi")]}) + + final = result["messages"][-1] + assert isinstance(final, AIMessage) + assert final.content == "recovered via fallback"