mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 22:32:39 +02:00
Merge pull request #1351 from CREDO23/feature/multi-agent
[Improvement] Modular middleware stack + agent/prompt caching + subagent resilience + unit tests
This commit is contained in:
commit
a4fc812b85
70 changed files with 2037 additions and 547 deletions
|
|
@ -2,6 +2,6 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .main_agent import create_surfsense_deep_agent
|
from .main_agent import create_multi_agent_chat_deep_agent
|
||||||
|
|
||||||
__all__ = ["create_surfsense_deep_agent"]
|
__all__ = ["create_multi_agent_chat_deep_agent"]
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,6 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .runtime import create_surfsense_deep_agent
|
from .runtime import create_multi_agent_chat_deep_agent
|
||||||
|
|
||||||
__all__ = ["create_surfsense_deep_agent"]
|
__all__ = ["create_multi_agent_chat_deep_agent"]
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,9 @@ from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.middleware import (
|
||||||
|
build_main_agent_deepagent_middleware,
|
||||||
|
)
|
||||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||||
ToolsPermissions,
|
ToolsPermissions,
|
||||||
)
|
)
|
||||||
|
|
@ -19,8 +22,6 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
from .middleware import build_main_agent_deepagent_middleware
|
|
||||||
|
|
||||||
|
|
||||||
def build_compiled_agent_graph_sync(
|
def build_compiled_agent_graph_sync(
|
||||||
*,
|
*,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
"""Main-agent graph middleware assembly (SurfSense + LangChain + deepagents)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from .deepagent_stack import build_main_agent_deepagent_middleware
|
|
||||||
|
|
||||||
__all__ = ["build_main_agent_deepagent_middleware"]
|
|
||||||
|
|
@ -1,506 +0,0 @@
|
||||||
"""Assemble the main-agent deep-agent middleware list (LangChain + SurfSense + deepagents)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from deepagents import SubAgent
|
|
||||||
from deepagents.backends import StateBackend
|
|
||||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
|
||||||
from deepagents.middleware.skills import SkillsMiddleware
|
|
||||||
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
|
||||||
from langchain.agents.middleware import (
|
|
||||||
LLMToolSelectorMiddleware,
|
|
||||||
ModelCallLimitMiddleware,
|
|
||||||
ModelFallbackMiddleware,
|
|
||||||
TodoListMiddleware,
|
|
||||||
ToolCallLimitMiddleware,
|
|
||||||
)
|
|
||||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
|
||||||
from langchain_core.tools import BaseTool
|
|
||||||
from langgraph.types import Checkpointer
|
|
||||||
|
|
||||||
from app.agents.multi_agent_chat.subagents import (
|
|
||||||
build_subagents,
|
|
||||||
get_subagents_to_exclude,
|
|
||||||
)
|
|
||||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
|
||||||
ToolsPermissions,
|
|
||||||
)
|
|
||||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
|
||||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
|
||||||
from app.agents.new_chat.middleware import (
|
|
||||||
ActionLogMiddleware,
|
|
||||||
AnonymousDocumentMiddleware,
|
|
||||||
BusyMutexMiddleware,
|
|
||||||
ClearToolUsesEdit,
|
|
||||||
DedupHITLToolCallsMiddleware,
|
|
||||||
DoomLoopMiddleware,
|
|
||||||
FileIntentMiddleware,
|
|
||||||
KnowledgeBasePersistenceMiddleware,
|
|
||||||
KnowledgePriorityMiddleware,
|
|
||||||
KnowledgeTreeMiddleware,
|
|
||||||
MemoryInjectionMiddleware,
|
|
||||||
NoopInjectionMiddleware,
|
|
||||||
OtelSpanMiddleware,
|
|
||||||
PermissionMiddleware,
|
|
||||||
RetryAfterMiddleware,
|
|
||||||
SpillingContextEditingMiddleware,
|
|
||||||
SpillToBackendEdit,
|
|
||||||
SurfSenseFilesystemMiddleware,
|
|
||||||
ToolCallNameRepairMiddleware,
|
|
||||||
build_skills_backend_factory,
|
|
||||||
create_surfsense_compaction_middleware,
|
|
||||||
default_skills_sources,
|
|
||||||
)
|
|
||||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
|
||||||
from app.agents.new_chat.plugin_loader import (
|
|
||||||
PluginContext,
|
|
||||||
load_allowed_plugin_names_from_env,
|
|
||||||
load_plugin_middlewares,
|
|
||||||
)
|
|
||||||
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
|
||||||
from app.db import ChatVisibility
|
|
||||||
|
|
||||||
from ...context_prune.prune_tool_names import safe_exclude_tools
|
|
||||||
from .checkpointed_subagent_middleware import SurfSenseCheckpointedSubAgentMiddleware
|
|
||||||
|
|
||||||
|
|
||||||
def build_main_agent_deepagent_middleware(
|
|
||||||
*,
|
|
||||||
llm: BaseChatModel,
|
|
||||||
tools: Sequence[BaseTool],
|
|
||||||
backend_resolver: Any,
|
|
||||||
filesystem_mode: FilesystemMode,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str | None,
|
|
||||||
thread_id: int | None,
|
|
||||||
visibility: ChatVisibility,
|
|
||||||
anon_session_id: str | None,
|
|
||||||
available_connectors: list[str] | None,
|
|
||||||
available_document_types: list[str] | None,
|
|
||||||
mentioned_document_ids: list[int] | None,
|
|
||||||
max_input_tokens: int | None,
|
|
||||||
flags: AgentFeatureFlags,
|
|
||||||
subagent_dependencies: dict[str, Any],
|
|
||||||
checkpointer: Checkpointer,
|
|
||||||
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
|
|
||||||
disabled_tools: list[str] | None = None,
|
|
||||||
) -> list[Any]:
|
|
||||||
"""Build ordered middleware for ``create_agent`` (Nones already stripped)."""
|
|
||||||
_memory_middleware = MemoryInjectionMiddleware(
|
|
||||||
user_id=user_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
thread_visibility=visibility,
|
|
||||||
)
|
|
||||||
|
|
||||||
gp_middleware = [
|
|
||||||
TodoListMiddleware(),
|
|
||||||
_memory_middleware,
|
|
||||||
FileIntentMiddleware(llm=llm),
|
|
||||||
SurfSenseFilesystemMiddleware(
|
|
||||||
backend=backend_resolver,
|
|
||||||
filesystem_mode=filesystem_mode,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
created_by_id=user_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
),
|
|
||||||
create_surfsense_compaction_middleware(llm, StateBackend),
|
|
||||||
PatchToolCallsMiddleware(),
|
|
||||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Build permission rulesets up front so the GP subagent can mirror ``ask``
|
|
||||||
# rules into ``interrupt_on``: tool calls emitted from within ``task`` runs
|
|
||||||
# never reach the parent's ``PermissionMiddleware``.
|
|
||||||
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
|
||||||
permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack
|
|
||||||
permission_rulesets: list[Ruleset] = []
|
|
||||||
if permission_enabled or is_desktop_fs:
|
|
||||||
permission_rulesets.append(
|
|
||||||
Ruleset(
|
|
||||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
|
||||||
origin="surfsense_defaults",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if is_desktop_fs:
|
|
||||||
permission_rulesets.append(
|
|
||||||
Ruleset(
|
|
||||||
rules=[
|
|
||||||
Rule(permission="rm", pattern="*", action="ask"),
|
|
||||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
|
||||||
Rule(permission="move_file", pattern="*", action="ask"),
|
|
||||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
|
||||||
Rule(permission="write_file", pattern="*", action="ask"),
|
|
||||||
],
|
|
||||||
origin="desktop_safety",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Tools that self-prompt via ``request_approval`` must not also appear
|
|
||||||
# as ``ask`` rules — that would double-prompt the user for one call.
|
|
||||||
_tool_names_in_use = {t.name for t in tools}
|
|
||||||
|
|
||||||
# Deny parent-bound tools whose ``required_connector`` is missing.
|
|
||||||
# No-op today (connector subagents are pruned upstream); guards future
|
|
||||||
# additions to the parent's tool list.
|
|
||||||
if permission_enabled:
|
|
||||||
_available_set = set(available_connectors or [])
|
|
||||||
_synthesized: list[Rule] = []
|
|
||||||
for tool_def in BUILTIN_TOOLS:
|
|
||||||
if tool_def.name not in _tool_names_in_use:
|
|
||||||
continue
|
|
||||||
rc = tool_def.required_connector
|
|
||||||
if rc and rc not in _available_set:
|
|
||||||
_synthesized.append(
|
|
||||||
Rule(permission=tool_def.name, pattern="*", action="deny")
|
|
||||||
)
|
|
||||||
if _synthesized:
|
|
||||||
permission_rulesets.append(
|
|
||||||
Ruleset(rules=_synthesized, origin="connector_synthesized")
|
|
||||||
)
|
|
||||||
gp_interrupt_on: dict[str, bool] = {
|
|
||||||
rule.permission: True
|
|
||||||
for rs in permission_rulesets
|
|
||||||
for rule in rs.rules
|
|
||||||
if rule.action == "ask" and rule.permission in _tool_names_in_use
|
|
||||||
}
|
|
||||||
|
|
||||||
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
|
|
||||||
**GENERAL_PURPOSE_SUBAGENT,
|
|
||||||
"model": llm,
|
|
||||||
"tools": tools,
|
|
||||||
"middleware": gp_middleware,
|
|
||||||
}
|
|
||||||
if gp_interrupt_on:
|
|
||||||
general_purpose_spec["interrupt_on"] = gp_interrupt_on
|
|
||||||
|
|
||||||
# Deny-only on subagents: ``task`` runs bypass the parent's
|
|
||||||
# PermissionMiddleware, while bucket-based ask gates own the ask path.
|
|
||||||
subagent_deny_rulesets: list[Ruleset] = [
|
|
||||||
Ruleset(
|
|
||||||
rules=[r for r in rs.rules if r.action == "deny"],
|
|
||||||
origin=rs.origin,
|
|
||||||
)
|
|
||||||
for rs in permission_rulesets
|
|
||||||
]
|
|
||||||
subagent_deny_rulesets = [rs for rs in subagent_deny_rulesets if rs.rules]
|
|
||||||
|
|
||||||
subagent_deny_permission_mw: PermissionMiddleware | None = (
|
|
||||||
PermissionMiddleware(rulesets=subagent_deny_rulesets)
|
|
||||||
if subagent_deny_rulesets
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if subagent_deny_permission_mw is not None:
|
|
||||||
# Run deny check on already-repaired tool calls; insert before
|
|
||||||
# PatchToolCallsMiddleware (append if the slot moves).
|
|
||||||
_patch_idx = next(
|
|
||||||
(
|
|
||||||
i
|
|
||||||
for i, m in enumerate(gp_middleware)
|
|
||||||
if isinstance(m, PatchToolCallsMiddleware)
|
|
||||||
),
|
|
||||||
len(gp_middleware),
|
|
||||||
)
|
|
||||||
gp_middleware.insert(_patch_idx, subagent_deny_permission_mw)
|
|
||||||
|
|
||||||
registry_subagents: list[SubAgent] = []
|
|
||||||
try:
|
|
||||||
subagent_extra_middleware: list[Any] = [
|
|
||||||
TodoListMiddleware(),
|
|
||||||
SurfSenseFilesystemMiddleware(
|
|
||||||
backend=backend_resolver,
|
|
||||||
filesystem_mode=filesystem_mode,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
created_by_id=user_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
if subagent_deny_permission_mw is not None:
|
|
||||||
subagent_extra_middleware.append(subagent_deny_permission_mw)
|
|
||||||
registry_subagents = build_subagents(
|
|
||||||
dependencies=subagent_dependencies,
|
|
||||||
model=llm,
|
|
||||||
extra_middleware=subagent_extra_middleware,
|
|
||||||
mcp_tools_by_agent=mcp_tools_by_agent or {},
|
|
||||||
exclude=get_subagents_to_exclude(available_connectors),
|
|
||||||
disabled_tools=disabled_tools,
|
|
||||||
)
|
|
||||||
logging.info(
|
|
||||||
"Registry subagents: %s",
|
|
||||||
[s["name"] for s in registry_subagents],
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logging.exception("Registry subagent build failed")
|
|
||||||
raise
|
|
||||||
|
|
||||||
subagent_specs: list[SubAgent] = [general_purpose_spec, *registry_subagents]
|
|
||||||
|
|
||||||
summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend)
|
|
||||||
|
|
||||||
context_edit_mw = None
|
|
||||||
if (
|
|
||||||
flags.enable_context_editing
|
|
||||||
and not flags.disable_new_agent_stack
|
|
||||||
and max_input_tokens
|
|
||||||
):
|
|
||||||
spill_edit = SpillToBackendEdit(
|
|
||||||
trigger=int(max_input_tokens * 0.55),
|
|
||||||
clear_at_least=int(max_input_tokens * 0.15),
|
|
||||||
keep=5,
|
|
||||||
exclude_tools=safe_exclude_tools(tools),
|
|
||||||
clear_tool_inputs=True,
|
|
||||||
)
|
|
||||||
clear_edit = ClearToolUsesEdit(
|
|
||||||
trigger=int(max_input_tokens * 0.55),
|
|
||||||
clear_at_least=int(max_input_tokens * 0.15),
|
|
||||||
keep=5,
|
|
||||||
exclude_tools=safe_exclude_tools(tools),
|
|
||||||
clear_tool_inputs=True,
|
|
||||||
placeholder="[cleared - older tool output trimmed for context]",
|
|
||||||
)
|
|
||||||
context_edit_mw = SpillingContextEditingMiddleware(
|
|
||||||
edits=[spill_edit, clear_edit],
|
|
||||||
backend_resolver=backend_resolver,
|
|
||||||
)
|
|
||||||
|
|
||||||
retry_mw = (
|
|
||||||
RetryAfterMiddleware(max_retries=3)
|
|
||||||
if flags.enable_retry_after and not flags.disable_new_agent_stack
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
fallback_mw: ModelFallbackMiddleware | None = None
|
|
||||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
|
||||||
try:
|
|
||||||
fallback_mw = ModelFallbackMiddleware(
|
|
||||||
"openai:gpt-4o-mini",
|
|
||||||
"anthropic:claude-3-5-haiku-20241022",
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
|
||||||
fallback_mw = None
|
|
||||||
model_call_limit_mw = (
|
|
||||||
ModelCallLimitMiddleware(
|
|
||||||
thread_limit=120,
|
|
||||||
run_limit=80,
|
|
||||||
exit_behavior="end",
|
|
||||||
)
|
|
||||||
if flags.enable_model_call_limit and not flags.disable_new_agent_stack
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
tool_call_limit_mw = (
|
|
||||||
ToolCallLimitMiddleware(
|
|
||||||
thread_limit=300, run_limit=80, exit_behavior="continue"
|
|
||||||
)
|
|
||||||
if flags.enable_tool_call_limit and not flags.disable_new_agent_stack
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
noop_mw = (
|
|
||||||
NoopInjectionMiddleware()
|
|
||||||
if flags.enable_compaction_v2 and not flags.disable_new_agent_stack
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
repair_mw = None
|
|
||||||
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
|
|
||||||
registered_names: set[str] = {t.name for t in tools}
|
|
||||||
registered_names |= {
|
|
||||||
"write_todos",
|
|
||||||
"ls",
|
|
||||||
"read_file",
|
|
||||||
"write_file",
|
|
||||||
"edit_file",
|
|
||||||
"glob",
|
|
||||||
"grep",
|
|
||||||
"execute",
|
|
||||||
"task",
|
|
||||||
"mkdir",
|
|
||||||
"cd",
|
|
||||||
"pwd",
|
|
||||||
"move_file",
|
|
||||||
"rm",
|
|
||||||
"rmdir",
|
|
||||||
"list_tree",
|
|
||||||
"execute_code",
|
|
||||||
}
|
|
||||||
repair_mw = ToolCallNameRepairMiddleware(
|
|
||||||
registered_tool_names=registered_names,
|
|
||||||
fuzzy_match_threshold=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
doom_loop_mw = (
|
|
||||||
DoomLoopMiddleware(threshold=3)
|
|
||||||
if flags.enable_doom_loop and not flags.disable_new_agent_stack
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
permission_mw: PermissionMiddleware | None = (
|
|
||||||
PermissionMiddleware(rulesets=permission_rulesets)
|
|
||||||
if permission_rulesets
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
action_log_mw: ActionLogMiddleware | None = None
|
|
||||||
if (
|
|
||||||
flags.enable_action_log
|
|
||||||
and not flags.disable_new_agent_stack
|
|
||||||
and thread_id is not None
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
|
|
||||||
action_log_mw = ActionLogMiddleware(
|
|
||||||
thread_id=thread_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
tool_definitions=tool_defs_by_name,
|
|
||||||
)
|
|
||||||
except Exception: # pragma: no cover - defensive
|
|
||||||
logging.warning(
|
|
||||||
"ActionLogMiddleware init failed; running without it.",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
action_log_mw = None
|
|
||||||
|
|
||||||
busy_mutex_mw: BusyMutexMiddleware | None = (
|
|
||||||
BusyMutexMiddleware()
|
|
||||||
if flags.enable_busy_mutex and not flags.disable_new_agent_stack
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
otel_mw: OtelSpanMiddleware | None = (
|
|
||||||
OtelSpanMiddleware()
|
|
||||||
if flags.enable_otel and not flags.disable_new_agent_stack
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
plugin_middlewares: list[Any] = []
|
|
||||||
if flags.enable_plugin_loader and not flags.disable_new_agent_stack:
|
|
||||||
try:
|
|
||||||
allowed_names = load_allowed_plugin_names_from_env()
|
|
||||||
if allowed_names:
|
|
||||||
plugin_middlewares = load_plugin_middlewares(
|
|
||||||
PluginContext.build(
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_id=user_id,
|
|
||||||
thread_visibility=visibility,
|
|
||||||
llm=llm,
|
|
||||||
),
|
|
||||||
allowed_plugin_names=allowed_names,
|
|
||||||
)
|
|
||||||
except Exception: # pragma: no cover - defensive
|
|
||||||
logging.warning(
|
|
||||||
"Plugin loader failed; continuing without plugins.",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
plugin_middlewares = []
|
|
||||||
|
|
||||||
skills_mw: SkillsMiddleware | None = None
|
|
||||||
if flags.enable_skills and not flags.disable_new_agent_stack:
|
|
||||||
try:
|
|
||||||
skills_factory = build_skills_backend_factory(
|
|
||||||
search_space_id=search_space_id
|
|
||||||
if filesystem_mode == FilesystemMode.CLOUD
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
skills_mw = SkillsMiddleware(
|
|
||||||
backend=skills_factory,
|
|
||||||
sources=default_skills_sources(),
|
|
||||||
)
|
|
||||||
except Exception as exc: # pragma: no cover - defensive
|
|
||||||
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
|
|
||||||
skills_mw = None
|
|
||||||
|
|
||||||
selector_mw: LLMToolSelectorMiddleware | None = None
|
|
||||||
if (
|
|
||||||
flags.enable_llm_tool_selector
|
|
||||||
and not flags.disable_new_agent_stack
|
|
||||||
and len(tools) > 30
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
selector_mw = LLMToolSelectorMiddleware(
|
|
||||||
model="openai:gpt-4o-mini",
|
|
||||||
max_tools=12,
|
|
||||||
always_include=[
|
|
||||||
name
|
|
||||||
for name in (
|
|
||||||
"update_memory",
|
|
||||||
"get_connected_accounts",
|
|
||||||
"scrape_webpage",
|
|
||||||
)
|
|
||||||
if name in {t.name for t in tools}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
|
|
||||||
selector_mw = None
|
|
||||||
|
|
||||||
deepagent_middleware = [
|
|
||||||
busy_mutex_mw,
|
|
||||||
otel_mw,
|
|
||||||
TodoListMiddleware(),
|
|
||||||
_memory_middleware,
|
|
||||||
AnonymousDocumentMiddleware(
|
|
||||||
anon_session_id=anon_session_id,
|
|
||||||
)
|
|
||||||
if filesystem_mode == FilesystemMode.CLOUD
|
|
||||||
else None,
|
|
||||||
KnowledgeTreeMiddleware(
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
filesystem_mode=filesystem_mode,
|
|
||||||
llm=llm,
|
|
||||||
)
|
|
||||||
if filesystem_mode == FilesystemMode.CLOUD
|
|
||||||
else None,
|
|
||||||
KnowledgePriorityMiddleware(
|
|
||||||
llm=llm,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
filesystem_mode=filesystem_mode,
|
|
||||||
available_connectors=available_connectors,
|
|
||||||
available_document_types=available_document_types,
|
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
|
||||||
),
|
|
||||||
FileIntentMiddleware(llm=llm),
|
|
||||||
SurfSenseFilesystemMiddleware(
|
|
||||||
backend=backend_resolver,
|
|
||||||
filesystem_mode=filesystem_mode,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
created_by_id=user_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
),
|
|
||||||
KnowledgeBasePersistenceMiddleware(
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
created_by_id=user_id,
|
|
||||||
filesystem_mode=filesystem_mode,
|
|
||||||
thread_id=thread_id,
|
|
||||||
)
|
|
||||||
if filesystem_mode == FilesystemMode.CLOUD
|
|
||||||
else None,
|
|
||||||
skills_mw,
|
|
||||||
SurfSenseCheckpointedSubAgentMiddleware(
|
|
||||||
checkpointer=checkpointer,
|
|
||||||
backend=StateBackend,
|
|
||||||
subagents=subagent_specs,
|
|
||||||
),
|
|
||||||
selector_mw,
|
|
||||||
model_call_limit_mw,
|
|
||||||
tool_call_limit_mw,
|
|
||||||
context_edit_mw,
|
|
||||||
summarization_mw,
|
|
||||||
noop_mw,
|
|
||||||
retry_mw,
|
|
||||||
fallback_mw,
|
|
||||||
repair_mw,
|
|
||||||
permission_mw,
|
|
||||||
doom_loop_mw,
|
|
||||||
action_log_mw,
|
|
||||||
PatchToolCallsMiddleware(),
|
|
||||||
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
|
|
||||||
*plugin_middlewares,
|
|
||||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
|
||||||
]
|
|
||||||
return [m for m in deepagent_middleware if m is not None]
|
|
||||||
|
|
@ -2,6 +2,6 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .factory import create_surfsense_deep_agent
|
from .factory import create_multi_agent_chat_deep_agent
|
||||||
|
|
||||||
__all__ = ["create_surfsense_deep_agent"]
|
__all__ = ["create_multi_agent_chat_deep_agent"]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
"""Compiled agent graph caching for the multi-agent path."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
|
||||||
|
from app.agents.new_chat.agent_cache import (
|
||||||
|
flags_signature,
|
||||||
|
get_cache,
|
||||||
|
stable_hash,
|
||||||
|
system_prompt_hash,
|
||||||
|
tools_signature,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
|
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||||
|
|
||||||
|
|
||||||
|
def mcp_signature(mcp_tools_by_agent: dict[str, ToolsPermissions]) -> str:
|
||||||
|
"""Hash the per-agent MCP tool surface so a change rotates the cache key."""
|
||||||
|
rows = []
|
||||||
|
for agent_name in sorted(mcp_tools_by_agent.keys()):
|
||||||
|
perms = mcp_tools_by_agent[agent_name]
|
||||||
|
allow_names = sorted(item.get("name", "") for item in perms.get("allow", []))
|
||||||
|
ask_names = sorted(item.get("name", "") for item in perms.get("ask", []))
|
||||||
|
rows.append((agent_name, allow_names, ask_names))
|
||||||
|
return stable_hash(rows)
|
||||||
|
|
||||||
|
|
||||||
|
async def build_agent_with_cache(
|
||||||
|
*,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
final_system_prompt: str,
|
||||||
|
backend_resolver: Any,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
thread_id: int | None,
|
||||||
|
visibility: ChatVisibility,
|
||||||
|
anon_session_id: str | None,
|
||||||
|
available_connectors: list[str],
|
||||||
|
available_document_types: list[str],
|
||||||
|
mentioned_document_ids: list[int] | None,
|
||||||
|
max_input_tokens: int | None,
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
checkpointer: Checkpointer,
|
||||||
|
subagent_dependencies: dict[str, Any],
|
||||||
|
mcp_tools_by_agent: dict[str, ToolsPermissions],
|
||||||
|
disabled_tools: list[str] | None,
|
||||||
|
config_id: str | None,
|
||||||
|
) -> Any:
|
||||||
|
"""Compile the multi-agent graph, serving from cache when key components are stable."""
|
||||||
|
|
||||||
|
async def _build() -> Any:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
build_compiled_agent_graph_sync,
|
||||||
|
llm=llm,
|
||||||
|
tools=tools,
|
||||||
|
final_system_prompt=final_system_prompt,
|
||||||
|
backend_resolver=backend_resolver,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
visibility=visibility,
|
||||||
|
anon_session_id=anon_session_id,
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
available_document_types=available_document_types,
|
||||||
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
|
max_input_tokens=max_input_tokens,
|
||||||
|
flags=flags,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
subagent_dependencies=subagent_dependencies,
|
||||||
|
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||||
|
disabled_tools=disabled_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (flags.enable_agent_cache and not flags.disable_new_agent_stack):
|
||||||
|
return await _build()
|
||||||
|
|
||||||
|
# Every per-request value any middleware closes over at __init__ must be in
|
||||||
|
# the key, otherwise a hit will leak state across threads. Bump the schema
|
||||||
|
# version when the component list changes shape.
|
||||||
|
cache_key = stable_hash(
|
||||||
|
"multi-agent-v1",
|
||||||
|
config_id,
|
||||||
|
thread_id,
|
||||||
|
user_id,
|
||||||
|
search_space_id,
|
||||||
|
visibility,
|
||||||
|
filesystem_mode,
|
||||||
|
anon_session_id,
|
||||||
|
tools_signature(
|
||||||
|
tools,
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
available_document_types=available_document_types,
|
||||||
|
),
|
||||||
|
mcp_signature(mcp_tools_by_agent),
|
||||||
|
flags_signature(flags),
|
||||||
|
system_prompt_hash(final_system_prompt),
|
||||||
|
max_input_tokens,
|
||||||
|
sorted(disabled_tools) if disabled_tools else None,
|
||||||
|
)
|
||||||
|
return await get_cache().get_or_build(cache_key, builder=_build)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["build_agent_with_cache", "mcp_signature"]
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
@ -26,23 +25,24 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
||||||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.new_chat.llm_config import AgentConfig
|
||||||
|
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||||
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
|
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
|
||||||
from app.agents.new_chat.tools.registry import build_tools_async
|
from app.agents.new_chat.tools.registry import build_tools_async
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
|
||||||
from ..system_prompt import build_main_agent_system_prompt
|
from ..system_prompt import build_main_agent_system_prompt
|
||||||
from ..tools import (
|
from ..tools import (
|
||||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
|
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
|
||||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
|
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
|
||||||
)
|
)
|
||||||
|
from .agent_cache import build_agent_with_cache
|
||||||
|
|
||||||
_perf_log = get_perf_logger()
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
async def create_surfsense_deep_agent(
|
async def create_multi_agent_chat_deep_agent(
|
||||||
llm: BaseChatModel,
|
llm: BaseChatModel,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
|
|
@ -62,6 +62,9 @@ async def create_surfsense_deep_agent(
|
||||||
):
|
):
|
||||||
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled."""
|
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled."""
|
||||||
_t_agent_total = time.perf_counter()
|
_t_agent_total = time.perf_counter()
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
|
||||||
|
|
||||||
filesystem_selection = filesystem_selection or FilesystemSelection()
|
filesystem_selection = filesystem_selection or FilesystemSelection()
|
||||||
backend_resolver = build_backend_resolver(
|
backend_resolver = build_backend_resolver(
|
||||||
filesystem_selection,
|
filesystem_selection,
|
||||||
|
|
@ -85,7 +88,18 @@ async def create_surfsense_deep_agent(
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("Failed to discover available connectors/document types: %s", e)
|
logging.warning(
|
||||||
|
"Connector/doc-type discovery failed; excluding connector subagents this turn: %s",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fail closed: a None list short-circuits ``get_subagents_to_exclude`` to "exclude
|
||||||
|
# nothing", which would silently advertise every connector specialist on a flaky
|
||||||
|
# discovery call. Empty list excludes connector-gated subagents while keeping builtins.
|
||||||
|
if available_connectors is None:
|
||||||
|
available_connectors = []
|
||||||
|
if available_document_types is None:
|
||||||
|
available_document_types = []
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] Connector/doc-type discovery in %.3fs",
|
"[create_agent] Connector/doc-type discovery in %.3fs",
|
||||||
time.perf_counter() - _t0,
|
time.perf_counter() - _t0,
|
||||||
|
|
@ -115,7 +129,16 @@ async def create_surfsense_deep_agent(
|
||||||
}
|
}
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
mcp_tools_by_agent = await load_mcp_tools_by_connector(db_session, search_space_id)
|
try:
|
||||||
|
mcp_tools_by_agent = await load_mcp_tools_by_connector(db_session, search_space_id)
|
||||||
|
except Exception as e:
|
||||||
|
# Degrade to builtins-only rather than aborting the turn: a transient
|
||||||
|
# DB or MCP-server hiccup should not deny the user a response.
|
||||||
|
logging.warning(
|
||||||
|
"MCP tool discovery failed; subagents will run without MCP tools this turn: %s",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
mcp_tools_by_agent = {}
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] load_mcp_tools_by_connector in %.3fs (%d buckets)",
|
"[create_agent] load_mcp_tools_by_connector in %.3fs (%d buckets)",
|
||||||
time.perf_counter() - _t0,
|
time.perf_counter() - _t0,
|
||||||
|
|
@ -195,9 +218,10 @@ async def create_surfsense_deep_agent(
|
||||||
|
|
||||||
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
||||||
|
|
||||||
|
config_id = agent_config.config_id if agent_config is not None else None
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
agent = await asyncio.to_thread(
|
agent = await build_agent_with_cache(
|
||||||
build_compiled_agent_graph_sync,
|
|
||||||
llm=llm,
|
llm=llm,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
final_system_prompt=final_system_prompt,
|
final_system_prompt=final_system_prompt,
|
||||||
|
|
@ -217,6 +241,7 @@ async def create_surfsense_deep_agent(
|
||||||
subagent_dependencies=dependencies,
|
subagent_dependencies=dependencies,
|
||||||
mcp_tools_by_agent=mcp_tools_by_agent,
|
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||||
disabled_tools=disabled_tools,
|
disabled_tools=disabled_tools,
|
||||||
|
config_id=config_id,
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
"""Multi-agent middleware stack assembly."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .stack import build_main_agent_deepagent_middleware
|
||||||
|
|
||||||
|
__all__ = ["build_main_agent_deepagent_middleware"]
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
"""Audit row per tool call (reversibility metadata)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.middleware import ActionLogMiddleware
|
||||||
|
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
||||||
|
|
||||||
|
from ..shared.flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
def build_action_log_mw(
|
||||||
|
*,
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
thread_id: int | None,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
) -> ActionLogMiddleware | None:
|
||||||
|
if not enabled(flags, "enable_action_log") or thread_id is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
|
||||||
|
return ActionLogMiddleware(
|
||||||
|
thread_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
tool_definitions=tool_defs_by_name,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
logging.warning(
|
||||||
|
"ActionLogMiddleware init failed; running without it.",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Anonymous document hydration from Redis (cloud only)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.middleware import AnonymousDocumentMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
def build_anonymous_doc_mw(
|
||||||
|
*,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
anon_session_id: str | None,
|
||||||
|
) -> AnonymousDocumentMiddleware | None:
|
||||||
|
if filesystem_mode != FilesystemMode.CLOUD:
|
||||||
|
return None
|
||||||
|
return AnonymousDocumentMiddleware(anon_session_id=anon_session_id)
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
"""Per-thread cooperative lock around the whole turn."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.middleware import BusyMutexMiddleware
|
||||||
|
|
||||||
|
from ..shared.flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
def build_busy_mutex_mw(flags: AgentFeatureFlags) -> BusyMutexMiddleware | None:
|
||||||
|
return BusyMutexMiddleware() if enabled(flags, "enable_busy_mutex") else None
|
||||||
|
|
@ -69,9 +69,16 @@ def build_task_tool_with_parent_config(
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
state_update = {k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS}
|
state_update = {k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS}
|
||||||
message_text = (
|
messages = result["messages"]
|
||||||
result["messages"][-1].text.rstrip() if result["messages"][-1].text else ""
|
if not messages:
|
||||||
)
|
msg = (
|
||||||
|
"CompiledSubAgent returned an empty 'messages' list. "
|
||||||
|
"Subagents must produce at least one message so the parent has "
|
||||||
|
"output to forward back to the user."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
last_text = getattr(messages[-1], "text", None) or ""
|
||||||
|
message_text = last_text.rstrip()
|
||||||
return Command(
|
return Command(
|
||||||
update={
|
update={
|
||||||
**state_update,
|
**state_update,
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
"""Spill + clear-tool-uses passes to keep payloads under budget."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.main_agent.context_prune.prune_tool_names import (
|
||||||
|
safe_exclude_tools,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.middleware import (
|
||||||
|
ClearToolUsesEdit,
|
||||||
|
SpillingContextEditingMiddleware,
|
||||||
|
SpillToBackendEdit,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..shared.flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
def build_context_editing_mw(
|
||||||
|
*,
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
max_input_tokens: int | None,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
backend_resolver: Any,
|
||||||
|
) -> SpillingContextEditingMiddleware | None:
|
||||||
|
if not enabled(flags, "enable_context_editing") or not max_input_tokens:
|
||||||
|
return None
|
||||||
|
spill_edit = SpillToBackendEdit(
|
||||||
|
trigger=int(max_input_tokens * 0.55),
|
||||||
|
clear_at_least=int(max_input_tokens * 0.15),
|
||||||
|
keep=5,
|
||||||
|
exclude_tools=safe_exclude_tools(tools),
|
||||||
|
clear_tool_inputs=True,
|
||||||
|
)
|
||||||
|
clear_edit = ClearToolUsesEdit(
|
||||||
|
trigger=int(max_input_tokens * 0.55),
|
||||||
|
clear_at_least=int(max_input_tokens * 0.15),
|
||||||
|
keep=5,
|
||||||
|
exclude_tools=safe_exclude_tools(tools),
|
||||||
|
clear_tool_inputs=True,
|
||||||
|
placeholder="[cleared - older tool output trimmed for context]",
|
||||||
|
)
|
||||||
|
return SpillingContextEditingMiddleware(
|
||||||
|
edits=[spill_edit, clear_edit],
|
||||||
|
backend_resolver=backend_resolver,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
"""Drop duplicate HITL tool calls before execution."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
|
||||||
|
return DedupHITLToolCallsMiddleware(agent_tools=list(tools))
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
"""Stop N identical tool calls in a row via interrupt."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.middleware import DoomLoopMiddleware
|
||||||
|
|
||||||
|
from ..shared.flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
def build_doom_loop_mw(flags: AgentFeatureFlags) -> DoomLoopMiddleware | None:
|
||||||
|
return DoomLoopMiddleware(threshold=3) if enabled(flags, "enable_doom_loop") else None
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
"""Commit staged cloud filesystem mutations to Postgres at end of turn."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.middleware import KnowledgeBasePersistenceMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
def build_kb_persistence_mw(
|
||||||
|
*,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
thread_id: int | None,
|
||||||
|
) -> KnowledgeBasePersistenceMiddleware | None:
|
||||||
|
if filesystem_mode != FilesystemMode.CLOUD:
|
||||||
|
return None
|
||||||
|
return KnowledgeBasePersistenceMiddleware(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""KB priority planner: <priority_documents> injection."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.middleware import KnowledgePriorityMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
def build_knowledge_priority_mw(
|
||||||
|
*,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
search_space_id: int,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
available_connectors: list[str] | None,
|
||||||
|
available_document_types: list[str] | None,
|
||||||
|
mentioned_document_ids: list[int] | None,
|
||||||
|
) -> KnowledgePriorityMiddleware:
|
||||||
|
return KnowledgePriorityMiddleware(
|
||||||
|
llm=llm,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
available_document_types=available_document_types,
|
||||||
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
"""<workspace_tree> injection (cloud only)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.middleware import KnowledgeTreeMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
def build_knowledge_tree_mw(
|
||||||
|
*,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
search_space_id: int,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
) -> KnowledgeTreeMiddleware | None:
|
||||||
|
if filesystem_mode != FilesystemMode.CLOUD:
|
||||||
|
return None
|
||||||
|
return KnowledgeTreeMiddleware(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
"""Provider-compat: append a `_noop` tool when tools=[] but history has tool calls."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.middleware import NoopInjectionMiddleware
|
||||||
|
|
||||||
|
from ..shared.flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
def build_noop_injection_mw(flags: AgentFeatureFlags) -> NoopInjectionMiddleware | None:
|
||||||
|
return NoopInjectionMiddleware() if enabled(flags, "enable_compaction_v2") else None
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
"""OTel spans on model and tool calls."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.middleware import OtelSpanMiddleware
|
||||||
|
|
||||||
|
from ..shared.flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
def build_otel_mw(flags: AgentFeatureFlags) -> OtelSpanMiddleware | None:
|
||||||
|
return OtelSpanMiddleware() if enabled(flags, "enable_otel") else None
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
"""Tail-of-stack plugin slot driven by env allowlist."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.plugin_loader import (
|
||||||
|
PluginContext,
|
||||||
|
load_allowed_plugin_names_from_env,
|
||||||
|
load_plugin_middlewares,
|
||||||
|
)
|
||||||
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
|
from ..shared.flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
def build_plugin_middlewares(
|
||||||
|
*,
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
visibility: ChatVisibility,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
) -> list[Any]:
|
||||||
|
if not enabled(flags, "enable_plugin_loader"):
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
allowed_names = load_allowed_plugin_names_from_env()
|
||||||
|
if not allowed_names:
|
||||||
|
return []
|
||||||
|
return load_plugin_middlewares(
|
||||||
|
PluginContext.build(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_visibility=visibility,
|
||||||
|
llm=llm,
|
||||||
|
),
|
||||||
|
allowed_plugin_names=allowed_names,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
logging.warning(
|
||||||
|
"Plugin loader failed; continuing without plugins.",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
"""Repair miscased / unknown tool names to the registered set or invalid_tool."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.middleware import ToolCallNameRepairMiddleware
|
||||||
|
|
||||||
|
from ..shared.flags import enabled
|
||||||
|
|
||||||
|
# deepagents-built-in tool names the repair pass treats as known.
|
||||||
|
_DEEPAGENT_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"write_todos",
|
||||||
|
"ls",
|
||||||
|
"read_file",
|
||||||
|
"write_file",
|
||||||
|
"edit_file",
|
||||||
|
"glob",
|
||||||
|
"grep",
|
||||||
|
"execute",
|
||||||
|
"task",
|
||||||
|
"mkdir",
|
||||||
|
"cd",
|
||||||
|
"pwd",
|
||||||
|
"move_file",
|
||||||
|
"rm",
|
||||||
|
"rmdir",
|
||||||
|
"list_tree",
|
||||||
|
"execute_code",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_repair_mw(
|
||||||
|
*,
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
) -> ToolCallNameRepairMiddleware | None:
|
||||||
|
if not enabled(flags, "enable_tool_call_repair"):
|
||||||
|
return None
|
||||||
|
registered_names: set[str] = {t.name for t in tools}
|
||||||
|
registered_names |= _DEEPAGENT_BUILTIN_TOOL_NAMES
|
||||||
|
return ToolCallNameRepairMiddleware(
|
||||||
|
registered_tool_names=registered_names,
|
||||||
|
fuzzy_match_threshold=None,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
"""LLM-based tool subset selection (only when >30 tools)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
|
||||||
|
from ..shared.flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
def build_selector_mw(
|
||||||
|
*,
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
) -> LLMToolSelectorMiddleware | None:
|
||||||
|
if not enabled(flags, "enable_llm_tool_selector") or len(tools) <= 30:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return LLMToolSelectorMiddleware(
|
||||||
|
model="openai:gpt-4o-mini",
|
||||||
|
max_tools=12,
|
||||||
|
always_include=[
|
||||||
|
name
|
||||||
|
for name in (
|
||||||
|
"update_memory",
|
||||||
|
"get_connected_accounts",
|
||||||
|
"scrape_webpage",
|
||||||
|
)
|
||||||
|
if name in {t.name for t in tools}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
|
||||||
|
return None
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
"""Skill discovery + injection."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from deepagents.middleware.skills import SkillsMiddleware
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.middleware import (
|
||||||
|
build_skills_backend_factory,
|
||||||
|
default_skills_sources,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..shared.flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
def build_skills_mw(
|
||||||
|
*,
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> SkillsMiddleware | None:
|
||||||
|
if not enabled(flags, "enable_skills"):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
skills_factory = build_skills_backend_factory(
|
||||||
|
search_space_id=search_space_id
|
||||||
|
if filesystem_mode == FilesystemMode.CLOUD
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
return SkillsMiddleware(
|
||||||
|
backend=skills_factory,
|
||||||
|
sources=default_skills_sources(),
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
|
||||||
|
return None
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
"""Anthropic prompt caching annotations on system/tool/message blocks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
def build_anthropic_cache_mw() -> AnthropicPromptCachingMiddleware:
|
||||||
|
return AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
"""Context-window summarization with SurfSense protected sections."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from deepagents.backends import StateBackend
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware import create_surfsense_compaction_middleware
|
||||||
|
|
||||||
|
|
||||||
|
def build_compaction_mw(llm: BaseChatModel) -> Any:
|
||||||
|
return create_surfsense_compaction_middleware(llm, StateBackend)
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
"""File-intent classifier that gates strict write contracts."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware import FileIntentMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
def build_file_intent_mw(llm: BaseChatModel) -> FileIntentMiddleware:
|
||||||
|
return FileIntentMiddleware(llm=llm)
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
"""SurfSense filesystem tools/middleware."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.middleware import SurfSenseFilesystemMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
def build_filesystem_mw(
|
||||||
|
*,
|
||||||
|
backend_resolver: Any,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
thread_id: int | None,
|
||||||
|
) -> SurfSenseFilesystemMiddleware:
|
||||||
|
return SurfSenseFilesystemMiddleware(
|
||||||
|
backend=backend_resolver,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
"""Single source of truth for the feature-flag predicate."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
|
||||||
|
|
||||||
|
def enabled(flags: AgentFeatureFlags, attr: str) -> bool:
|
||||||
|
"""``flags.<attr>`` is on AND the new-agent-stack kill switch is off."""
|
||||||
|
return getattr(flags, attr) and not flags.disable_new_agent_stack
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
"""User/team memory injection prepended to the conversation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware import MemoryInjectionMiddleware
|
||||||
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
|
|
||||||
|
def build_memory_mw(
|
||||||
|
*,
|
||||||
|
user_id: str | None,
|
||||||
|
search_space_id: int,
|
||||||
|
visibility: ChatVisibility,
|
||||||
|
) -> MemoryInjectionMiddleware:
|
||||||
|
return MemoryInjectionMiddleware(
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
thread_visibility=visibility,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
"""Repair dangling tool-call sequences before each agent turn."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
def build_patch_tool_calls_mw() -> PatchToolCallsMiddleware:
|
||||||
|
return PatchToolCallsMiddleware()
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
"""Permission rulesets fanned out to parent / general-purpose / subagent stacks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .context import PermissionContext, build_permission_context
|
||||||
|
from .middleware import build_full_permission_mw
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PermissionContext",
|
||||||
|
"build_full_permission_mw",
|
||||||
|
"build_permission_context",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""Derive shared permission context once; fan out to all three stack layers.
|
||||||
|
|
||||||
|
The context carries:
|
||||||
|
- ``rulesets``: full ask/deny/allow rules for the main-agent permission middleware.
|
||||||
|
- ``general_purpose_interrupt_on``: ``ask`` rules mirrored as deepagents
|
||||||
|
``interrupt_on`` so HITL still triggers from inside ``task`` runs (subagents
|
||||||
|
bypass the main-agent permission middleware).
|
||||||
|
- ``subagent_deny_mw``: a deny-only ``PermissionMiddleware`` instance shared
|
||||||
|
across the general-purpose and registry subagent stacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.middleware import PermissionMiddleware
|
||||||
|
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||||
|
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
||||||
|
|
||||||
|
from ..flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PermissionContext:
|
||||||
|
rulesets: list[Ruleset]
|
||||||
|
general_purpose_interrupt_on: dict[str, bool]
|
||||||
|
subagent_deny_mw: PermissionMiddleware | None
|
||||||
|
|
||||||
|
|
||||||
|
def build_permission_context(
|
||||||
|
*,
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
available_connectors: list[str] | None,
|
||||||
|
) -> PermissionContext:
|
||||||
|
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||||
|
permission_enabled = enabled(flags, "enable_permission")
|
||||||
|
|
||||||
|
rulesets: list[Ruleset] = []
|
||||||
|
if permission_enabled or is_desktop_fs:
|
||||||
|
rulesets.append(
|
||||||
|
Ruleset(
|
||||||
|
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||||
|
origin="surfsense_defaults",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if is_desktop_fs:
|
||||||
|
rulesets.append(
|
||||||
|
Ruleset(
|
||||||
|
rules=[
|
||||||
|
Rule(permission="rm", pattern="*", action="ask"),
|
||||||
|
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||||
|
Rule(permission="move_file", pattern="*", action="ask"),
|
||||||
|
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||||
|
Rule(permission="write_file", pattern="*", action="ask"),
|
||||||
|
],
|
||||||
|
origin="desktop_safety",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_names_in_use = {t.name for t in tools}
|
||||||
|
|
||||||
|
if permission_enabled:
|
||||||
|
available_set = set(available_connectors or [])
|
||||||
|
synthesized: list[Rule] = []
|
||||||
|
for tool_def in BUILTIN_TOOLS:
|
||||||
|
if tool_def.name not in tool_names_in_use:
|
||||||
|
continue
|
||||||
|
rc = tool_def.required_connector
|
||||||
|
if rc and rc not in available_set:
|
||||||
|
synthesized.append(
|
||||||
|
Rule(permission=tool_def.name, pattern="*", action="deny")
|
||||||
|
)
|
||||||
|
if synthesized:
|
||||||
|
rulesets.append(
|
||||||
|
Ruleset(rules=synthesized, origin="connector_synthesized")
|
||||||
|
)
|
||||||
|
|
||||||
|
general_purpose_interrupt_on: dict[str, bool] = {
|
||||||
|
rule.permission: True
|
||||||
|
for rs in rulesets
|
||||||
|
for rule in rs.rules
|
||||||
|
if rule.action == "ask" and rule.permission in tool_names_in_use
|
||||||
|
}
|
||||||
|
|
||||||
|
deny_rulesets = [
|
||||||
|
Ruleset(
|
||||||
|
rules=[r for r in rs.rules if r.action == "deny"],
|
||||||
|
origin=rs.origin,
|
||||||
|
)
|
||||||
|
for rs in rulesets
|
||||||
|
]
|
||||||
|
deny_rulesets = [rs for rs in deny_rulesets if rs.rules]
|
||||||
|
|
||||||
|
subagent_deny_mw: PermissionMiddleware | None = (
|
||||||
|
PermissionMiddleware(rulesets=deny_rulesets) if deny_rulesets else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return PermissionContext(
|
||||||
|
rulesets=rulesets,
|
||||||
|
general_purpose_interrupt_on=general_purpose_interrupt_on,
|
||||||
|
subagent_deny_mw=subagent_deny_mw,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
"""Main-agent permission middleware (full ask/deny/allow rules)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware import PermissionMiddleware
|
||||||
|
from app.agents.new_chat.permissions import Ruleset
|
||||||
|
|
||||||
|
|
||||||
|
def build_full_permission_mw(rulesets: list[Ruleset]) -> PermissionMiddleware | None:
|
||||||
|
return PermissionMiddleware(rulesets=rulesets) if rulesets else None
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
"""Resilience middleware shared as the same instances across parent / general-purpose / registry."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .bundle import ResilienceBundle, build_resilience_bundle
|
||||||
|
|
||||||
|
__all__ = ["ResilienceBundle", "build_resilience_bundle"]
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
"""Construct each resilience middleware once; same instances flow into every consumer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import (
|
||||||
|
ModelCallLimitMiddleware,
|
||||||
|
ToolCallLimitMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.middleware import RetryAfterMiddleware
|
||||||
|
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||||
|
ScopedModelFallbackMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .fallback import build_fallback_mw
|
||||||
|
from .model_call_limit import build_model_call_limit_mw
|
||||||
|
from .retry import build_retry_mw
|
||||||
|
from .tool_call_limit import build_tool_call_limit_mw
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ResilienceBundle:
|
||||||
|
retry: RetryAfterMiddleware | None
|
||||||
|
fallback: ScopedModelFallbackMiddleware | None
|
||||||
|
model_call_limit: ModelCallLimitMiddleware | None
|
||||||
|
tool_call_limit: ToolCallLimitMiddleware | None
|
||||||
|
|
||||||
|
def as_list(self) -> list[Any]:
|
||||||
|
return [
|
||||||
|
m
|
||||||
|
for m in (
|
||||||
|
self.retry,
|
||||||
|
self.fallback,
|
||||||
|
self.model_call_limit,
|
||||||
|
self.tool_call_limit,
|
||||||
|
)
|
||||||
|
if m is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_resilience_bundle(flags: AgentFeatureFlags) -> ResilienceBundle:
|
||||||
|
return ResilienceBundle(
|
||||||
|
retry=build_retry_mw(flags),
|
||||||
|
fallback=build_fallback_mw(flags),
|
||||||
|
model_call_limit=build_model_call_limit_mw(flags),
|
||||||
|
tool_call_limit=build_tool_call_limit_mw(flags),
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""Switch to a fallback model on provider/network errors only."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||||
|
ScopedModelFallbackMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
def build_fallback_mw(
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
) -> ScopedModelFallbackMiddleware | None:
|
||||||
|
if not enabled(flags, "enable_model_fallback"):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return ScopedModelFallbackMiddleware(
|
||||||
|
"openai:gpt-4o-mini",
|
||||||
|
"anthropic:claude-3-5-haiku-20241022",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
||||||
|
return None
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
"""Cap model calls per thread / per run to prevent runaway cost."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain.agents.middleware import ModelCallLimitMiddleware
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
|
||||||
|
from ..flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
def build_model_call_limit_mw(
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
) -> ModelCallLimitMiddleware | None:
|
||||||
|
if not enabled(flags, "enable_model_call_limit"):
|
||||||
|
return None
|
||||||
|
return ModelCallLimitMiddleware(
|
||||||
|
thread_limit=120,
|
||||||
|
run_limit=80,
|
||||||
|
exit_behavior="end",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Retry on transient model errors (e.g. Retry-After-bearing 429s)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.middleware import RetryAfterMiddleware
|
||||||
|
|
||||||
|
from ..flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
def build_retry_mw(flags: AgentFeatureFlags) -> RetryAfterMiddleware | None:
|
||||||
|
return (
|
||||||
|
RetryAfterMiddleware(max_retries=3)
|
||||||
|
if enabled(flags, "enable_retry_after")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
"""Cap tool calls per thread / per run to bound infinite-loop blast radius."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain.agents.middleware import ToolCallLimitMiddleware
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
|
||||||
|
from ..flags import enabled
|
||||||
|
|
||||||
|
|
||||||
|
def build_tool_call_limit_mw(
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
) -> ToolCallLimitMiddleware | None:
|
||||||
|
if not enabled(flags, "enable_tool_call_limit"):
|
||||||
|
return None
|
||||||
|
return ToolCallLimitMiddleware(
|
||||||
|
thread_limit=300,
|
||||||
|
run_limit=80,
|
||||||
|
exit_behavior="continue",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
"""Todo-list middleware (each consumer needs its own instance)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain.agents.middleware import TodoListMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
def build_todos_mw() -> TodoListMiddleware:
|
||||||
|
return TodoListMiddleware()
|
||||||
|
|
@ -0,0 +1,216 @@
|
||||||
|
"""Main-agent middleware list assembly: one line per slot."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from deepagents import SubAgent
|
||||||
|
from deepagents.backends import StateBackend
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.subagents import (
|
||||||
|
build_subagents,
|
||||||
|
get_subagents_to_exclude,
|
||||||
|
)
|
||||||
|
from app.agents.multi_agent_chat.subagents.builtins.general_purpose.agent import (
|
||||||
|
build_subagent as build_general_purpose_subagent,
|
||||||
|
)
|
||||||
|
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
|
from .main_agent.action_log import build_action_log_mw
|
||||||
|
from .main_agent.anonymous_doc import build_anonymous_doc_mw
|
||||||
|
from .main_agent.busy_mutex import build_busy_mutex_mw
|
||||||
|
from .main_agent.checkpointed_subagent_middleware import (
|
||||||
|
SurfSenseCheckpointedSubAgentMiddleware,
|
||||||
|
)
|
||||||
|
from .main_agent.context_editing import build_context_editing_mw
|
||||||
|
from .main_agent.dedup_hitl import build_dedup_hitl_mw
|
||||||
|
from .main_agent.doom_loop import build_doom_loop_mw
|
||||||
|
from .main_agent.kb_persistence import build_kb_persistence_mw
|
||||||
|
from .main_agent.knowledge_priority import build_knowledge_priority_mw
|
||||||
|
from .main_agent.knowledge_tree import build_knowledge_tree_mw
|
||||||
|
from .main_agent.noop_injection import build_noop_injection_mw
|
||||||
|
from .main_agent.otel import build_otel_mw
|
||||||
|
from .main_agent.plugins import build_plugin_middlewares
|
||||||
|
from .main_agent.repair import build_repair_mw
|
||||||
|
from .main_agent.selector import build_selector_mw
|
||||||
|
from .main_agent.skills import build_skills_mw
|
||||||
|
from .shared.anthropic_cache import build_anthropic_cache_mw
|
||||||
|
from .shared.compaction import build_compaction_mw
|
||||||
|
from .shared.file_intent import build_file_intent_mw
|
||||||
|
from .shared.filesystem import build_filesystem_mw
|
||||||
|
from .shared.memory import build_memory_mw
|
||||||
|
from .shared.patch_tool_calls import build_patch_tool_calls_mw
|
||||||
|
from .shared.permissions import (
|
||||||
|
build_full_permission_mw,
|
||||||
|
build_permission_context,
|
||||||
|
)
|
||||||
|
from .shared.resilience import build_resilience_bundle
|
||||||
|
from .shared.todos import build_todos_mw
|
||||||
|
from .subagent.extras import build_subagent_extras
|
||||||
|
|
||||||
|
|
||||||
|
def build_main_agent_deepagent_middleware(
|
||||||
|
*,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
backend_resolver: Any,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
thread_id: int | None,
|
||||||
|
visibility: ChatVisibility,
|
||||||
|
anon_session_id: str | None,
|
||||||
|
available_connectors: list[str] | None,
|
||||||
|
available_document_types: list[str] | None,
|
||||||
|
mentioned_document_ids: list[int] | None,
|
||||||
|
max_input_tokens: int | None,
|
||||||
|
flags: AgentFeatureFlags,
|
||||||
|
subagent_dependencies: dict[str, Any],
|
||||||
|
checkpointer: Checkpointer,
|
||||||
|
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
|
||||||
|
disabled_tools: list[str] | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Ordered middleware for ``create_agent`` (None entries already stripped)."""
|
||||||
|
permissions = build_permission_context(
|
||||||
|
flags=flags,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
tools=tools,
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
)
|
||||||
|
resilience = build_resilience_bundle(flags)
|
||||||
|
|
||||||
|
# Single instance threaded into both the main-agent stack and the general-purpose subagent.
|
||||||
|
memory_mw = build_memory_mw(
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
visibility=visibility,
|
||||||
|
)
|
||||||
|
|
||||||
|
general_purpose_subagent = build_general_purpose_subagent(
|
||||||
|
llm=llm,
|
||||||
|
tools=tools,
|
||||||
|
backend_resolver=backend_resolver,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
permissions=permissions,
|
||||||
|
resilience=resilience,
|
||||||
|
memory_mw=memory_mw,
|
||||||
|
)
|
||||||
|
|
||||||
|
subagents_registry: list[SubAgent] = []
|
||||||
|
try:
|
||||||
|
subagent_extras = build_subagent_extras(
|
||||||
|
permissions=permissions,
|
||||||
|
resilience=resilience,
|
||||||
|
)
|
||||||
|
subagents_registry = build_subagents(
|
||||||
|
dependencies=subagent_dependencies,
|
||||||
|
model=llm,
|
||||||
|
extra_middleware=subagent_extras,
|
||||||
|
mcp_tools_by_agent=mcp_tools_by_agent or {},
|
||||||
|
exclude=get_subagents_to_exclude(available_connectors),
|
||||||
|
disabled_tools=disabled_tools,
|
||||||
|
)
|
||||||
|
logging.debug(
|
||||||
|
"Subagents registry: %s",
|
||||||
|
[s["name"] for s in subagents_registry],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Degrade to general-purpose-only rather than aborting the turn:
|
||||||
|
# one bad subagent dep should not deny the user a response.
|
||||||
|
logging.exception(
|
||||||
|
"Subagents registry build failed; falling back to general-purpose only"
|
||||||
|
)
|
||||||
|
subagents_registry = []
|
||||||
|
|
||||||
|
subagents: list[SubAgent] = [general_purpose_subagent, *subagents_registry]
|
||||||
|
|
||||||
|
stack: list[Any] = [
|
||||||
|
build_busy_mutex_mw(flags),
|
||||||
|
build_otel_mw(flags),
|
||||||
|
build_todos_mw(),
|
||||||
|
memory_mw,
|
||||||
|
build_anonymous_doc_mw(
|
||||||
|
filesystem_mode=filesystem_mode, anon_session_id=anon_session_id
|
||||||
|
),
|
||||||
|
build_knowledge_tree_mw(
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
llm=llm,
|
||||||
|
),
|
||||||
|
build_knowledge_priority_mw(
|
||||||
|
llm=llm,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
available_document_types=available_document_types,
|
||||||
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
|
),
|
||||||
|
build_file_intent_mw(llm),
|
||||||
|
build_filesystem_mw(
|
||||||
|
backend_resolver=backend_resolver,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
),
|
||||||
|
build_kb_persistence_mw(
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
),
|
||||||
|
build_skills_mw(
|
||||||
|
flags=flags,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
),
|
||||||
|
SurfSenseCheckpointedSubAgentMiddleware(
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
backend=StateBackend,
|
||||||
|
subagents=subagents,
|
||||||
|
),
|
||||||
|
build_selector_mw(flags=flags, tools=tools),
|
||||||
|
resilience.model_call_limit,
|
||||||
|
resilience.tool_call_limit,
|
||||||
|
build_context_editing_mw(
|
||||||
|
flags=flags,
|
||||||
|
max_input_tokens=max_input_tokens,
|
||||||
|
tools=tools,
|
||||||
|
backend_resolver=backend_resolver,
|
||||||
|
),
|
||||||
|
build_compaction_mw(llm),
|
||||||
|
build_noop_injection_mw(flags),
|
||||||
|
resilience.retry,
|
||||||
|
resilience.fallback,
|
||||||
|
build_repair_mw(flags=flags, tools=tools),
|
||||||
|
build_full_permission_mw(permissions.rulesets),
|
||||||
|
build_doom_loop_mw(flags),
|
||||||
|
build_action_log_mw(
|
||||||
|
flags=flags,
|
||||||
|
thread_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
),
|
||||||
|
build_patch_tool_calls_mw(),
|
||||||
|
build_dedup_hitl_mw(tools),
|
||||||
|
*build_plugin_middlewares(
|
||||||
|
flags=flags,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
visibility=visibility,
|
||||||
|
llm=llm,
|
||||||
|
),
|
||||||
|
build_anthropic_cache_mw(),
|
||||||
|
]
|
||||||
|
return [m for m in stack if m is not None]
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
"""Extra middleware threaded into every registry subagent's stack.
|
||||||
|
|
||||||
|
Registry subagents are scoped to one domain (deliverables, research, memory,
|
||||||
|
connectors, MCP) and never read or write the SurfSense filesystem — that
|
||||||
|
capability belongs to the main agent and is delegated to the general-purpose
|
||||||
|
subagent as an escape hatch. Keeping FS off the registry stacks avoids
|
||||||
|
polluting their tool surface with FS tools they never act on.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..shared.permissions import PermissionContext
|
||||||
|
from ..shared.resilience import ResilienceBundle
|
||||||
|
from ..shared.todos import build_todos_mw
|
||||||
|
|
||||||
|
|
||||||
|
def build_subagent_extras(
|
||||||
|
*,
|
||||||
|
permissions: PermissionContext,
|
||||||
|
resilience: ResilienceBundle,
|
||||||
|
) -> list[Any]:
|
||||||
|
extras: list[Any] = [build_todos_mw()]
|
||||||
|
if permissions.subagent_deny_mw is not None:
|
||||||
|
extras.append(permissions.subagent_deny_mw)
|
||||||
|
extras.extend(resilience.as_list())
|
||||||
|
return extras
|
||||||
|
|
@ -0,0 +1,105 @@
|
||||||
|
"""General-purpose subagent for the multi-agent main agent."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from deepagents import SubAgent
|
||||||
|
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||||
|
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||||
|
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.middleware.shared.anthropic_cache import (
|
||||||
|
build_anthropic_cache_mw,
|
||||||
|
)
|
||||||
|
from app.agents.multi_agent_chat.middleware.shared.compaction import (
|
||||||
|
build_compaction_mw,
|
||||||
|
)
|
||||||
|
from app.agents.multi_agent_chat.middleware.shared.file_intent import (
|
||||||
|
build_file_intent_mw,
|
||||||
|
)
|
||||||
|
from app.agents.multi_agent_chat.middleware.shared.filesystem import (
|
||||||
|
build_filesystem_mw,
|
||||||
|
)
|
||||||
|
from app.agents.multi_agent_chat.middleware.shared.patch_tool_calls import (
|
||||||
|
build_patch_tool_calls_mw,
|
||||||
|
)
|
||||||
|
from app.agents.multi_agent_chat.middleware.shared.permissions import (
|
||||||
|
PermissionContext,
|
||||||
|
)
|
||||||
|
from app.agents.multi_agent_chat.middleware.shared.resilience import (
|
||||||
|
ResilienceBundle,
|
||||||
|
)
|
||||||
|
from app.agents.multi_agent_chat.middleware.shared.todos import build_todos_mw
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.middleware import MemoryInjectionMiddleware
|
||||||
|
|
||||||
|
NAME = "general-purpose"
|
||||||
|
|
||||||
|
|
||||||
|
def build_subagent(
|
||||||
|
*,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
backend_resolver: Any,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
thread_id: int | None,
|
||||||
|
permissions: PermissionContext,
|
||||||
|
resilience: ResilienceBundle,
|
||||||
|
memory_mw: MemoryInjectionMiddleware,
|
||||||
|
) -> SubAgent:
|
||||||
|
"""Deny + resilience inserts encapsulated here so the orchestrator never mutates the list."""
|
||||||
|
middleware: list[Any] = [
|
||||||
|
build_todos_mw(),
|
||||||
|
memory_mw,
|
||||||
|
build_file_intent_mw(llm),
|
||||||
|
build_filesystem_mw(
|
||||||
|
backend_resolver=backend_resolver,
|
||||||
|
filesystem_mode=filesystem_mode,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
),
|
||||||
|
build_compaction_mw(llm),
|
||||||
|
build_patch_tool_calls_mw(),
|
||||||
|
build_anthropic_cache_mw(),
|
||||||
|
]
|
||||||
|
|
||||||
|
if permissions.subagent_deny_mw is not None:
|
||||||
|
patch_idx = next(
|
||||||
|
(
|
||||||
|
i
|
||||||
|
for i, m in enumerate(middleware)
|
||||||
|
if isinstance(m, PatchToolCallsMiddleware)
|
||||||
|
),
|
||||||
|
len(middleware),
|
||||||
|
)
|
||||||
|
middleware.insert(patch_idx, permissions.subagent_deny_mw)
|
||||||
|
|
||||||
|
resilience_mws = resilience.as_list()
|
||||||
|
if resilience_mws:
|
||||||
|
cache_idx = next(
|
||||||
|
(
|
||||||
|
i
|
||||||
|
for i, m in enumerate(middleware)
|
||||||
|
if isinstance(m, AnthropicPromptCachingMiddleware)
|
||||||
|
),
|
||||||
|
len(middleware),
|
||||||
|
)
|
||||||
|
for offset, mw in enumerate(resilience_mws):
|
||||||
|
middleware.insert(cache_idx + offset, mw)
|
||||||
|
|
||||||
|
spec: dict[str, Any] = {
|
||||||
|
**GENERAL_PURPOSE_SUBAGENT,
|
||||||
|
"model": llm,
|
||||||
|
"tools": tools,
|
||||||
|
"middleware": middleware,
|
||||||
|
}
|
||||||
|
if permissions.general_purpose_interrupt_on:
|
||||||
|
spec["interrupt_on"] = permissions.general_purpose_interrupt_on
|
||||||
|
return cast(SubAgent, spec)
|
||||||
|
|
@ -31,7 +31,6 @@ from langchain.agents import create_agent
|
||||||
from langchain.agents.middleware import (
|
from langchain.agents.middleware import (
|
||||||
LLMToolSelectorMiddleware,
|
LLMToolSelectorMiddleware,
|
||||||
ModelCallLimitMiddleware,
|
ModelCallLimitMiddleware,
|
||||||
ModelFallbackMiddleware,
|
|
||||||
TodoListMiddleware,
|
TodoListMiddleware,
|
||||||
ToolCallLimitMiddleware,
|
ToolCallLimitMiddleware,
|
||||||
)
|
)
|
||||||
|
|
@ -77,6 +76,9 @@ from app.agents.new_chat.middleware import (
|
||||||
create_surfsense_compaction_middleware,
|
create_surfsense_compaction_middleware,
|
||||||
default_skills_sources,
|
default_skills_sources,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||||
|
ScopedModelFallbackMiddleware,
|
||||||
|
)
|
||||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||||
from app.agents.new_chat.plugin_loader import (
|
from app.agents.new_chat.plugin_loader import (
|
||||||
PluginContext,
|
PluginContext,
|
||||||
|
|
@ -792,15 +794,15 @@ def _build_compiled_agent_blocking(
|
||||||
# Fallback chain — primary is the agent's own model; we add cheap
|
# Fallback chain — primary is the agent's own model; we add cheap
|
||||||
# alternatives. Off by default; only the first call site that
|
# alternatives. Off by default; only the first call site that
|
||||||
# configures the chain via env should enable it.
|
# configures the chain via env should enable it.
|
||||||
fallback_mw: ModelFallbackMiddleware | None = None
|
fallback_mw: ScopedModelFallbackMiddleware | None = None
|
||||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||||
try:
|
try:
|
||||||
fallback_mw = ModelFallbackMiddleware(
|
fallback_mw = ScopedModelFallbackMiddleware(
|
||||||
"openai:gpt-4o-mini",
|
"openai:gpt-4o-mini",
|
||||||
"anthropic:claude-3-5-haiku-20241022",
|
"anthropic:claude-3-5-haiku-20241022",
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
||||||
fallback_mw = None
|
fallback_mw = None
|
||||||
model_call_limit_mw = (
|
model_call_limit_mw = (
|
||||||
ModelCallLimitMiddleware(
|
ModelCallLimitMiddleware(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
"""Fallback only on provider/network errors; let programming bugs raise."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
|
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
|
||||||
|
# Matched by class name across the MRO so we don't have to import every
|
||||||
|
# provider SDK (openai/anthropic/google/...). Extend as new providers ship.
|
||||||
|
_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"RateLimitError",
|
||||||
|
"APIStatusError",
|
||||||
|
"InternalServerError",
|
||||||
|
"ServiceUnavailableError",
|
||||||
|
"BadGatewayError",
|
||||||
|
"GatewayTimeoutError",
|
||||||
|
"APIConnectionError",
|
||||||
|
"APITimeoutError",
|
||||||
|
"ConnectError",
|
||||||
|
"ConnectTimeout",
|
||||||
|
"ReadTimeout",
|
||||||
|
"RemoteProtocolError",
|
||||||
|
"TimeoutError",
|
||||||
|
"TimeoutException",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_fallback_eligible(exc: BaseException) -> bool:
|
||||||
|
return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__)
|
||||||
|
|
||||||
|
|
||||||
|
class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
|
||||||
|
"""Re-raise non-provider exceptions instead of walking the fallback chain."""
|
||||||
|
|
||||||
|
def wrap_model_call( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest[Any],
|
||||||
|
handler: Callable[[ModelRequest[Any]], ModelResponse[Any]],
|
||||||
|
) -> ModelResponse[Any] | AIMessage:
|
||||||
|
last_exception: Exception
|
||||||
|
try:
|
||||||
|
return handler(request)
|
||||||
|
except Exception as e:
|
||||||
|
if not _is_fallback_eligible(e):
|
||||||
|
raise
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
for fallback_model in self.models:
|
||||||
|
try:
|
||||||
|
return handler(request.override(model=fallback_model))
|
||||||
|
except Exception as e:
|
||||||
|
if not _is_fallback_eligible(e):
|
||||||
|
raise
|
||||||
|
last_exception = e
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
async def awrap_model_call( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest[Any],
|
||||||
|
handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]],
|
||||||
|
) -> ModelResponse[Any] | AIMessage:
|
||||||
|
last_exception: Exception
|
||||||
|
try:
|
||||||
|
return await handler(request)
|
||||||
|
except Exception as e:
|
||||||
|
if not _is_fallback_eligible(e):
|
||||||
|
raise
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
for fallback_model in self.models:
|
||||||
|
try:
|
||||||
|
return await handler(request.override(model=fallback_model))
|
||||||
|
except Exception as e:
|
||||||
|
if not _is_fallback_eligible(e):
|
||||||
|
raise
|
||||||
|
last_exception = e
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise last_exception
|
||||||
|
|
@ -28,9 +28,7 @@ from langchain_core.messages import HumanMessage
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from app.agents.multi_agent_chat import (
|
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
|
||||||
create_surfsense_deep_agent as create_registry_deep_agent,
|
|
||||||
)
|
|
||||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||||
|
|
@ -577,6 +575,43 @@ async def _preflight_llm(llm: Any) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_main_agent_for_thread(
|
||||||
|
agent_factory: Any,
|
||||||
|
*,
|
||||||
|
llm: Any,
|
||||||
|
search_space_id: int,
|
||||||
|
db_session: Any,
|
||||||
|
connector_service: ConnectorService,
|
||||||
|
checkpointer: Any,
|
||||||
|
user_id: str | None,
|
||||||
|
thread_id: int | None,
|
||||||
|
agent_config: AgentConfig | None,
|
||||||
|
firecrawl_api_key: str | None,
|
||||||
|
thread_visibility: ChatVisibility | None,
|
||||||
|
filesystem_selection: FilesystemSelection | None,
|
||||||
|
disabled_tools: list[str] | None = None,
|
||||||
|
mentioned_document_ids: list[int] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Single (re)build path so the agent factory cannot drift across
|
||||||
|
initial build, preflight repin, and mid-stream 429 recovery for one
|
||||||
|
``thread_id``: a graph swap mid-turn would corrupt checkpointer state."""
|
||||||
|
return await agent_factory(
|
||||||
|
llm=llm,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
db_session=db_session,
|
||||||
|
connector_service=connector_service,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
user_id=user_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
agent_config=agent_config,
|
||||||
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
|
thread_visibility=thread_visibility,
|
||||||
|
filesystem_selection=filesystem_selection,
|
||||||
|
disabled_tools=disabled_tools,
|
||||||
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
|
async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
|
||||||
"""Wait for a discarded speculative agent build to release shared state.
|
"""Wait for a discarded speculative agent build to release shared state.
|
||||||
|
|
||||||
|
|
@ -2767,7 +2802,7 @@ async def stream_new_chat(
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
agent_factory = (
|
agent_factory = (
|
||||||
create_registry_deep_agent
|
create_multi_agent_chat_deep_agent
|
||||||
if use_multi_agent
|
if use_multi_agent
|
||||||
else create_surfsense_deep_agent
|
else create_surfsense_deep_agent
|
||||||
)
|
)
|
||||||
|
|
@ -2776,7 +2811,8 @@ async def stream_new_chat(
|
||||||
# if preflight reports 429 we will discard this future and rebuild
|
# if preflight reports 429 we will discard this future and rebuild
|
||||||
# against the freshly pinned config below.
|
# against the freshly pinned config below.
|
||||||
agent_build_task = asyncio.create_task(
|
agent_build_task = asyncio.create_task(
|
||||||
agent_factory(
|
_build_main_agent_for_thread(
|
||||||
|
agent_factory,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
db_session=session,
|
db_session=session,
|
||||||
|
|
@ -2787,9 +2823,9 @@ async def stream_new_chat(
|
||||||
agent_config=agent_config,
|
agent_config=agent_config,
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
thread_visibility=visibility,
|
thread_visibility=visibility,
|
||||||
|
filesystem_selection=filesystem_selection,
|
||||||
disabled_tools=disabled_tools,
|
disabled_tools=disabled_tools,
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
filesystem_selection=filesystem_selection,
|
|
||||||
),
|
),
|
||||||
name="agent_build:stream_new_chat",
|
name="agent_build:stream_new_chat",
|
||||||
)
|
)
|
||||||
|
|
@ -3466,7 +3502,8 @@ async def stream_new_chat(
|
||||||
title_task = None
|
title_task = None
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
agent = await create_surfsense_deep_agent(
|
agent = await _build_main_agent_for_thread(
|
||||||
|
agent_factory,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
db_session=session,
|
db_session=session,
|
||||||
|
|
@ -3477,9 +3514,9 @@ async def stream_new_chat(
|
||||||
agent_config=agent_config,
|
agent_config=agent_config,
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
thread_visibility=visibility,
|
thread_visibility=visibility,
|
||||||
|
filesystem_selection=filesystem_selection,
|
||||||
disabled_tools=disabled_tools,
|
disabled_tools=disabled_tools,
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
filesystem_selection=filesystem_selection,
|
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_new_chat] Runtime rate-limit recovery repinned "
|
"[stream_new_chat] Runtime rate-limit recovery repinned "
|
||||||
|
|
@ -4130,12 +4167,13 @@ async def stream_resume_chat(
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
agent_factory = (
|
agent_factory = (
|
||||||
create_registry_deep_agent
|
create_multi_agent_chat_deep_agent
|
||||||
if _app_config.MULTI_AGENT_CHAT_ENABLED
|
if _app_config.MULTI_AGENT_CHAT_ENABLED
|
||||||
else create_surfsense_deep_agent
|
else create_surfsense_deep_agent
|
||||||
)
|
)
|
||||||
agent_build_task = asyncio.create_task(
|
agent_build_task = asyncio.create_task(
|
||||||
agent_factory(
|
_build_main_agent_for_thread(
|
||||||
|
agent_factory,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
db_session=session,
|
db_session=session,
|
||||||
|
|
@ -4224,7 +4262,8 @@ async def stream_resume_chat(
|
||||||
"fallback_config_id": llm_config_id,
|
"fallback_config_id": llm_config_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
agent = await agent_factory(
|
agent = await _build_main_agent_for_thread(
|
||||||
|
agent_factory,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
db_session=session,
|
db_session=session,
|
||||||
|
|
@ -4409,7 +4448,8 @@ async def stream_resume_chat(
|
||||||
raise stream_exc
|
raise stream_exc
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
agent = await create_surfsense_deep_agent(
|
agent = await _build_main_agent_for_thread(
|
||||||
|
agent_factory,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
db_session=session,
|
db_session=session,
|
||||||
|
|
@ -4421,6 +4461,7 @@ async def stream_resume_chat(
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
thread_visibility=visibility,
|
thread_visibility=visibility,
|
||||||
filesystem_selection=filesystem_selection,
|
filesystem_selection=filesystem_selection,
|
||||||
|
disabled_tools=disabled_tools,
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_resume] Runtime rate-limit recovery repinned "
|
"[stream_resume] Runtime rate-limit recovery repinned "
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,208 @@
|
||||||
|
"""End-to-end resume-bridge tests against a real LangGraph subagent."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain.tools import ToolRuntime
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
from langgraph.types import Command, interrupt
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||||
|
build_task_tool_with_parent_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _SubagentState(TypedDict, total=False):
|
||||||
|
messages: list
|
||||||
|
decision_text: str
|
||||||
|
|
||||||
|
|
||||||
|
def _build_single_interrupt_subagent():
|
||||||
|
def approve_node(state):
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
decision = interrupt(
|
||||||
|
{
|
||||||
|
"action_requests": [
|
||||||
|
{
|
||||||
|
"name": "do_thing",
|
||||||
|
"args": {"x": 1},
|
||||||
|
"description": "test action",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"review_configs": [{}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"messages": [AIMessage(content="done")],
|
||||||
|
"decision_text": repr(decision),
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = StateGraph(_SubagentState)
|
||||||
|
graph.add_node("approve", approve_node)
|
||||||
|
graph.add_edge(START, "approve")
|
||||||
|
graph.add_edge("approve", END)
|
||||||
|
return graph.compile(checkpointer=InMemorySaver())
|
||||||
|
|
||||||
|
|
||||||
|
def _make_runtime(config: dict) -> ToolRuntime:
|
||||||
|
return ToolRuntime(
|
||||||
|
state={"messages": [HumanMessage(content="seed")]},
|
||||||
|
context=None,
|
||||||
|
config=config,
|
||||||
|
stream_writer=None,
|
||||||
|
tool_call_id="parent-tcid-1",
|
||||||
|
store=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_bridge_dispatches_decision_into_pending_subagent():
|
||||||
|
"""Side-channel decision must reach the subagent's pending interrupt verbatim."""
|
||||||
|
subagent = _build_single_interrupt_subagent()
|
||||||
|
task_tool = build_task_tool_with_parent_config(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "approver",
|
||||||
|
"description": "approves things",
|
||||||
|
"runnable": subagent,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
parent_config: dict = {
|
||||||
|
"configurable": {"thread_id": "shared-thread"},
|
||||||
|
"recursion_limit": 100,
|
||||||
|
}
|
||||||
|
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||||
|
snap = await subagent.aget_state(parent_config)
|
||||||
|
assert snap.tasks and snap.tasks[0].interrupts, (
|
||||||
|
"fixture broken: subagent should be paused on its interrupt"
|
||||||
|
)
|
||||||
|
|
||||||
|
parent_config["configurable"]["surfsense_resume_value"] = {
|
||||||
|
"decisions": ["APPROVED"]
|
||||||
|
}
|
||||||
|
runtime = _make_runtime(parent_config)
|
||||||
|
|
||||||
|
result = await task_tool.coroutine(
|
||||||
|
description="please approve",
|
||||||
|
subagent_type="approver",
|
||||||
|
runtime=runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, Command)
|
||||||
|
update = result.update
|
||||||
|
assert update["decision_text"] == repr({"decisions": ["APPROVED"]})
|
||||||
|
assert "surfsense_resume_value" not in parent_config["configurable"]
|
||||||
|
|
||||||
|
final = await subagent.aget_state(parent_config)
|
||||||
|
assert not final.tasks or all(not t.interrupts for t in final.tasks)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pending_interrupt_without_resume_value_raises_runtime_error():
|
||||||
|
"""Bridge must fail loud rather than silently replay the user's interrupt."""
|
||||||
|
subagent = _build_single_interrupt_subagent()
|
||||||
|
task_tool = build_task_tool_with_parent_config(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "approver",
|
||||||
|
"description": "approves things",
|
||||||
|
"runnable": subagent,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
parent_config: dict = {
|
||||||
|
"configurable": {"thread_id": "guard-thread"},
|
||||||
|
"recursion_limit": 100,
|
||||||
|
}
|
||||||
|
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||||
|
snap = await subagent.aget_state(parent_config)
|
||||||
|
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
|
||||||
|
|
||||||
|
runtime = _make_runtime(parent_config)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="resume bridge is broken"):
|
||||||
|
await task_tool.coroutine(
|
||||||
|
description="please approve",
|
||||||
|
subagent_type="approver",
|
||||||
|
runtime=runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_bundle_subagent():
|
||||||
|
def bundle_node(state):
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
decision = interrupt(
|
||||||
|
{
|
||||||
|
"action_requests": [
|
||||||
|
{"name": "create_a", "args": {}, "description": ""},
|
||||||
|
{"name": "create_b", "args": {}, "description": ""},
|
||||||
|
{"name": "create_c", "args": {}, "description": ""},
|
||||||
|
],
|
||||||
|
"review_configs": [{}, {}, {}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"messages": [AIMessage(content="bundle-done")],
|
||||||
|
"decision_text": repr(decision),
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = StateGraph(_SubagentState)
|
||||||
|
graph.add_node("bundle", bundle_node)
|
||||||
|
graph.add_edge(START, "bundle")
|
||||||
|
graph.add_edge("bundle", END)
|
||||||
|
return graph.compile(checkpointer=InMemorySaver())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bundle_three_mixed_decisions_arrive_in_order():
|
||||||
|
"""Approve / edit / reject for a 3-action bundle must land at ordinals 0/1/2."""
|
||||||
|
subagent = _build_bundle_subagent()
|
||||||
|
task_tool = build_task_tool_with_parent_config(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "bundler",
|
||||||
|
"description": "creates a bundle",
|
||||||
|
"runnable": subagent,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
parent_config: dict = {
|
||||||
|
"configurable": {"thread_id": "bundle-thread"},
|
||||||
|
"recursion_limit": 100,
|
||||||
|
}
|
||||||
|
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||||
|
|
||||||
|
decisions_payload = {
|
||||||
|
"decisions": [
|
||||||
|
{"type": "approve", "args": {}},
|
||||||
|
{"type": "edit", "args": {"args": {"name": "edited-b"}}},
|
||||||
|
{"type": "reject", "args": {"message": "no thanks"}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
parent_config["configurable"]["surfsense_resume_value"] = decisions_payload
|
||||||
|
runtime = _make_runtime(parent_config)
|
||||||
|
|
||||||
|
result = await task_tool.coroutine(
|
||||||
|
description="run bundle",
|
||||||
|
subagent_type="bundler",
|
||||||
|
runtime=runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, Command)
|
||||||
|
received = ast.literal_eval(result.update["decision_text"])
|
||||||
|
assert received == decisions_payload
|
||||||
|
assert received["decisions"][0]["type"] == "approve"
|
||||||
|
assert received["decisions"][1]["type"] == "edit"
|
||||||
|
assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}}
|
||||||
|
assert received["decisions"][2]["type"] == "reject"
|
||||||
|
|
@ -0,0 +1,55 @@
|
||||||
|
"""Pins the first-wins assumption of ``get_first_pending_subagent_interrupt``.
|
||||||
|
|
||||||
|
The bridge currently relies on at-most-one pending interrupt per snapshot
|
||||||
|
(sequential tool nodes). If parallel tool calls are ever enabled, the bridge
|
||||||
|
needs an id-aware lookup; these tests will need to be revisited at that point.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume import (
|
||||||
|
get_first_pending_subagent_interrupt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetFirstPendingSubagentInterrupt:
|
||||||
|
def test_returns_first_when_multiple_top_level_interrupts_pending(self):
|
||||||
|
first = SimpleNamespace(id="i-1", value={"decision": "approve"})
|
||||||
|
second = SimpleNamespace(id="i-2", value={"decision": "reject"})
|
||||||
|
state = SimpleNamespace(interrupts=(first, second), tasks=())
|
||||||
|
|
||||||
|
assert get_first_pending_subagent_interrupt(state) == (
|
||||||
|
"i-1",
|
||||||
|
{"decision": "approve"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_returns_first_when_multiple_subtask_interrupts_pending(self):
|
||||||
|
first = SimpleNamespace(id="i-A", value="approve")
|
||||||
|
second = SimpleNamespace(id="i-B", value="reject")
|
||||||
|
sub_task = SimpleNamespace(interrupts=(first, second))
|
||||||
|
state = SimpleNamespace(interrupts=(), tasks=(sub_task,))
|
||||||
|
|
||||||
|
assert get_first_pending_subagent_interrupt(state) == ("i-A", "approve")
|
||||||
|
|
||||||
|
def test_returns_none_when_no_interrupts(self):
|
||||||
|
state = SimpleNamespace(interrupts=(), tasks=())
|
||||||
|
|
||||||
|
assert get_first_pending_subagent_interrupt(state) == (None, None)
|
||||||
|
|
||||||
|
def test_returns_none_when_state_is_none(self):
|
||||||
|
assert get_first_pending_subagent_interrupt(None) == (None, None)
|
||||||
|
|
||||||
|
def test_skips_interrupts_with_none_value(self):
|
||||||
|
empty = SimpleNamespace(id="i-empty", value=None)
|
||||||
|
real = SimpleNamespace(id="i-real", value="approve")
|
||||||
|
state = SimpleNamespace(interrupts=(empty, real), tasks=())
|
||||||
|
|
||||||
|
assert get_first_pending_subagent_interrupt(state) == ("i-real", "approve")
|
||||||
|
|
||||||
|
def test_normalizes_non_string_id_to_none(self):
|
||||||
|
interrupt = SimpleNamespace(id=12345, value="approve")
|
||||||
|
state = SimpleNamespace(interrupts=(interrupt,), tasks=())
|
||||||
|
|
||||||
|
assert get_first_pending_subagent_interrupt(state) == (None, "approve")
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
"""Resume side-channel must be read exactly once per turn."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain.tools import ToolRuntime
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
|
||||||
|
consume_surfsense_resume,
|
||||||
|
has_surfsense_resume,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_with_config(config: dict) -> ToolRuntime:
|
||||||
|
return ToolRuntime(
|
||||||
|
state=None,
|
||||||
|
context=None,
|
||||||
|
config=config,
|
||||||
|
stream_writer=None,
|
||||||
|
tool_call_id="tcid-test",
|
||||||
|
store=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConsumeSurfsenseResume:
|
||||||
|
def test_pops_value_on_first_call(self):
|
||||||
|
runtime = _runtime_with_config(
|
||||||
|
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
|
||||||
|
|
||||||
|
def test_second_call_returns_none(self):
|
||||||
|
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
|
||||||
|
runtime = _runtime_with_config({"configurable": configurable})
|
||||||
|
|
||||||
|
consume_surfsense_resume(runtime)
|
||||||
|
|
||||||
|
assert consume_surfsense_resume(runtime) is None
|
||||||
|
assert "surfsense_resume_value" not in configurable
|
||||||
|
|
||||||
|
def test_returns_none_when_no_payload_queued(self):
|
||||||
|
runtime = _runtime_with_config({"configurable": {}})
|
||||||
|
|
||||||
|
assert consume_surfsense_resume(runtime) is None
|
||||||
|
|
||||||
|
def test_returns_none_when_configurable_missing(self):
|
||||||
|
runtime = _runtime_with_config({})
|
||||||
|
|
||||||
|
assert consume_surfsense_resume(runtime) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestHasSurfsenseResume:
|
||||||
|
def test_true_when_payload_queued(self):
|
||||||
|
runtime = _runtime_with_config(
|
||||||
|
{"configurable": {"surfsense_resume_value": "approve"}}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert has_surfsense_resume(runtime) is True
|
||||||
|
|
||||||
|
def test_does_not_consume_payload(self):
|
||||||
|
configurable = {"surfsense_resume_value": "approve"}
|
||||||
|
runtime = _runtime_with_config({"configurable": configurable})
|
||||||
|
|
||||||
|
has_surfsense_resume(runtime)
|
||||||
|
|
||||||
|
assert configurable == {"surfsense_resume_value": "approve"}
|
||||||
|
|
||||||
|
def test_false_when_payload_absent(self):
|
||||||
|
runtime = _runtime_with_config({"configurable": {}})
|
||||||
|
|
||||||
|
assert has_surfsense_resume(runtime) is False
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
"""Subagent resilience contract: ``extra_middleware`` reaches the agent chain."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
from langchain_core.language_models.fake_chat_models import (
|
||||||
|
FakeMessagesListChatModel,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||||
|
pack_subagent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitError(Exception):
|
||||||
|
"""Name matches the scoped-fallback eligibility allowlist."""
|
||||||
|
|
||||||
|
|
||||||
|
class _AlwaysFailingChatModel(BaseChatModel):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "always-failing-test-model"
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
msg = "primary llm exploded"
|
||||||
|
raise RateLimitError(msg)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
msg = "primary llm exploded"
|
||||||
|
raise RateLimitError(msg)
|
||||||
|
|
||||||
|
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
|
||||||
|
msg = "primary llm exploded"
|
||||||
|
raise RateLimitError(msg)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self, *args: Any, **kwargs: Any
|
||||||
|
) -> AsyncIterator[ChatGeneration]:
|
||||||
|
msg = "primary llm exploded"
|
||||||
|
raise RateLimitError(msg)
|
||||||
|
yield # pragma: no cover - unreachable, satisfies async generator typing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subagent_recovers_when_primary_llm_fails():
|
||||||
|
"""Fallback in ``extra_middleware`` must finish the turn when primary raises."""
|
||||||
|
primary = _AlwaysFailingChatModel()
|
||||||
|
fallback = FakeMessagesListChatModel(
|
||||||
|
responses=[AIMessage(content="recovered via fallback")]
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = pack_subagent(
|
||||||
|
name="resilience_test",
|
||||||
|
description="test subagent",
|
||||||
|
system_prompt="be helpful",
|
||||||
|
tools=[],
|
||||||
|
model=primary,
|
||||||
|
extra_middleware=[ModelFallbackMiddleware(fallback)],
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = create_agent(
|
||||||
|
model=spec["model"],
|
||||||
|
tools=spec["tools"],
|
||||||
|
middleware=spec["middleware"],
|
||||||
|
system_prompt=spec["system_prompt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await agent.ainvoke({"messages": [HumanMessage(content="hi")]})
|
||||||
|
|
||||||
|
final = result["messages"][-1]
|
||||||
|
assert isinstance(final, AIMessage)
|
||||||
|
assert final.content == "recovered via fallback"
|
||||||
|
|
@ -0,0 +1,130 @@
|
||||||
|
"""``ScopedModelFallbackMiddleware`` triggers fallback only on provider errors."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
|
|
||||||
|
|
||||||
|
class _RaisingChatModel(BaseChatModel):
|
||||||
|
exc_to_raise: Any
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "raising-test-model"
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
raise self.exc_to_raise
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
raise self.exc_to_raise
|
||||||
|
|
||||||
|
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
|
||||||
|
raise self.exc_to_raise
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self, *args: Any, **kwargs: Any
|
||||||
|
) -> AsyncIterator[ChatGeneration]:
|
||||||
|
raise self.exc_to_raise
|
||||||
|
yield # pragma: no cover - unreachable
|
||||||
|
|
||||||
|
|
||||||
|
class _RecordingChatModel(BaseChatModel):
|
||||||
|
response_text: str = "fallback-ok"
|
||||||
|
call_count: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "recording-test-model"
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
self.call_count += 1
|
||||||
|
return ChatResult(
|
||||||
|
generations=[
|
||||||
|
ChatGeneration(message=AIMessage(content=self.response_text))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
return self._generate(messages, stop, None, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitError(Exception):
|
||||||
|
"""Name matches the scoped-fallback eligibility allowlist."""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_agent(primary: BaseChatModel, fallback: BaseChatModel):
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||||
|
ScopedModelFallbackMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_agent(
|
||||||
|
model=primary,
|
||||||
|
tools=[],
|
||||||
|
middleware=[ScopedModelFallbackMiddleware(fallback)],
|
||||||
|
system_prompt="be helpful",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_errors_trigger_fallback():
|
||||||
|
"""Eligible exception names must drive the fallback chain."""
|
||||||
|
primary = _RaisingChatModel(exc_to_raise=RateLimitError("429 from provider"))
|
||||||
|
fallback = _RecordingChatModel(response_text="recovered")
|
||||||
|
|
||||||
|
agent = _build_agent(primary, fallback)
|
||||||
|
result = await agent.ainvoke({"messages": [("user", "hi")]})
|
||||||
|
|
||||||
|
final = result["messages"][-1]
|
||||||
|
assert isinstance(final, AIMessage)
|
||||||
|
assert final.content == "recovered"
|
||||||
|
assert fallback.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_programming_errors_propagate_without_invoking_fallback():
|
||||||
|
"""Non-eligible exceptions must propagate; fallback must not be invoked."""
|
||||||
|
primary = _RaisingChatModel(exc_to_raise=KeyError("missing_state_field"))
|
||||||
|
fallback = _RecordingChatModel(response_text="should-never-arrive")
|
||||||
|
|
||||||
|
agent = _build_agent(primary, fallback)
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match="missing_state_field"):
|
||||||
|
await agent.ainvoke({"messages": [("user", "hi")]})
|
||||||
|
|
||||||
|
assert fallback.call_count == 0
|
||||||
|
|
@ -202,6 +202,15 @@ class FakeBudgetLLM:
|
||||||
|
|
||||||
|
|
||||||
class TestKnowledgeBaseSearchMiddlewarePlanner:
|
class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _disable_planner_runnable(self, monkeypatch):
|
||||||
|
# ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the
|
||||||
|
# planner Runnable path is enabled) calls ``.bind()`` on the LLM,
|
||||||
|
# which the mock does not implement. Pin the flag off so the
|
||||||
|
# planner falls through to the legacy ``self.llm.ainvoke`` path
|
||||||
|
# these tests assert against (``llm.calls[0]["config"]``).
|
||||||
|
monkeypatch.setenv("SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "false")
|
||||||
|
|
||||||
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
|
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
|
||||||
messages = [
|
messages = [
|
||||||
HumanMessage(content="old user context " * 40),
|
HumanMessage(content="old user context " * 40),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue