Merge pull request #1351 from CREDO23/feature/multi-agent

[Improvement] Modular middleware stack + agent/prompt caching + subagent resilience + unit tests
This commit is contained in:
Rohan Verma 2026-05-05 16:21:48 -07:00 committed by GitHub
commit a4fc812b85
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
70 changed files with 2037 additions and 547 deletions

View file

@ -2,6 +2,6 @@
from __future__ import annotations
from .main_agent import create_surfsense_deep_agent
from .main_agent import create_multi_agent_chat_deep_agent
__all__ = ["create_surfsense_deep_agent"]
__all__ = ["create_multi_agent_chat_deep_agent"]

View file

@ -2,6 +2,6 @@
from __future__ import annotations
from .runtime import create_surfsense_deep_agent
from .runtime import create_multi_agent_chat_deep_agent
__all__ = ["create_surfsense_deep_agent"]
__all__ = ["create_multi_agent_chat_deep_agent"]

View file

@ -11,6 +11,9 @@ from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from app.agents.multi_agent_chat.middleware import (
build_main_agent_deepagent_middleware,
)
from app.agents.multi_agent_chat.subagents.shared.permissions import (
ToolsPermissions,
)
@ -19,8 +22,6 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.db import ChatVisibility
from .middleware import build_main_agent_deepagent_middleware
def build_compiled_agent_graph_sync(
*,

View file

@ -1,7 +0,0 @@
"""Main-agent graph middleware assembly (SurfSense + LangChain + deepagents)."""
from __future__ import annotations
from .deepagent_stack import build_main_agent_deepagent_middleware
__all__ = ["build_main_agent_deepagent_middleware"]

View file

@ -1,506 +0,0 @@
"""Assemble the main-agent deep-agent middleware list (LangChain + SurfSense + deepagents)."""
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import Any
from deepagents import SubAgent
from deepagents.backends import StateBackend
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
from deepagents.middleware.skills import SkillsMiddleware
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
from langchain.agents.middleware import (
LLMToolSelectorMiddleware,
ModelCallLimitMiddleware,
ModelFallbackMiddleware,
TodoListMiddleware,
ToolCallLimitMiddleware,
)
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from app.agents.multi_agent_chat.subagents import (
build_subagents,
get_subagents_to_exclude,
)
from app.agents.multi_agent_chat.subagents.shared.permissions import (
ToolsPermissions,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import (
ActionLogMiddleware,
AnonymousDocumentMiddleware,
BusyMutexMiddleware,
ClearToolUsesEdit,
DedupHITLToolCallsMiddleware,
DoomLoopMiddleware,
FileIntentMiddleware,
KnowledgeBasePersistenceMiddleware,
KnowledgePriorityMiddleware,
KnowledgeTreeMiddleware,
MemoryInjectionMiddleware,
NoopInjectionMiddleware,
OtelSpanMiddleware,
PermissionMiddleware,
RetryAfterMiddleware,
SpillingContextEditingMiddleware,
SpillToBackendEdit,
SurfSenseFilesystemMiddleware,
ToolCallNameRepairMiddleware,
build_skills_backend_factory,
create_surfsense_compaction_middleware,
default_skills_sources,
)
from app.agents.new_chat.permissions import Rule, Ruleset
from app.agents.new_chat.plugin_loader import (
PluginContext,
load_allowed_plugin_names_from_env,
load_plugin_middlewares,
)
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
from app.db import ChatVisibility
from ...context_prune.prune_tool_names import safe_exclude_tools
from .checkpointed_subagent_middleware import SurfSenseCheckpointedSubAgentMiddleware
def build_main_agent_deepagent_middleware(
*,
llm: BaseChatModel,
tools: Sequence[BaseTool],
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
visibility: ChatVisibility,
anon_session_id: str | None,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
max_input_tokens: int | None,
flags: AgentFeatureFlags,
subagent_dependencies: dict[str, Any],
checkpointer: Checkpointer,
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
disabled_tools: list[str] | None = None,
) -> list[Any]:
"""Build ordered middleware for ``create_agent`` (Nones already stripped)."""
_memory_middleware = MemoryInjectionMiddleware(
user_id=user_id,
search_space_id=search_space_id,
thread_visibility=visibility,
)
gp_middleware = [
TodoListMiddleware(),
_memory_middleware,
FileIntentMiddleware(llm=llm),
SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
created_by_id=user_id,
thread_id=thread_id,
),
create_surfsense_compaction_middleware(llm, StateBackend),
PatchToolCallsMiddleware(),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
# Build permission rulesets up front so the GP subagent can mirror ``ask``
# rules into ``interrupt_on``: tool calls emitted from within ``task`` runs
# never reach the parent's ``PermissionMiddleware``.
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack
permission_rulesets: list[Ruleset] = []
if permission_enabled or is_desktop_fs:
permission_rulesets.append(
Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
)
)
if is_desktop_fs:
permission_rulesets.append(
Ruleset(
rules=[
Rule(permission="rm", pattern="*", action="ask"),
Rule(permission="rmdir", pattern="*", action="ask"),
Rule(permission="move_file", pattern="*", action="ask"),
Rule(permission="edit_file", pattern="*", action="ask"),
Rule(permission="write_file", pattern="*", action="ask"),
],
origin="desktop_safety",
)
)
# Tools that self-prompt via ``request_approval`` must not also appear
# as ``ask`` rules — that would double-prompt the user for one call.
_tool_names_in_use = {t.name for t in tools}
# Deny parent-bound tools whose ``required_connector`` is missing.
# No-op today (connector subagents are pruned upstream); guards future
# additions to the parent's tool list.
if permission_enabled:
_available_set = set(available_connectors or [])
_synthesized: list[Rule] = []
for tool_def in BUILTIN_TOOLS:
if tool_def.name not in _tool_names_in_use:
continue
rc = tool_def.required_connector
if rc and rc not in _available_set:
_synthesized.append(
Rule(permission=tool_def.name, pattern="*", action="deny")
)
if _synthesized:
permission_rulesets.append(
Ruleset(rules=_synthesized, origin="connector_synthesized")
)
gp_interrupt_on: dict[str, bool] = {
rule.permission: True
for rs in permission_rulesets
for rule in rs.rules
if rule.action == "ask" and rule.permission in _tool_names_in_use
}
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
**GENERAL_PURPOSE_SUBAGENT,
"model": llm,
"tools": tools,
"middleware": gp_middleware,
}
if gp_interrupt_on:
general_purpose_spec["interrupt_on"] = gp_interrupt_on
# Deny-only on subagents: ``task`` runs bypass the parent's
# PermissionMiddleware, while bucket-based ask gates own the ask path.
subagent_deny_rulesets: list[Ruleset] = [
Ruleset(
rules=[r for r in rs.rules if r.action == "deny"],
origin=rs.origin,
)
for rs in permission_rulesets
]
subagent_deny_rulesets = [rs for rs in subagent_deny_rulesets if rs.rules]
subagent_deny_permission_mw: PermissionMiddleware | None = (
PermissionMiddleware(rulesets=subagent_deny_rulesets)
if subagent_deny_rulesets
else None
)
if subagent_deny_permission_mw is not None:
# Run deny check on already-repaired tool calls; insert before
# PatchToolCallsMiddleware (append if the slot moves).
_patch_idx = next(
(
i
for i, m in enumerate(gp_middleware)
if isinstance(m, PatchToolCallsMiddleware)
),
len(gp_middleware),
)
gp_middleware.insert(_patch_idx, subagent_deny_permission_mw)
registry_subagents: list[SubAgent] = []
try:
subagent_extra_middleware: list[Any] = [
TodoListMiddleware(),
SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
created_by_id=user_id,
thread_id=thread_id,
),
]
if subagent_deny_permission_mw is not None:
subagent_extra_middleware.append(subagent_deny_permission_mw)
registry_subagents = build_subagents(
dependencies=subagent_dependencies,
model=llm,
extra_middleware=subagent_extra_middleware,
mcp_tools_by_agent=mcp_tools_by_agent or {},
exclude=get_subagents_to_exclude(available_connectors),
disabled_tools=disabled_tools,
)
logging.info(
"Registry subagents: %s",
[s["name"] for s in registry_subagents],
)
except Exception:
logging.exception("Registry subagent build failed")
raise
subagent_specs: list[SubAgent] = [general_purpose_spec, *registry_subagents]
summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend)
context_edit_mw = None
if (
flags.enable_context_editing
and not flags.disable_new_agent_stack
and max_input_tokens
):
spill_edit = SpillToBackendEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=safe_exclude_tools(tools),
clear_tool_inputs=True,
)
clear_edit = ClearToolUsesEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=safe_exclude_tools(tools),
clear_tool_inputs=True,
placeholder="[cleared - older tool output trimmed for context]",
)
context_edit_mw = SpillingContextEditingMiddleware(
edits=[spill_edit, clear_edit],
backend_resolver=backend_resolver,
)
retry_mw = (
RetryAfterMiddleware(max_retries=3)
if flags.enable_retry_after and not flags.disable_new_agent_stack
else None
)
fallback_mw: ModelFallbackMiddleware | None = None
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
try:
fallback_mw = ModelFallbackMiddleware(
"openai:gpt-4o-mini",
"anthropic:claude-3-5-haiku-20241022",
)
except Exception:
logging.warning("ModelFallbackMiddleware init failed; skipping.")
fallback_mw = None
model_call_limit_mw = (
ModelCallLimitMiddleware(
thread_limit=120,
run_limit=80,
exit_behavior="end",
)
if flags.enable_model_call_limit and not flags.disable_new_agent_stack
else None
)
tool_call_limit_mw = (
ToolCallLimitMiddleware(
thread_limit=300, run_limit=80, exit_behavior="continue"
)
if flags.enable_tool_call_limit and not flags.disable_new_agent_stack
else None
)
noop_mw = (
NoopInjectionMiddleware()
if flags.enable_compaction_v2 and not flags.disable_new_agent_stack
else None
)
repair_mw = None
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
registered_names: set[str] = {t.name for t in tools}
registered_names |= {
"write_todos",
"ls",
"read_file",
"write_file",
"edit_file",
"glob",
"grep",
"execute",
"task",
"mkdir",
"cd",
"pwd",
"move_file",
"rm",
"rmdir",
"list_tree",
"execute_code",
}
repair_mw = ToolCallNameRepairMiddleware(
registered_tool_names=registered_names,
fuzzy_match_threshold=None,
)
doom_loop_mw = (
DoomLoopMiddleware(threshold=3)
if flags.enable_doom_loop and not flags.disable_new_agent_stack
else None
)
permission_mw: PermissionMiddleware | None = (
PermissionMiddleware(rulesets=permission_rulesets)
if permission_rulesets
else None
)
action_log_mw: ActionLogMiddleware | None = None
if (
flags.enable_action_log
and not flags.disable_new_agent_stack
and thread_id is not None
):
try:
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
action_log_mw = ActionLogMiddleware(
thread_id=thread_id,
search_space_id=search_space_id,
user_id=user_id,
tool_definitions=tool_defs_by_name,
)
except Exception: # pragma: no cover - defensive
logging.warning(
"ActionLogMiddleware init failed; running without it.",
exc_info=True,
)
action_log_mw = None
busy_mutex_mw: BusyMutexMiddleware | None = (
BusyMutexMiddleware()
if flags.enable_busy_mutex and not flags.disable_new_agent_stack
else None
)
otel_mw: OtelSpanMiddleware | None = (
OtelSpanMiddleware()
if flags.enable_otel and not flags.disable_new_agent_stack
else None
)
plugin_middlewares: list[Any] = []
if flags.enable_plugin_loader and not flags.disable_new_agent_stack:
try:
allowed_names = load_allowed_plugin_names_from_env()
if allowed_names:
plugin_middlewares = load_plugin_middlewares(
PluginContext.build(
search_space_id=search_space_id,
user_id=user_id,
thread_visibility=visibility,
llm=llm,
),
allowed_plugin_names=allowed_names,
)
except Exception: # pragma: no cover - defensive
logging.warning(
"Plugin loader failed; continuing without plugins.",
exc_info=True,
)
plugin_middlewares = []
skills_mw: SkillsMiddleware | None = None
if flags.enable_skills and not flags.disable_new_agent_stack:
try:
skills_factory = build_skills_backend_factory(
search_space_id=search_space_id
if filesystem_mode == FilesystemMode.CLOUD
else None,
)
skills_mw = SkillsMiddleware(
backend=skills_factory,
sources=default_skills_sources(),
)
except Exception as exc: # pragma: no cover - defensive
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
skills_mw = None
selector_mw: LLMToolSelectorMiddleware | None = None
if (
flags.enable_llm_tool_selector
and not flags.disable_new_agent_stack
and len(tools) > 30
):
try:
selector_mw = LLMToolSelectorMiddleware(
model="openai:gpt-4o-mini",
max_tools=12,
always_include=[
name
for name in (
"update_memory",
"get_connected_accounts",
"scrape_webpage",
)
if name in {t.name for t in tools}
],
)
except Exception:
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
selector_mw = None
deepagent_middleware = [
busy_mutex_mw,
otel_mw,
TodoListMiddleware(),
_memory_middleware,
AnonymousDocumentMiddleware(
anon_session_id=anon_session_id,
)
if filesystem_mode == FilesystemMode.CLOUD
else None,
KnowledgeTreeMiddleware(
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
llm=llm,
)
if filesystem_mode == FilesystemMode.CLOUD
else None,
KnowledgePriorityMiddleware(
llm=llm,
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
),
FileIntentMiddleware(llm=llm),
SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
created_by_id=user_id,
thread_id=thread_id,
),
KnowledgeBasePersistenceMiddleware(
search_space_id=search_space_id,
created_by_id=user_id,
filesystem_mode=filesystem_mode,
thread_id=thread_id,
)
if filesystem_mode == FilesystemMode.CLOUD
else None,
skills_mw,
SurfSenseCheckpointedSubAgentMiddleware(
checkpointer=checkpointer,
backend=StateBackend,
subagents=subagent_specs,
),
selector_mw,
model_call_limit_mw,
tool_call_limit_mw,
context_edit_mw,
summarization_mw,
noop_mw,
retry_mw,
fallback_mw,
repair_mw,
permission_mw,
doom_loop_mw,
action_log_mw,
PatchToolCallsMiddleware(),
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
*plugin_middlewares,
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
return [m for m in deepagent_middleware if m is not None]

View file

@ -2,6 +2,6 @@
from __future__ import annotations
from .factory import create_surfsense_deep_agent
from .factory import create_multi_agent_chat_deep_agent
__all__ = ["create_surfsense_deep_agent"]
__all__ = ["create_multi_agent_chat_deep_agent"]

View file

@ -0,0 +1,117 @@
"""Compiled agent graph caching for the multi-agent path."""
from __future__ import annotations
import asyncio
from collections.abc import Sequence
from typing import Any
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
from app.agents.new_chat.agent_cache import (
flags_signature,
get_cache,
stable_hash,
system_prompt_hash,
tools_signature,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.db import ChatVisibility
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
def mcp_signature(mcp_tools_by_agent: dict[str, ToolsPermissions]) -> str:
"""Hash the per-agent MCP tool surface so a change rotates the cache key."""
rows = []
for agent_name in sorted(mcp_tools_by_agent.keys()):
perms = mcp_tools_by_agent[agent_name]
allow_names = sorted(item.get("name", "") for item in perms.get("allow", []))
ask_names = sorted(item.get("name", "") for item in perms.get("ask", []))
rows.append((agent_name, allow_names, ask_names))
return stable_hash(rows)
async def build_agent_with_cache(
*,
llm: BaseChatModel,
tools: Sequence[BaseTool],
final_system_prompt: str,
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
visibility: ChatVisibility,
anon_session_id: str | None,
available_connectors: list[str],
available_document_types: list[str],
mentioned_document_ids: list[int] | None,
max_input_tokens: int | None,
flags: AgentFeatureFlags,
checkpointer: Checkpointer,
subagent_dependencies: dict[str, Any],
mcp_tools_by_agent: dict[str, ToolsPermissions],
disabled_tools: list[str] | None,
config_id: str | None,
) -> Any:
"""Compile the multi-agent graph, serving from cache when key components are stable."""
async def _build() -> Any:
return await asyncio.to_thread(
build_compiled_agent_graph_sync,
llm=llm,
tools=tools,
final_system_prompt=final_system_prompt,
backend_resolver=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
visibility=visibility,
anon_session_id=anon_session_id,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
max_input_tokens=max_input_tokens,
flags=flags,
checkpointer=checkpointer,
subagent_dependencies=subagent_dependencies,
mcp_tools_by_agent=mcp_tools_by_agent,
disabled_tools=disabled_tools,
)
if not (flags.enable_agent_cache and not flags.disable_new_agent_stack):
return await _build()
# Every per-request value any middleware closes over at __init__ must be in
# the key, otherwise a hit will leak state across threads. Bump the schema
# version when the component list changes shape.
cache_key = stable_hash(
"multi-agent-v1",
config_id,
thread_id,
user_id,
search_space_id,
visibility,
filesystem_mode,
anon_session_id,
tools_signature(
tools,
available_connectors=available_connectors,
available_document_types=available_document_types,
),
mcp_signature(mcp_tools_by_agent),
flags_signature(flags),
system_prompt_hash(final_system_prompt),
max_input_tokens,
sorted(disabled_tools) if disabled_tools else None,
)
return await get_cache().get_or_build(cache_key, builder=_build)
__all__ = ["build_agent_with_cache", "mcp_signature"]

View file

@ -2,7 +2,6 @@
from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Sequence
@ -26,23 +25,24 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
from app.agents.new_chat.filesystem_backends import build_backend_resolver
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.llm_config import AgentConfig
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
from app.agents.new_chat.tools.registry import build_tools_async
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
from ..system_prompt import build_main_agent_system_prompt
from ..tools import (
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
)
from .agent_cache import build_agent_with_cache
_perf_log = get_perf_logger()
async def create_surfsense_deep_agent(
async def create_multi_agent_chat_deep_agent(
llm: BaseChatModel,
search_space_id: int,
db_session: AsyncSession,
@ -62,6 +62,9 @@ async def create_surfsense_deep_agent(
):
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled."""
_t_agent_total = time.perf_counter()
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
filesystem_selection = filesystem_selection or FilesystemSelection()
backend_resolver = build_backend_resolver(
filesystem_selection,
@ -85,7 +88,18 @@ async def create_surfsense_deep_agent(
)
except Exception as e:
logging.warning("Failed to discover available connectors/document types: %s", e)
logging.warning(
"Connector/doc-type discovery failed; excluding connector subagents this turn: %s",
e,
)
# Fail closed: a None list short-circuits ``get_subagents_to_exclude`` to "exclude
# nothing", which would silently advertise every connector specialist on a flaky
# discovery call. Empty list excludes connector-gated subagents while keeping builtins.
if available_connectors is None:
available_connectors = []
if available_document_types is None:
available_document_types = []
_perf_log.info(
"[create_agent] Connector/doc-type discovery in %.3fs",
time.perf_counter() - _t0,
@ -115,7 +129,16 @@ async def create_surfsense_deep_agent(
}
_t0 = time.perf_counter()
mcp_tools_by_agent = await load_mcp_tools_by_connector(db_session, search_space_id)
try:
mcp_tools_by_agent = await load_mcp_tools_by_connector(db_session, search_space_id)
except Exception as e:
# Degrade to builtins-only rather than aborting the turn: a transient
# DB or MCP-server hiccup should not deny the user a response.
logging.warning(
"MCP tool discovery failed; subagents will run without MCP tools this turn: %s",
e,
)
mcp_tools_by_agent = {}
_perf_log.info(
"[create_agent] load_mcp_tools_by_connector in %.3fs (%d buckets)",
time.perf_counter() - _t0,
@ -195,9 +218,10 @@ async def create_surfsense_deep_agent(
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
config_id = agent_config.config_id if agent_config is not None else None
_t0 = time.perf_counter()
agent = await asyncio.to_thread(
build_compiled_agent_graph_sync,
agent = await build_agent_with_cache(
llm=llm,
tools=tools,
final_system_prompt=final_system_prompt,
@ -217,6 +241,7 @@ async def create_surfsense_deep_agent(
subagent_dependencies=dependencies,
mcp_tools_by_agent=mcp_tools_by_agent,
disabled_tools=disabled_tools,
config_id=config_id,
)
_perf_log.info(
"[create_agent] Middleware stack + graph compiled in %.3fs",

View file

@ -0,0 +1,7 @@
"""Multi-agent middleware stack assembly."""
from __future__ import annotations
from .stack import build_main_agent_deepagent_middleware
__all__ = ["build_main_agent_deepagent_middleware"]

View file

@ -0,0 +1,36 @@
"""Audit row per tool call (reversibility metadata)."""
from __future__ import annotations
import logging
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import ActionLogMiddleware
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
from ..shared.flags import enabled
def build_action_log_mw(
*,
flags: AgentFeatureFlags,
thread_id: int | None,
search_space_id: int,
user_id: str | None,
) -> ActionLogMiddleware | None:
if not enabled(flags, "enable_action_log") or thread_id is None:
return None
try:
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
return ActionLogMiddleware(
thread_id=thread_id,
search_space_id=search_space_id,
user_id=user_id,
tool_definitions=tool_defs_by_name,
)
except Exception: # pragma: no cover - defensive
logging.warning(
"ActionLogMiddleware init failed; running without it.",
exc_info=True,
)
return None

View file

@ -0,0 +1,16 @@
"""Anonymous document hydration from Redis (cloud only)."""
from __future__ import annotations
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import AnonymousDocumentMiddleware
def build_anonymous_doc_mw(
*,
filesystem_mode: FilesystemMode,
anon_session_id: str | None,
) -> AnonymousDocumentMiddleware | None:
if filesystem_mode != FilesystemMode.CLOUD:
return None
return AnonymousDocumentMiddleware(anon_session_id=anon_session_id)

View file

@ -0,0 +1,12 @@
"""Per-thread cooperative lock around the whole turn."""
from __future__ import annotations
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import BusyMutexMiddleware
from ..shared.flags import enabled
def build_busy_mutex_mw(flags: AgentFeatureFlags) -> BusyMutexMiddleware | None:
return BusyMutexMiddleware() if enabled(flags, "enable_busy_mutex") else None

View file

@ -69,9 +69,16 @@ def build_task_tool_with_parent_config(
raise ValueError(msg)
state_update = {k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS}
message_text = (
result["messages"][-1].text.rstrip() if result["messages"][-1].text else ""
)
messages = result["messages"]
if not messages:
msg = (
"CompiledSubAgent returned an empty 'messages' list. "
"Subagents must produce at least one message so the parent has "
"output to forward back to the user."
)
raise ValueError(msg)
last_text = getattr(messages[-1], "text", None) or ""
message_text = last_text.rstrip()
return Command(
update={
**state_update,

View file

@ -0,0 +1,50 @@
"""Spill + clear-tool-uses passes to keep payloads under budget."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
from langchain_core.tools import BaseTool
from app.agents.multi_agent_chat.main_agent.context_prune.prune_tool_names import (
safe_exclude_tools,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import (
ClearToolUsesEdit,
SpillingContextEditingMiddleware,
SpillToBackendEdit,
)
from ..shared.flags import enabled
def build_context_editing_mw(
*,
flags: AgentFeatureFlags,
max_input_tokens: int | None,
tools: Sequence[BaseTool],
backend_resolver: Any,
) -> SpillingContextEditingMiddleware | None:
if not enabled(flags, "enable_context_editing") or not max_input_tokens:
return None
spill_edit = SpillToBackendEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=safe_exclude_tools(tools),
clear_tool_inputs=True,
)
clear_edit = ClearToolUsesEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=safe_exclude_tools(tools),
clear_tool_inputs=True,
placeholder="[cleared - older tool output trimmed for context]",
)
return SpillingContextEditingMiddleware(
edits=[spill_edit, clear_edit],
backend_resolver=backend_resolver,
)

View file

@ -0,0 +1,13 @@
"""Drop duplicate HITL tool calls before execution."""
from __future__ import annotations
from collections.abc import Sequence
from langchain_core.tools import BaseTool
from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
return DedupHITLToolCallsMiddleware(agent_tools=list(tools))

View file

@ -0,0 +1,12 @@
"""Stop N identical tool calls in a row via interrupt."""
from __future__ import annotations
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import DoomLoopMiddleware
from ..shared.flags import enabled
def build_doom_loop_mw(flags: AgentFeatureFlags) -> DoomLoopMiddleware | None:
return DoomLoopMiddleware(threshold=3) if enabled(flags, "enable_doom_loop") else None

View file

@ -0,0 +1,23 @@
"""Commit staged cloud filesystem mutations to Postgres at end of turn."""
from __future__ import annotations
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import KnowledgeBasePersistenceMiddleware
def build_kb_persistence_mw(
*,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
) -> KnowledgeBasePersistenceMiddleware | None:
if filesystem_mode != FilesystemMode.CLOUD:
return None
return KnowledgeBasePersistenceMiddleware(
search_space_id=search_space_id,
created_by_id=user_id,
filesystem_mode=filesystem_mode,
thread_id=thread_id,
)

View file

@ -0,0 +1,27 @@
"""KB priority planner: <priority_documents> injection."""
from __future__ import annotations
from langchain_core.language_models import BaseChatModel
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import KnowledgePriorityMiddleware
def build_knowledge_priority_mw(
*,
llm: BaseChatModel,
search_space_id: int,
filesystem_mode: FilesystemMode,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
) -> KnowledgePriorityMiddleware:
return KnowledgePriorityMiddleware(
llm=llm,
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
)

View file

@ -0,0 +1,23 @@
"""<workspace_tree> injection (cloud only)."""
from __future__ import annotations
from langchain_core.language_models import BaseChatModel
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import KnowledgeTreeMiddleware
def build_knowledge_tree_mw(
*,
filesystem_mode: FilesystemMode,
search_space_id: int,
llm: BaseChatModel,
) -> KnowledgeTreeMiddleware | None:
if filesystem_mode != FilesystemMode.CLOUD:
return None
return KnowledgeTreeMiddleware(
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
llm=llm,
)

View file

@ -0,0 +1,12 @@
"""Provider-compat: append a `_noop` tool when tools=[] but history has tool calls."""
from __future__ import annotations
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import NoopInjectionMiddleware
from ..shared.flags import enabled
def build_noop_injection_mw(flags: AgentFeatureFlags) -> NoopInjectionMiddleware | None:
return NoopInjectionMiddleware() if enabled(flags, "enable_compaction_v2") else None

View file

@ -0,0 +1,12 @@
"""OTel spans on model and tool calls."""
from __future__ import annotations
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import OtelSpanMiddleware
from ..shared.flags import enabled
def build_otel_mw(flags: AgentFeatureFlags) -> OtelSpanMiddleware | None:
return OtelSpanMiddleware() if enabled(flags, "enable_otel") else None

View file

@ -0,0 +1,49 @@
"""Tail-of-stack plugin slot driven by env allowlist."""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.language_models import BaseChatModel
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.plugin_loader import (
PluginContext,
load_allowed_plugin_names_from_env,
load_plugin_middlewares,
)
from app.db import ChatVisibility
from ..shared.flags import enabled
def build_plugin_middlewares(
*,
flags: AgentFeatureFlags,
search_space_id: int,
user_id: str | None,
visibility: ChatVisibility,
llm: BaseChatModel,
) -> list[Any]:
if not enabled(flags, "enable_plugin_loader"):
return []
try:
allowed_names = load_allowed_plugin_names_from_env()
if not allowed_names:
return []
return load_plugin_middlewares(
PluginContext.build(
search_space_id=search_space_id,
user_id=user_id,
thread_visibility=visibility,
llm=llm,
),
allowed_plugin_names=allowed_names,
)
except Exception: # pragma: no cover - defensive
logging.warning(
"Plugin loader failed; continuing without plugins.",
exc_info=True,
)
return []

View file

@ -0,0 +1,50 @@
"""Repair miscased / unknown tool names to the registered set or invalid_tool."""
from __future__ import annotations
from collections.abc import Sequence
from langchain_core.tools import BaseTool
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import ToolCallNameRepairMiddleware
from ..shared.flags import enabled
# deepagents-built-in tool names the repair pass treats as known.
_DEEPAGENT_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
{
"write_todos",
"ls",
"read_file",
"write_file",
"edit_file",
"glob",
"grep",
"execute",
"task",
"mkdir",
"cd",
"pwd",
"move_file",
"rm",
"rmdir",
"list_tree",
"execute_code",
}
)
def build_repair_mw(
*,
flags: AgentFeatureFlags,
tools: Sequence[BaseTool],
) -> ToolCallNameRepairMiddleware | None:
if not enabled(flags, "enable_tool_call_repair"):
return None
registered_names: set[str] = {t.name for t in tools}
registered_names |= _DEEPAGENT_BUILTIN_TOOL_NAMES
return ToolCallNameRepairMiddleware(
registered_tool_names=registered_names,
fuzzy_match_threshold=None,
)

View file

@ -0,0 +1,39 @@
"""LLM-based tool subset selection (only when >30 tools)."""
from __future__ import annotations
import logging
from collections.abc import Sequence
from langchain.agents.middleware import LLMToolSelectorMiddleware
from langchain_core.tools import BaseTool
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from ..shared.flags import enabled
def build_selector_mw(
*,
flags: AgentFeatureFlags,
tools: Sequence[BaseTool],
) -> LLMToolSelectorMiddleware | None:
if not enabled(flags, "enable_llm_tool_selector") or len(tools) <= 30:
return None
try:
return LLMToolSelectorMiddleware(
model="openai:gpt-4o-mini",
max_tools=12,
always_include=[
name
for name in (
"update_memory",
"get_connected_accounts",
"scrape_webpage",
)
if name in {t.name for t in tools}
],
)
except Exception:
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
return None

View file

@ -0,0 +1,39 @@
"""Skill discovery + injection."""
from __future__ import annotations
import logging
from deepagents.middleware.skills import SkillsMiddleware
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import (
build_skills_backend_factory,
default_skills_sources,
)
from ..shared.flags import enabled
def build_skills_mw(
*,
flags: AgentFeatureFlags,
filesystem_mode: FilesystemMode,
search_space_id: int,
) -> SkillsMiddleware | None:
if not enabled(flags, "enable_skills"):
return None
try:
skills_factory = build_skills_backend_factory(
search_space_id=search_space_id
if filesystem_mode == FilesystemMode.CLOUD
else None,
)
return SkillsMiddleware(
backend=skills_factory,
sources=default_skills_sources(),
)
except Exception as exc: # pragma: no cover - defensive
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
return None

View file

@ -0,0 +1,9 @@
"""Anthropic prompt caching annotations on system/tool/message blocks."""
from __future__ import annotations
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
def build_anthropic_cache_mw() -> AnthropicPromptCachingMiddleware:
return AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")

View file

@ -0,0 +1,14 @@
"""Context-window summarization with SurfSense protected sections."""
from __future__ import annotations
from typing import Any
from deepagents.backends import StateBackend
from langchain_core.language_models import BaseChatModel
from app.agents.new_chat.middleware import create_surfsense_compaction_middleware
def build_compaction_mw(llm: BaseChatModel) -> Any:
return create_surfsense_compaction_middleware(llm, StateBackend)

View file

@ -0,0 +1,11 @@
"""File-intent classifier that gates strict write contracts."""
from __future__ import annotations
from langchain_core.language_models import BaseChatModel
from app.agents.new_chat.middleware import FileIntentMiddleware
def build_file_intent_mw(llm: BaseChatModel) -> FileIntentMiddleware:
return FileIntentMiddleware(llm=llm)

View file

@ -0,0 +1,25 @@
"""SurfSense filesystem tools/middleware."""
from __future__ import annotations
from typing import Any
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import SurfSenseFilesystemMiddleware
def build_filesystem_mw(
*,
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
) -> SurfSenseFilesystemMiddleware:
return SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
created_by_id=user_id,
thread_id=thread_id,
)

View file

@ -0,0 +1,10 @@
"""Single source of truth for the feature-flag predicate."""
from __future__ import annotations
from app.agents.new_chat.feature_flags import AgentFeatureFlags
def enabled(flags: AgentFeatureFlags, attr: str) -> bool:
"""``flags.<attr>`` is on AND the new-agent-stack kill switch is off."""
return getattr(flags, attr) and not flags.disable_new_agent_stack

View file

@ -0,0 +1,19 @@
"""User/team memory injection prepended to the conversation."""
from __future__ import annotations
from app.agents.new_chat.middleware import MemoryInjectionMiddleware
from app.db import ChatVisibility
def build_memory_mw(
*,
user_id: str | None,
search_space_id: int,
visibility: ChatVisibility,
) -> MemoryInjectionMiddleware:
return MemoryInjectionMiddleware(
user_id=user_id,
search_space_id=search_space_id,
thread_visibility=visibility,
)

View file

@ -0,0 +1,9 @@
"""Repair dangling tool-call sequences before each agent turn."""
from __future__ import annotations
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
def build_patch_tool_calls_mw() -> PatchToolCallsMiddleware:
return PatchToolCallsMiddleware()

View file

@ -0,0 +1,12 @@
"""Permission rulesets fanned out to parent / general-purpose / subagent stacks."""
from __future__ import annotations
from .context import PermissionContext, build_permission_context
from .middleware import build_full_permission_mw
__all__ = [
"PermissionContext",
"build_full_permission_mw",
"build_permission_context",
]

View file

@ -0,0 +1,109 @@
"""Derive shared permission context once; fan out to all three stack layers.
The context carries:
- ``rulesets``: full ask/deny/allow rules for the main-agent permission middleware.
- ``general_purpose_interrupt_on``: ``ask`` rules mirrored as deepagents
``interrupt_on`` so HITL still triggers from inside ``task`` runs (subagents
bypass the main-agent permission middleware).
- ``subagent_deny_mw``: a deny-only ``PermissionMiddleware`` instance shared
across the general-purpose and registry subagent stacks.
"""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from langchain_core.tools import BaseTool
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import PermissionMiddleware
from app.agents.new_chat.permissions import Rule, Ruleset
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
from ..flags import enabled
@dataclass(frozen=True)
class PermissionContext:
rulesets: list[Ruleset]
general_purpose_interrupt_on: dict[str, bool]
subagent_deny_mw: PermissionMiddleware | None
def build_permission_context(
*,
flags: AgentFeatureFlags,
filesystem_mode: FilesystemMode,
tools: Sequence[BaseTool],
available_connectors: list[str] | None,
) -> PermissionContext:
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
permission_enabled = enabled(flags, "enable_permission")
rulesets: list[Ruleset] = []
if permission_enabled or is_desktop_fs:
rulesets.append(
Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
)
)
if is_desktop_fs:
rulesets.append(
Ruleset(
rules=[
Rule(permission="rm", pattern="*", action="ask"),
Rule(permission="rmdir", pattern="*", action="ask"),
Rule(permission="move_file", pattern="*", action="ask"),
Rule(permission="edit_file", pattern="*", action="ask"),
Rule(permission="write_file", pattern="*", action="ask"),
],
origin="desktop_safety",
)
)
tool_names_in_use = {t.name for t in tools}
if permission_enabled:
available_set = set(available_connectors or [])
synthesized: list[Rule] = []
for tool_def in BUILTIN_TOOLS:
if tool_def.name not in tool_names_in_use:
continue
rc = tool_def.required_connector
if rc and rc not in available_set:
synthesized.append(
Rule(permission=tool_def.name, pattern="*", action="deny")
)
if synthesized:
rulesets.append(
Ruleset(rules=synthesized, origin="connector_synthesized")
)
general_purpose_interrupt_on: dict[str, bool] = {
rule.permission: True
for rs in rulesets
for rule in rs.rules
if rule.action == "ask" and rule.permission in tool_names_in_use
}
deny_rulesets = [
Ruleset(
rules=[r for r in rs.rules if r.action == "deny"],
origin=rs.origin,
)
for rs in rulesets
]
deny_rulesets = [rs for rs in deny_rulesets if rs.rules]
subagent_deny_mw: PermissionMiddleware | None = (
PermissionMiddleware(rulesets=deny_rulesets) if deny_rulesets else None
)
return PermissionContext(
rulesets=rulesets,
general_purpose_interrupt_on=general_purpose_interrupt_on,
subagent_deny_mw=subagent_deny_mw,
)

View file

@ -0,0 +1,10 @@
"""Main-agent permission middleware (full ask/deny/allow rules)."""
from __future__ import annotations
from app.agents.new_chat.middleware import PermissionMiddleware
from app.agents.new_chat.permissions import Ruleset
def build_full_permission_mw(rulesets: list[Ruleset]) -> PermissionMiddleware | None:
return PermissionMiddleware(rulesets=rulesets) if rulesets else None

View file

@ -0,0 +1,7 @@
"""Resilience middleware shared as the same instances across parent / general-purpose / registry."""
from __future__ import annotations
from .bundle import ResilienceBundle, build_resilience_bundle
__all__ = ["ResilienceBundle", "build_resilience_bundle"]

View file

@ -0,0 +1,51 @@
"""Construct each resilience middleware once; same instances flow into every consumer."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from langchain.agents.middleware import (
ModelCallLimitMiddleware,
ToolCallLimitMiddleware,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import RetryAfterMiddleware
from app.agents.new_chat.middleware.scoped_model_fallback import (
ScopedModelFallbackMiddleware,
)
from .fallback import build_fallback_mw
from .model_call_limit import build_model_call_limit_mw
from .retry import build_retry_mw
from .tool_call_limit import build_tool_call_limit_mw
@dataclass(frozen=True)
class ResilienceBundle:
retry: RetryAfterMiddleware | None
fallback: ScopedModelFallbackMiddleware | None
model_call_limit: ModelCallLimitMiddleware | None
tool_call_limit: ToolCallLimitMiddleware | None
def as_list(self) -> list[Any]:
return [
m
for m in (
self.retry,
self.fallback,
self.model_call_limit,
self.tool_call_limit,
)
if m is not None
]
def build_resilience_bundle(flags: AgentFeatureFlags) -> ResilienceBundle:
return ResilienceBundle(
retry=build_retry_mw(flags),
fallback=build_fallback_mw(flags),
model_call_limit=build_model_call_limit_mw(flags),
tool_call_limit=build_tool_call_limit_mw(flags),
)

View file

@ -0,0 +1,27 @@
"""Switch to a fallback model on provider/network errors only."""
from __future__ import annotations
import logging
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware.scoped_model_fallback import (
ScopedModelFallbackMiddleware,
)
from ..flags import enabled
def build_fallback_mw(
flags: AgentFeatureFlags,
) -> ScopedModelFallbackMiddleware | None:
if not enabled(flags, "enable_model_fallback"):
return None
try:
return ScopedModelFallbackMiddleware(
"openai:gpt-4o-mini",
"anthropic:claude-3-5-haiku-20241022",
)
except Exception:
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
return None

View file

@ -0,0 +1,21 @@
"""Cap model calls per thread / per run to prevent runaway cost."""
from __future__ import annotations
from langchain.agents.middleware import ModelCallLimitMiddleware
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from ..flags import enabled
def build_model_call_limit_mw(
flags: AgentFeatureFlags,
) -> ModelCallLimitMiddleware | None:
if not enabled(flags, "enable_model_call_limit"):
return None
return ModelCallLimitMiddleware(
thread_limit=120,
run_limit=80,
exit_behavior="end",
)

View file

@ -0,0 +1,16 @@
"""Retry on transient model errors (e.g. Retry-After-bearing 429s)."""
from __future__ import annotations
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.middleware import RetryAfterMiddleware
from ..flags import enabled
def build_retry_mw(flags: AgentFeatureFlags) -> RetryAfterMiddleware | None:
return (
RetryAfterMiddleware(max_retries=3)
if enabled(flags, "enable_retry_after")
else None
)

View file

@ -0,0 +1,21 @@
"""Cap tool calls per thread / per run to bound infinite-loop blast radius."""
from __future__ import annotations
from langchain.agents.middleware import ToolCallLimitMiddleware
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from ..flags import enabled
def build_tool_call_limit_mw(
flags: AgentFeatureFlags,
) -> ToolCallLimitMiddleware | None:
if not enabled(flags, "enable_tool_call_limit"):
return None
return ToolCallLimitMiddleware(
thread_limit=300,
run_limit=80,
exit_behavior="continue",
)

View file

@ -0,0 +1,9 @@
"""Todo-list middleware (each consumer needs its own instance)."""
from __future__ import annotations
from langchain.agents.middleware import TodoListMiddleware
def build_todos_mw() -> TodoListMiddleware:
return TodoListMiddleware()

View file

@ -0,0 +1,216 @@
"""Main-agent middleware list assembly: one line per slot."""
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import Any
from deepagents import SubAgent
from deepagents.backends import StateBackend
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from app.agents.multi_agent_chat.subagents import (
build_subagents,
get_subagents_to_exclude,
)
from app.agents.multi_agent_chat.subagents.builtins.general_purpose.agent import (
build_subagent as build_general_purpose_subagent,
)
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.db import ChatVisibility
from .main_agent.action_log import build_action_log_mw
from .main_agent.anonymous_doc import build_anonymous_doc_mw
from .main_agent.busy_mutex import build_busy_mutex_mw
from .main_agent.checkpointed_subagent_middleware import (
SurfSenseCheckpointedSubAgentMiddleware,
)
from .main_agent.context_editing import build_context_editing_mw
from .main_agent.dedup_hitl import build_dedup_hitl_mw
from .main_agent.doom_loop import build_doom_loop_mw
from .main_agent.kb_persistence import build_kb_persistence_mw
from .main_agent.knowledge_priority import build_knowledge_priority_mw
from .main_agent.knowledge_tree import build_knowledge_tree_mw
from .main_agent.noop_injection import build_noop_injection_mw
from .main_agent.otel import build_otel_mw
from .main_agent.plugins import build_plugin_middlewares
from .main_agent.repair import build_repair_mw
from .main_agent.selector import build_selector_mw
from .main_agent.skills import build_skills_mw
from .shared.anthropic_cache import build_anthropic_cache_mw
from .shared.compaction import build_compaction_mw
from .shared.file_intent import build_file_intent_mw
from .shared.filesystem import build_filesystem_mw
from .shared.memory import build_memory_mw
from .shared.patch_tool_calls import build_patch_tool_calls_mw
from .shared.permissions import (
build_full_permission_mw,
build_permission_context,
)
from .shared.resilience import build_resilience_bundle
from .shared.todos import build_todos_mw
from .subagent.extras import build_subagent_extras
def build_main_agent_deepagent_middleware(
*,
llm: BaseChatModel,
tools: Sequence[BaseTool],
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
visibility: ChatVisibility,
anon_session_id: str | None,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
max_input_tokens: int | None,
flags: AgentFeatureFlags,
subagent_dependencies: dict[str, Any],
checkpointer: Checkpointer,
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
disabled_tools: list[str] | None = None,
) -> list[Any]:
"""Ordered middleware for ``create_agent`` (None entries already stripped)."""
permissions = build_permission_context(
flags=flags,
filesystem_mode=filesystem_mode,
tools=tools,
available_connectors=available_connectors,
)
resilience = build_resilience_bundle(flags)
# Single instance threaded into both the main-agent stack and the general-purpose subagent.
memory_mw = build_memory_mw(
user_id=user_id,
search_space_id=search_space_id,
visibility=visibility,
)
general_purpose_subagent = build_general_purpose_subagent(
llm=llm,
tools=tools,
backend_resolver=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
permissions=permissions,
resilience=resilience,
memory_mw=memory_mw,
)
subagents_registry: list[SubAgent] = []
try:
subagent_extras = build_subagent_extras(
permissions=permissions,
resilience=resilience,
)
subagents_registry = build_subagents(
dependencies=subagent_dependencies,
model=llm,
extra_middleware=subagent_extras,
mcp_tools_by_agent=mcp_tools_by_agent or {},
exclude=get_subagents_to_exclude(available_connectors),
disabled_tools=disabled_tools,
)
logging.debug(
"Subagents registry: %s",
[s["name"] for s in subagents_registry],
)
except Exception:
# Degrade to general-purpose-only rather than aborting the turn:
# one bad subagent dep should not deny the user a response.
logging.exception(
"Subagents registry build failed; falling back to general-purpose only"
)
subagents_registry = []
subagents: list[SubAgent] = [general_purpose_subagent, *subagents_registry]
stack: list[Any] = [
build_busy_mutex_mw(flags),
build_otel_mw(flags),
build_todos_mw(),
memory_mw,
build_anonymous_doc_mw(
filesystem_mode=filesystem_mode, anon_session_id=anon_session_id
),
build_knowledge_tree_mw(
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
llm=llm,
),
build_knowledge_priority_mw(
llm=llm,
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
),
build_file_intent_mw(llm),
build_filesystem_mw(
backend_resolver=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
),
build_kb_persistence_mw(
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
),
build_skills_mw(
flags=flags,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
),
SurfSenseCheckpointedSubAgentMiddleware(
checkpointer=checkpointer,
backend=StateBackend,
subagents=subagents,
),
build_selector_mw(flags=flags, tools=tools),
resilience.model_call_limit,
resilience.tool_call_limit,
build_context_editing_mw(
flags=flags,
max_input_tokens=max_input_tokens,
tools=tools,
backend_resolver=backend_resolver,
),
build_compaction_mw(llm),
build_noop_injection_mw(flags),
resilience.retry,
resilience.fallback,
build_repair_mw(flags=flags, tools=tools),
build_full_permission_mw(permissions.rulesets),
build_doom_loop_mw(flags),
build_action_log_mw(
flags=flags,
thread_id=thread_id,
search_space_id=search_space_id,
user_id=user_id,
),
build_patch_tool_calls_mw(),
build_dedup_hitl_mw(tools),
*build_plugin_middlewares(
flags=flags,
search_space_id=search_space_id,
user_id=user_id,
visibility=visibility,
llm=llm,
),
build_anthropic_cache_mw(),
]
return [m for m in stack if m is not None]

View file

@ -0,0 +1,28 @@
"""Extra middleware threaded into every registry subagent's stack.
Registry subagents are scoped to one domain (deliverables, research, memory,
connectors, MCP) and never read or write the SurfSense filesystem that
capability belongs to the main agent and is delegated to the general-purpose
subagent as an escape hatch. Keeping FS off the registry stacks avoids
polluting their tool surface with FS tools they never act on.
"""
from __future__ import annotations
from typing import Any
from ..shared.permissions import PermissionContext
from ..shared.resilience import ResilienceBundle
from ..shared.todos import build_todos_mw
def build_subagent_extras(
*,
permissions: PermissionContext,
resilience: ResilienceBundle,
) -> list[Any]:
extras: list[Any] = [build_todos_mw()]
if permissions.subagent_deny_mw is not None:
extras.append(permissions.subagent_deny_mw)
extras.extend(resilience.as_list())
return extras

View file

@ -0,0 +1,105 @@
"""General-purpose subagent for the multi-agent main agent."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, cast
from deepagents import SubAgent
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from app.agents.multi_agent_chat.middleware.shared.anthropic_cache import (
build_anthropic_cache_mw,
)
from app.agents.multi_agent_chat.middleware.shared.compaction import (
build_compaction_mw,
)
from app.agents.multi_agent_chat.middleware.shared.file_intent import (
build_file_intent_mw,
)
from app.agents.multi_agent_chat.middleware.shared.filesystem import (
build_filesystem_mw,
)
from app.agents.multi_agent_chat.middleware.shared.patch_tool_calls import (
build_patch_tool_calls_mw,
)
from app.agents.multi_agent_chat.middleware.shared.permissions import (
PermissionContext,
)
from app.agents.multi_agent_chat.middleware.shared.resilience import (
ResilienceBundle,
)
from app.agents.multi_agent_chat.middleware.shared.todos import build_todos_mw
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import MemoryInjectionMiddleware
NAME = "general-purpose"
def build_subagent(
*,
llm: BaseChatModel,
tools: Sequence[BaseTool],
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | None,
permissions: PermissionContext,
resilience: ResilienceBundle,
memory_mw: MemoryInjectionMiddleware,
) -> SubAgent:
"""Deny + resilience inserts encapsulated here so the orchestrator never mutates the list."""
middleware: list[Any] = [
build_todos_mw(),
memory_mw,
build_file_intent_mw(llm),
build_filesystem_mw(
backend_resolver=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
),
build_compaction_mw(llm),
build_patch_tool_calls_mw(),
build_anthropic_cache_mw(),
]
if permissions.subagent_deny_mw is not None:
patch_idx = next(
(
i
for i, m in enumerate(middleware)
if isinstance(m, PatchToolCallsMiddleware)
),
len(middleware),
)
middleware.insert(patch_idx, permissions.subagent_deny_mw)
resilience_mws = resilience.as_list()
if resilience_mws:
cache_idx = next(
(
i
for i, m in enumerate(middleware)
if isinstance(m, AnthropicPromptCachingMiddleware)
),
len(middleware),
)
for offset, mw in enumerate(resilience_mws):
middleware.insert(cache_idx + offset, mw)
spec: dict[str, Any] = {
**GENERAL_PURPOSE_SUBAGENT,
"model": llm,
"tools": tools,
"middleware": middleware,
}
if permissions.general_purpose_interrupt_on:
spec["interrupt_on"] = permissions.general_purpose_interrupt_on
return cast(SubAgent, spec)

View file

@ -31,7 +31,6 @@ from langchain.agents import create_agent
from langchain.agents.middleware import (
LLMToolSelectorMiddleware,
ModelCallLimitMiddleware,
ModelFallbackMiddleware,
TodoListMiddleware,
ToolCallLimitMiddleware,
)
@ -77,6 +76,9 @@ from app.agents.new_chat.middleware import (
create_surfsense_compaction_middleware,
default_skills_sources,
)
from app.agents.new_chat.middleware.scoped_model_fallback import (
ScopedModelFallbackMiddleware,
)
from app.agents.new_chat.permissions import Rule, Ruleset
from app.agents.new_chat.plugin_loader import (
PluginContext,
@ -792,15 +794,15 @@ def _build_compiled_agent_blocking(
# Fallback chain — primary is the agent's own model; we add cheap
# alternatives. Off by default; only the first call site that
# configures the chain via env should enable it.
fallback_mw: ModelFallbackMiddleware | None = None
fallback_mw: ScopedModelFallbackMiddleware | None = None
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
try:
fallback_mw = ModelFallbackMiddleware(
fallback_mw = ScopedModelFallbackMiddleware(
"openai:gpt-4o-mini",
"anthropic:claude-3-5-haiku-20241022",
)
except Exception:
logging.warning("ModelFallbackMiddleware init failed; skipping.")
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
fallback_mw = None
model_call_limit_mw = (
ModelCallLimitMiddleware(

View file

@ -0,0 +1,91 @@
"""Fallback only on provider/network errors; let programming bugs raise."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import ModelFallbackMiddleware
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.messages import AIMessage
# Matched by class name across the MRO so we don't have to import every
# provider SDK (openai/anthropic/google/...). Extend as new providers ship.
_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
{
"RateLimitError",
"APIStatusError",
"InternalServerError",
"ServiceUnavailableError",
"BadGatewayError",
"GatewayTimeoutError",
"APIConnectionError",
"APITimeoutError",
"ConnectError",
"ConnectTimeout",
"ReadTimeout",
"RemoteProtocolError",
"TimeoutError",
"TimeoutException",
}
)
def _is_fallback_eligible(exc: BaseException) -> bool:
return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__)
class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
"""Re-raise non-provider exceptions instead of walking the fallback chain."""
def wrap_model_call( # type: ignore[override]
self,
request: ModelRequest[Any],
handler: Callable[[ModelRequest[Any]], ModelResponse[Any]],
) -> ModelResponse[Any] | AIMessage:
last_exception: Exception
try:
return handler(request)
except Exception as e:
if not _is_fallback_eligible(e):
raise
last_exception = e
for fallback_model in self.models:
try:
return handler(request.override(model=fallback_model))
except Exception as e:
if not _is_fallback_eligible(e):
raise
last_exception = e
continue
raise last_exception
async def awrap_model_call( # type: ignore[override]
self,
request: ModelRequest[Any],
handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]],
) -> ModelResponse[Any] | AIMessage:
last_exception: Exception
try:
return await handler(request)
except Exception as e:
if not _is_fallback_eligible(e):
raise
last_exception = e
for fallback_model in self.models:
try:
return await handler(request.override(model=fallback_model))
except Exception as e:
if not _is_fallback_eligible(e):
raise
last_exception = e
continue
raise last_exception

View file

@ -28,9 +28,7 @@ from langchain_core.messages import HumanMessage
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.agents.multi_agent_chat import (
create_surfsense_deep_agent as create_registry_deep_agent,
)
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.context import SurfSenseContextSchema
@ -577,6 +575,43 @@ async def _preflight_llm(llm: Any) -> None:
)
async def _build_main_agent_for_thread(
agent_factory: Any,
*,
llm: Any,
search_space_id: int,
db_session: Any,
connector_service: ConnectorService,
checkpointer: Any,
user_id: str | None,
thread_id: int | None,
agent_config: AgentConfig | None,
firecrawl_api_key: str | None,
thread_visibility: ChatVisibility | None,
filesystem_selection: FilesystemSelection | None,
disabled_tools: list[str] | None = None,
mentioned_document_ids: list[int] | None = None,
) -> Any:
"""Single (re)build path so the agent factory cannot drift across
initial build, preflight repin, and mid-stream 429 recovery for one
``thread_id``: a graph swap mid-turn would corrupt checkpointer state."""
return await agent_factory(
llm=llm,
search_space_id=search_space_id,
db_session=db_session,
connector_service=connector_service,
checkpointer=checkpointer,
user_id=user_id,
thread_id=thread_id,
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=thread_visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
)
async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
"""Wait for a discarded speculative agent build to release shared state.
@ -2767,7 +2802,7 @@ async def stream_new_chat(
_t0 = time.perf_counter()
agent_factory = (
create_registry_deep_agent
create_multi_agent_chat_deep_agent
if use_multi_agent
else create_surfsense_deep_agent
)
@ -2776,7 +2811,8 @@ async def stream_new_chat(
# if preflight reports 429 we will discard this future and rebuild
# against the freshly pinned config below.
agent_build_task = asyncio.create_task(
agent_factory(
_build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
@ -2787,9 +2823,9 @@ async def stream_new_chat(
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
),
name="agent_build:stream_new_chat",
)
@ -3466,7 +3502,8 @@ async def stream_new_chat(
title_task = None
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
agent = await _build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
@ -3477,9 +3514,9 @@ async def stream_new_chat(
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
filesystem_selection=filesystem_selection,
)
_perf_log.info(
"[stream_new_chat] Runtime rate-limit recovery repinned "
@ -4130,12 +4167,13 @@ async def stream_resume_chat(
_t0 = time.perf_counter()
agent_factory = (
create_registry_deep_agent
create_multi_agent_chat_deep_agent
if _app_config.MULTI_AGENT_CHAT_ENABLED
else create_surfsense_deep_agent
)
agent_build_task = asyncio.create_task(
agent_factory(
_build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
@ -4224,7 +4262,8 @@ async def stream_resume_chat(
"fallback_config_id": llm_config_id,
},
)
agent = await agent_factory(
agent = await _build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
@ -4409,7 +4448,8 @@ async def stream_resume_chat(
raise stream_exc
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
agent = await _build_main_agent_for_thread(
agent_factory,
llm=llm,
search_space_id=search_space_id,
db_session=session,
@ -4421,6 +4461,7 @@ async def stream_resume_chat(
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
)
_perf_log.info(
"[stream_resume] Runtime rate-limit recovery repinned "

View file

@ -0,0 +1,208 @@
"""End-to-end resume-bridge tests against a real LangGraph subagent."""
from __future__ import annotations
import ast
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Command, interrupt
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
class _SubagentState(TypedDict, total=False):
messages: list
decision_text: str
def _build_single_interrupt_subagent():
def approve_node(state):
from langchain_core.messages import AIMessage
decision = interrupt(
{
"action_requests": [
{
"name": "do_thing",
"args": {"x": 1},
"description": "test action",
}
],
"review_configs": [{}],
}
)
return {
"messages": [AIMessage(content="done")],
"decision_text": repr(decision),
}
graph = StateGraph(_SubagentState)
graph.add_node("approve", approve_node)
graph.add_edge(START, "approve")
graph.add_edge("approve", END)
return graph.compile(checkpointer=InMemorySaver())
def _make_runtime(config: dict) -> ToolRuntime:
return ToolRuntime(
state={"messages": [HumanMessage(content="seed")]},
context=None,
config=config,
stream_writer=None,
tool_call_id="parent-tcid-1",
store=None,
)
@pytest.mark.asyncio
async def test_resume_bridge_dispatches_decision_into_pending_subagent():
"""Side-channel decision must reach the subagent's pending interrupt verbatim."""
subagent = _build_single_interrupt_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "approver",
"description": "approves things",
"runnable": subagent,
}
]
)
parent_config: dict = {
"configurable": {"thread_id": "shared-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await subagent.aget_state(parent_config)
assert snap.tasks and snap.tasks[0].interrupts, (
"fixture broken: subagent should be paused on its interrupt"
)
parent_config["configurable"]["surfsense_resume_value"] = {
"decisions": ["APPROVED"]
}
runtime = _make_runtime(parent_config)
result = await task_tool.coroutine(
description="please approve",
subagent_type="approver",
runtime=runtime,
)
assert isinstance(result, Command)
update = result.update
assert update["decision_text"] == repr({"decisions": ["APPROVED"]})
assert "surfsense_resume_value" not in parent_config["configurable"]
final = await subagent.aget_state(parent_config)
assert not final.tasks or all(not t.interrupts for t in final.tasks)
@pytest.mark.asyncio
async def test_pending_interrupt_without_resume_value_raises_runtime_error():
"""Bridge must fail loud rather than silently replay the user's interrupt."""
subagent = _build_single_interrupt_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "approver",
"description": "approves things",
"runnable": subagent,
}
]
)
parent_config: dict = {
"configurable": {"thread_id": "guard-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await subagent.aget_state(parent_config)
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
runtime = _make_runtime(parent_config)
with pytest.raises(RuntimeError, match="resume bridge is broken"):
await task_tool.coroutine(
description="please approve",
subagent_type="approver",
runtime=runtime,
)
def _build_bundle_subagent():
def bundle_node(state):
from langchain_core.messages import AIMessage
decision = interrupt(
{
"action_requests": [
{"name": "create_a", "args": {}, "description": ""},
{"name": "create_b", "args": {}, "description": ""},
{"name": "create_c", "args": {}, "description": ""},
],
"review_configs": [{}, {}, {}],
}
)
return {
"messages": [AIMessage(content="bundle-done")],
"decision_text": repr(decision),
}
graph = StateGraph(_SubagentState)
graph.add_node("bundle", bundle_node)
graph.add_edge(START, "bundle")
graph.add_edge("bundle", END)
return graph.compile(checkpointer=InMemorySaver())
@pytest.mark.asyncio
async def test_bundle_three_mixed_decisions_arrive_in_order():
"""Approve / edit / reject for a 3-action bundle must land at ordinals 0/1/2."""
subagent = _build_bundle_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "bundler",
"description": "creates a bundle",
"runnable": subagent,
}
]
)
parent_config: dict = {
"configurable": {"thread_id": "bundle-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
decisions_payload = {
"decisions": [
{"type": "approve", "args": {}},
{"type": "edit", "args": {"args": {"name": "edited-b"}}},
{"type": "reject", "args": {"message": "no thanks"}},
]
}
parent_config["configurable"]["surfsense_resume_value"] = decisions_payload
runtime = _make_runtime(parent_config)
result = await task_tool.coroutine(
description="run bundle",
subagent_type="bundler",
runtime=runtime,
)
assert isinstance(result, Command)
received = ast.literal_eval(result.update["decision_text"])
assert received == decisions_payload
assert received["decisions"][0]["type"] == "approve"
assert received["decisions"][1]["type"] == "edit"
assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}}
assert received["decisions"][2]["type"] == "reject"

View file

@ -0,0 +1,55 @@
"""Pins the first-wins assumption of ``get_first_pending_subagent_interrupt``.
The bridge currently relies on at-most-one pending interrupt per snapshot
(sequential tool nodes). If parallel tool calls are ever enabled, the bridge
needs an id-aware lookup; these tests will need to be revisited at that point.
"""
from __future__ import annotations
from types import SimpleNamespace
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume import (
get_first_pending_subagent_interrupt,
)
class TestGetFirstPendingSubagentInterrupt:
def test_returns_first_when_multiple_top_level_interrupts_pending(self):
first = SimpleNamespace(id="i-1", value={"decision": "approve"})
second = SimpleNamespace(id="i-2", value={"decision": "reject"})
state = SimpleNamespace(interrupts=(first, second), tasks=())
assert get_first_pending_subagent_interrupt(state) == (
"i-1",
{"decision": "approve"},
)
def test_returns_first_when_multiple_subtask_interrupts_pending(self):
first = SimpleNamespace(id="i-A", value="approve")
second = SimpleNamespace(id="i-B", value="reject")
sub_task = SimpleNamespace(interrupts=(first, second))
state = SimpleNamespace(interrupts=(), tasks=(sub_task,))
assert get_first_pending_subagent_interrupt(state) == ("i-A", "approve")
def test_returns_none_when_no_interrupts(self):
state = SimpleNamespace(interrupts=(), tasks=())
assert get_first_pending_subagent_interrupt(state) == (None, None)
def test_returns_none_when_state_is_none(self):
assert get_first_pending_subagent_interrupt(None) == (None, None)
def test_skips_interrupts_with_none_value(self):
empty = SimpleNamespace(id="i-empty", value=None)
real = SimpleNamespace(id="i-real", value="approve")
state = SimpleNamespace(interrupts=(empty, real), tasks=())
assert get_first_pending_subagent_interrupt(state) == ("i-real", "approve")
def test_normalizes_non_string_id_to_none(self):
interrupt = SimpleNamespace(id=12345, value="approve")
state = SimpleNamespace(interrupts=(interrupt,), tasks=())
assert get_first_pending_subagent_interrupt(state) == (None, "approve")

View file

@ -0,0 +1,71 @@
"""Resume side-channel must be read exactly once per turn."""
from __future__ import annotations
from langchain.tools import ToolRuntime
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
consume_surfsense_resume,
has_surfsense_resume,
)
def _runtime_with_config(config: dict) -> ToolRuntime:
return ToolRuntime(
state=None,
context=None,
config=config,
stream_writer=None,
tool_call_id="tcid-test",
store=None,
)
class TestConsumeSurfsenseResume:
def test_pops_value_on_first_call(self):
runtime = _runtime_with_config(
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
)
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
def test_second_call_returns_none(self):
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
runtime = _runtime_with_config({"configurable": configurable})
consume_surfsense_resume(runtime)
assert consume_surfsense_resume(runtime) is None
assert "surfsense_resume_value" not in configurable
def test_returns_none_when_no_payload_queued(self):
runtime = _runtime_with_config({"configurable": {}})
assert consume_surfsense_resume(runtime) is None
def test_returns_none_when_configurable_missing(self):
runtime = _runtime_with_config({})
assert consume_surfsense_resume(runtime) is None
class TestHasSurfsenseResume:
def test_true_when_payload_queued(self):
runtime = _runtime_with_config(
{"configurable": {"surfsense_resume_value": "approve"}}
)
assert has_surfsense_resume(runtime) is True
def test_does_not_consume_payload(self):
configurable = {"surfsense_resume_value": "approve"}
runtime = _runtime_with_config({"configurable": configurable})
has_surfsense_resume(runtime)
assert configurable == {"surfsense_resume_value": "approve"}
def test_false_when_payload_absent(self):
runtime = _runtime_with_config({"configurable": {}})
assert has_surfsense_resume(runtime) is False

View file

@ -0,0 +1,97 @@
"""Subagent resilience contract: ``extra_middleware`` reaches the agent chain."""
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator
from typing import Any
import pytest
from langchain.agents import create_agent
from langchain.agents.middleware import ModelFallbackMiddleware
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
pack_subagent,
)
class RateLimitError(Exception):
"""Name matches the scoped-fallback eligibility allowlist."""
class _AlwaysFailingChatModel(BaseChatModel):
@property
def _llm_type(self) -> str:
return "always-failing-test-model"
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "primary llm exploded"
raise RateLimitError(msg)
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "primary llm exploded"
raise RateLimitError(msg)
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
msg = "primary llm exploded"
raise RateLimitError(msg)
async def _astream(
self, *args: Any, **kwargs: Any
) -> AsyncIterator[ChatGeneration]:
msg = "primary llm exploded"
raise RateLimitError(msg)
yield # pragma: no cover - unreachable, satisfies async generator typing
@pytest.mark.asyncio
async def test_subagent_recovers_when_primary_llm_fails():
"""Fallback in ``extra_middleware`` must finish the turn when primary raises."""
primary = _AlwaysFailingChatModel()
fallback = FakeMessagesListChatModel(
responses=[AIMessage(content="recovered via fallback")]
)
spec = pack_subagent(
name="resilience_test",
description="test subagent",
system_prompt="be helpful",
tools=[],
model=primary,
extra_middleware=[ModelFallbackMiddleware(fallback)],
)
agent = create_agent(
model=spec["model"],
tools=spec["tools"],
middleware=spec["middleware"],
system_prompt=spec["system_prompt"],
)
result = await agent.ainvoke({"messages": [HumanMessage(content="hi")]})
final = result["messages"][-1]
assert isinstance(final, AIMessage)
assert final.content == "recovered via fallback"

View file

@ -0,0 +1,130 @@
"""``ScopedModelFallbackMiddleware`` triggers fallback only on provider errors."""
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator
from typing import Any
import pytest
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult
class _RaisingChatModel(BaseChatModel):
exc_to_raise: Any
@property
def _llm_type(self) -> str:
return "raising-test-model"
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
raise self.exc_to_raise
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
raise self.exc_to_raise
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
raise self.exc_to_raise
async def _astream(
self, *args: Any, **kwargs: Any
) -> AsyncIterator[ChatGeneration]:
raise self.exc_to_raise
yield # pragma: no cover - unreachable
class _RecordingChatModel(BaseChatModel):
response_text: str = "fallback-ok"
call_count: int = 0
@property
def _llm_type(self) -> str:
return "recording-test-model"
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
self.call_count += 1
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content=self.response_text))
]
)
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return self._generate(messages, stop, None, **kwargs)
class RateLimitError(Exception):
"""Name matches the scoped-fallback eligibility allowlist."""
def _build_agent(primary: BaseChatModel, fallback: BaseChatModel):
from langchain.agents import create_agent
from app.agents.new_chat.middleware.scoped_model_fallback import (
ScopedModelFallbackMiddleware,
)
return create_agent(
model=primary,
tools=[],
middleware=[ScopedModelFallbackMiddleware(fallback)],
system_prompt="be helpful",
)
@pytest.mark.asyncio
async def test_provider_errors_trigger_fallback():
"""Eligible exception names must drive the fallback chain."""
primary = _RaisingChatModel(exc_to_raise=RateLimitError("429 from provider"))
fallback = _RecordingChatModel(response_text="recovered")
agent = _build_agent(primary, fallback)
result = await agent.ainvoke({"messages": [("user", "hi")]})
final = result["messages"][-1]
assert isinstance(final, AIMessage)
assert final.content == "recovered"
assert fallback.call_count == 1
@pytest.mark.asyncio
async def test_programming_errors_propagate_without_invoking_fallback():
"""Non-eligible exceptions must propagate; fallback must not be invoked."""
primary = _RaisingChatModel(exc_to_raise=KeyError("missing_state_field"))
fallback = _RecordingChatModel(response_text="should-never-arrive")
agent = _build_agent(primary, fallback)
with pytest.raises(KeyError, match="missing_state_field"):
await agent.ainvoke({"messages": [("user", "hi")]})
assert fallback.call_count == 0

View file

@ -202,6 +202,15 @@ class FakeBudgetLLM:
class TestKnowledgeBaseSearchMiddlewarePlanner:
@pytest.fixture(autouse=True)
def _disable_planner_runnable(self, monkeypatch):
# ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the
# planner Runnable path is enabled) calls ``.bind()`` on the LLM,
# which the mock does not implement. Pin the flag off so the
# planner falls through to the legacy ``self.llm.ainvoke`` path
# these tests assert against (``llm.calls[0]["config"]``).
monkeypatch.setenv("SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "false")
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
messages = [
HumanMessage(content="old user context " * 40),