mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 14:22:47 +02:00
Merge pull request #1351 from CREDO23/feature/multi-agent
[Improvement] Modular middleware stack + agent/prompt caching + subagent resilience + unit tests
This commit is contained in:
commit
a4fc812b85
70 changed files with 2037 additions and 547 deletions
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from .main_agent import create_surfsense_deep_agent
|
||||
from .main_agent import create_multi_agent_chat_deep_agent
|
||||
|
||||
__all__ = ["create_surfsense_deep_agent"]
|
||||
__all__ = ["create_multi_agent_chat_deep_agent"]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from .runtime import create_surfsense_deep_agent
|
||||
from .runtime import create_multi_agent_chat_deep_agent
|
||||
|
||||
__all__ = ["create_surfsense_deep_agent"]
|
||||
__all__ = ["create_multi_agent_chat_deep_agent"]
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ from langchain_core.language_models import BaseChatModel
|
|||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.middleware import (
|
||||
build_main_agent_deepagent_middleware,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
|
@ -19,8 +22,6 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
|||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .middleware import build_main_agent_deepagent_middleware
|
||||
|
||||
|
||||
def build_compiled_agent_graph_sync(
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
"""Main-agent graph middleware assembly (SurfSense + LangChain + deepagents)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .deepagent_stack import build_main_agent_deepagent_middleware
|
||||
|
||||
__all__ = ["build_main_agent_deepagent_middleware"]
|
||||
|
|
@ -1,506 +0,0 @@
|
|||
"""Assemble the main-agent deep-agent middleware list (LangChain + SurfSense + deepagents)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from deepagents.backends import StateBackend
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from deepagents.middleware.skills import SkillsMiddleware
|
||||
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||
from langchain.agents.middleware import (
|
||||
LLMToolSelectorMiddleware,
|
||||
ModelCallLimitMiddleware,
|
||||
ModelFallbackMiddleware,
|
||||
TodoListMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents import (
|
||||
build_subagents,
|
||||
get_subagents_to_exclude,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import (
|
||||
ActionLogMiddleware,
|
||||
AnonymousDocumentMiddleware,
|
||||
BusyMutexMiddleware,
|
||||
ClearToolUsesEdit,
|
||||
DedupHITLToolCallsMiddleware,
|
||||
DoomLoopMiddleware,
|
||||
FileIntentMiddleware,
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
KnowledgePriorityMiddleware,
|
||||
KnowledgeTreeMiddleware,
|
||||
MemoryInjectionMiddleware,
|
||||
NoopInjectionMiddleware,
|
||||
OtelSpanMiddleware,
|
||||
PermissionMiddleware,
|
||||
RetryAfterMiddleware,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
SurfSenseFilesystemMiddleware,
|
||||
ToolCallNameRepairMiddleware,
|
||||
build_skills_backend_factory,
|
||||
create_surfsense_compaction_middleware,
|
||||
default_skills_sources,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
PluginContext,
|
||||
load_allowed_plugin_names_from_env,
|
||||
load_plugin_middlewares,
|
||||
)
|
||||
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ...context_prune.prune_tool_names import safe_exclude_tools
|
||||
from .checkpointed_subagent_middleware import SurfSenseCheckpointedSubAgentMiddleware
|
||||
|
||||
|
||||
def build_main_agent_deepagent_middleware(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
visibility: ChatVisibility,
|
||||
anon_session_id: str | None,
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
max_input_tokens: int | None,
|
||||
flags: AgentFeatureFlags,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
checkpointer: Checkpointer,
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Build ordered middleware for ``create_agent`` (Nones already stripped)."""
|
||||
_memory_middleware = MemoryInjectionMiddleware(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
thread_visibility=visibility,
|
||||
)
|
||||
|
||||
gp_middleware = [
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
FileIntentMiddleware(llm=llm),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
create_surfsense_compaction_middleware(llm, StateBackend),
|
||||
PatchToolCallsMiddleware(),
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
|
||||
# Build permission rulesets up front so the GP subagent can mirror ``ask``
|
||||
# rules into ``interrupt_on``: tool calls emitted from within ``task`` runs
|
||||
# never reach the parent's ``PermissionMiddleware``.
|
||||
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||
permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack
|
||||
permission_rulesets: list[Ruleset] = []
|
||||
if permission_enabled or is_desktop_fs:
|
||||
permission_rulesets.append(
|
||||
Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
)
|
||||
)
|
||||
if is_desktop_fs:
|
||||
permission_rulesets.append(
|
||||
Ruleset(
|
||||
rules=[
|
||||
Rule(permission="rm", pattern="*", action="ask"),
|
||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||
Rule(permission="move_file", pattern="*", action="ask"),
|
||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||
Rule(permission="write_file", pattern="*", action="ask"),
|
||||
],
|
||||
origin="desktop_safety",
|
||||
)
|
||||
)
|
||||
|
||||
# Tools that self-prompt via ``request_approval`` must not also appear
|
||||
# as ``ask`` rules — that would double-prompt the user for one call.
|
||||
_tool_names_in_use = {t.name for t in tools}
|
||||
|
||||
# Deny parent-bound tools whose ``required_connector`` is missing.
|
||||
# No-op today (connector subagents are pruned upstream); guards future
|
||||
# additions to the parent's tool list.
|
||||
if permission_enabled:
|
||||
_available_set = set(available_connectors or [])
|
||||
_synthesized: list[Rule] = []
|
||||
for tool_def in BUILTIN_TOOLS:
|
||||
if tool_def.name not in _tool_names_in_use:
|
||||
continue
|
||||
rc = tool_def.required_connector
|
||||
if rc and rc not in _available_set:
|
||||
_synthesized.append(
|
||||
Rule(permission=tool_def.name, pattern="*", action="deny")
|
||||
)
|
||||
if _synthesized:
|
||||
permission_rulesets.append(
|
||||
Ruleset(rules=_synthesized, origin="connector_synthesized")
|
||||
)
|
||||
gp_interrupt_on: dict[str, bool] = {
|
||||
rule.permission: True
|
||||
for rs in permission_rulesets
|
||||
for rule in rs.rules
|
||||
if rule.action == "ask" and rule.permission in _tool_names_in_use
|
||||
}
|
||||
|
||||
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
|
||||
**GENERAL_PURPOSE_SUBAGENT,
|
||||
"model": llm,
|
||||
"tools": tools,
|
||||
"middleware": gp_middleware,
|
||||
}
|
||||
if gp_interrupt_on:
|
||||
general_purpose_spec["interrupt_on"] = gp_interrupt_on
|
||||
|
||||
# Deny-only on subagents: ``task`` runs bypass the parent's
|
||||
# PermissionMiddleware, while bucket-based ask gates own the ask path.
|
||||
subagent_deny_rulesets: list[Ruleset] = [
|
||||
Ruleset(
|
||||
rules=[r for r in rs.rules if r.action == "deny"],
|
||||
origin=rs.origin,
|
||||
)
|
||||
for rs in permission_rulesets
|
||||
]
|
||||
subagent_deny_rulesets = [rs for rs in subagent_deny_rulesets if rs.rules]
|
||||
|
||||
subagent_deny_permission_mw: PermissionMiddleware | None = (
|
||||
PermissionMiddleware(rulesets=subagent_deny_rulesets)
|
||||
if subagent_deny_rulesets
|
||||
else None
|
||||
)
|
||||
|
||||
if subagent_deny_permission_mw is not None:
|
||||
# Run deny check on already-repaired tool calls; insert before
|
||||
# PatchToolCallsMiddleware (append if the slot moves).
|
||||
_patch_idx = next(
|
||||
(
|
||||
i
|
||||
for i, m in enumerate(gp_middleware)
|
||||
if isinstance(m, PatchToolCallsMiddleware)
|
||||
),
|
||||
len(gp_middleware),
|
||||
)
|
||||
gp_middleware.insert(_patch_idx, subagent_deny_permission_mw)
|
||||
|
||||
registry_subagents: list[SubAgent] = []
|
||||
try:
|
||||
subagent_extra_middleware: list[Any] = [
|
||||
TodoListMiddleware(),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
]
|
||||
if subagent_deny_permission_mw is not None:
|
||||
subagent_extra_middleware.append(subagent_deny_permission_mw)
|
||||
registry_subagents = build_subagents(
|
||||
dependencies=subagent_dependencies,
|
||||
model=llm,
|
||||
extra_middleware=subagent_extra_middleware,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent or {},
|
||||
exclude=get_subagents_to_exclude(available_connectors),
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
logging.info(
|
||||
"Registry subagents: %s",
|
||||
[s["name"] for s in registry_subagents],
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Registry subagent build failed")
|
||||
raise
|
||||
|
||||
subagent_specs: list[SubAgent] = [general_purpose_spec, *registry_subagents]
|
||||
|
||||
summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend)
|
||||
|
||||
context_edit_mw = None
|
||||
if (
|
||||
flags.enable_context_editing
|
||||
and not flags.disable_new_agent_stack
|
||||
and max_input_tokens
|
||||
):
|
||||
spill_edit = SpillToBackendEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
)
|
||||
clear_edit = ClearToolUsesEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
placeholder="[cleared - older tool output trimmed for context]",
|
||||
)
|
||||
context_edit_mw = SpillingContextEditingMiddleware(
|
||||
edits=[spill_edit, clear_edit],
|
||||
backend_resolver=backend_resolver,
|
||||
)
|
||||
|
||||
retry_mw = (
|
||||
RetryAfterMiddleware(max_retries=3)
|
||||
if flags.enable_retry_after and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
fallback_mw: ModelFallbackMiddleware | None = None
|
||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
fallback_mw = ModelFallbackMiddleware(
|
||||
"openai:gpt-4o-mini",
|
||||
"anthropic:claude-3-5-haiku-20241022",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
||||
fallback_mw = None
|
||||
model_call_limit_mw = (
|
||||
ModelCallLimitMiddleware(
|
||||
thread_limit=120,
|
||||
run_limit=80,
|
||||
exit_behavior="end",
|
||||
)
|
||||
if flags.enable_model_call_limit and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
tool_call_limit_mw = (
|
||||
ToolCallLimitMiddleware(
|
||||
thread_limit=300, run_limit=80, exit_behavior="continue"
|
||||
)
|
||||
if flags.enable_tool_call_limit and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
noop_mw = (
|
||||
NoopInjectionMiddleware()
|
||||
if flags.enable_compaction_v2 and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
repair_mw = None
|
||||
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
|
||||
registered_names: set[str] = {t.name for t in tools}
|
||||
registered_names |= {
|
||||
"write_todos",
|
||||
"ls",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"glob",
|
||||
"grep",
|
||||
"execute",
|
||||
"task",
|
||||
"mkdir",
|
||||
"cd",
|
||||
"pwd",
|
||||
"move_file",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"execute_code",
|
||||
}
|
||||
repair_mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names=registered_names,
|
||||
fuzzy_match_threshold=None,
|
||||
)
|
||||
|
||||
doom_loop_mw = (
|
||||
DoomLoopMiddleware(threshold=3)
|
||||
if flags.enable_doom_loop and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
permission_mw: PermissionMiddleware | None = (
|
||||
PermissionMiddleware(rulesets=permission_rulesets)
|
||||
if permission_rulesets
|
||||
else None
|
||||
)
|
||||
|
||||
action_log_mw: ActionLogMiddleware | None = None
|
||||
if (
|
||||
flags.enable_action_log
|
||||
and not flags.disable_new_agent_stack
|
||||
and thread_id is not None
|
||||
):
|
||||
try:
|
||||
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
|
||||
action_log_mw = ActionLogMiddleware(
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
tool_definitions=tool_defs_by_name,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"ActionLogMiddleware init failed; running without it.",
|
||||
exc_info=True,
|
||||
)
|
||||
action_log_mw = None
|
||||
|
||||
busy_mutex_mw: BusyMutexMiddleware | None = (
|
||||
BusyMutexMiddleware()
|
||||
if flags.enable_busy_mutex and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
otel_mw: OtelSpanMiddleware | None = (
|
||||
OtelSpanMiddleware()
|
||||
if flags.enable_otel and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
plugin_middlewares: list[Any] = []
|
||||
if flags.enable_plugin_loader and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
allowed_names = load_allowed_plugin_names_from_env()
|
||||
if allowed_names:
|
||||
plugin_middlewares = load_plugin_middlewares(
|
||||
PluginContext.build(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_visibility=visibility,
|
||||
llm=llm,
|
||||
),
|
||||
allowed_plugin_names=allowed_names,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"Plugin loader failed; continuing without plugins.",
|
||||
exc_info=True,
|
||||
)
|
||||
plugin_middlewares = []
|
||||
|
||||
skills_mw: SkillsMiddleware | None = None
|
||||
if flags.enable_skills and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
skills_factory = build_skills_backend_factory(
|
||||
search_space_id=search_space_id
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
)
|
||||
skills_mw = SkillsMiddleware(
|
||||
backend=skills_factory,
|
||||
sources=default_skills_sources(),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
|
||||
skills_mw = None
|
||||
|
||||
selector_mw: LLMToolSelectorMiddleware | None = None
|
||||
if (
|
||||
flags.enable_llm_tool_selector
|
||||
and not flags.disable_new_agent_stack
|
||||
and len(tools) > 30
|
||||
):
|
||||
try:
|
||||
selector_mw = LLMToolSelectorMiddleware(
|
||||
model="openai:gpt-4o-mini",
|
||||
max_tools=12,
|
||||
always_include=[
|
||||
name
|
||||
for name in (
|
||||
"update_memory",
|
||||
"get_connected_accounts",
|
||||
"scrape_webpage",
|
||||
)
|
||||
if name in {t.name for t in tools}
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
|
||||
selector_mw = None
|
||||
|
||||
deepagent_middleware = [
|
||||
busy_mutex_mw,
|
||||
otel_mw,
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
AnonymousDocumentMiddleware(
|
||||
anon_session_id=anon_session_id,
|
||||
)
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
KnowledgeTreeMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
llm=llm,
|
||||
)
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
KnowledgePriorityMiddleware(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
),
|
||||
FileIntentMiddleware(llm=llm),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
KnowledgeBasePersistenceMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
skills_mw,
|
||||
SurfSenseCheckpointedSubAgentMiddleware(
|
||||
checkpointer=checkpointer,
|
||||
backend=StateBackend,
|
||||
subagents=subagent_specs,
|
||||
),
|
||||
selector_mw,
|
||||
model_call_limit_mw,
|
||||
tool_call_limit_mw,
|
||||
context_edit_mw,
|
||||
summarization_mw,
|
||||
noop_mw,
|
||||
retry_mw,
|
||||
fallback_mw,
|
||||
repair_mw,
|
||||
permission_mw,
|
||||
doom_loop_mw,
|
||||
action_log_mw,
|
||||
PatchToolCallsMiddleware(),
|
||||
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
|
||||
*plugin_middlewares,
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
return [m for m in deepagent_middleware if m is not None]
|
||||
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from .factory import create_surfsense_deep_agent
|
||||
from .factory import create_multi_agent_chat_deep_agent
|
||||
|
||||
__all__ = ["create_surfsense_deep_agent"]
|
||||
__all__ = ["create_multi_agent_chat_deep_agent"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,117 @@
|
|||
"""Compiled agent graph caching for the multi-agent path."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
|
||||
from app.agents.new_chat.agent_cache import (
|
||||
flags_signature,
|
||||
get_cache,
|
||||
stable_hash,
|
||||
system_prompt_hash,
|
||||
tools_signature,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||
|
||||
|
||||
def mcp_signature(mcp_tools_by_agent: dict[str, ToolsPermissions]) -> str:
|
||||
"""Hash the per-agent MCP tool surface so a change rotates the cache key."""
|
||||
rows = []
|
||||
for agent_name in sorted(mcp_tools_by_agent.keys()):
|
||||
perms = mcp_tools_by_agent[agent_name]
|
||||
allow_names = sorted(item.get("name", "") for item in perms.get("allow", []))
|
||||
ask_names = sorted(item.get("name", "") for item in perms.get("ask", []))
|
||||
rows.append((agent_name, allow_names, ask_names))
|
||||
return stable_hash(rows)
|
||||
|
||||
|
||||
async def build_agent_with_cache(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
final_system_prompt: str,
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
visibility: ChatVisibility,
|
||||
anon_session_id: str | None,
|
||||
available_connectors: list[str],
|
||||
available_document_types: list[str],
|
||||
mentioned_document_ids: list[int] | None,
|
||||
max_input_tokens: int | None,
|
||||
flags: AgentFeatureFlags,
|
||||
checkpointer: Checkpointer,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions],
|
||||
disabled_tools: list[str] | None,
|
||||
config_id: str | None,
|
||||
) -> Any:
|
||||
"""Compile the multi-agent graph, serving from cache when key components are stable."""
|
||||
|
||||
async def _build() -> Any:
|
||||
return await asyncio.to_thread(
|
||||
build_compiled_agent_graph_sync,
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
final_system_prompt=final_system_prompt,
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
visibility=visibility,
|
||||
anon_session_id=anon_session_id,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
max_input_tokens=max_input_tokens,
|
||||
flags=flags,
|
||||
checkpointer=checkpointer,
|
||||
subagent_dependencies=subagent_dependencies,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
|
||||
if not (flags.enable_agent_cache and not flags.disable_new_agent_stack):
|
||||
return await _build()
|
||||
|
||||
# Every per-request value any middleware closes over at __init__ must be in
|
||||
# the key, otherwise a hit will leak state across threads. Bump the schema
|
||||
# version when the component list changes shape.
|
||||
cache_key = stable_hash(
|
||||
"multi-agent-v1",
|
||||
config_id,
|
||||
thread_id,
|
||||
user_id,
|
||||
search_space_id,
|
||||
visibility,
|
||||
filesystem_mode,
|
||||
anon_session_id,
|
||||
tools_signature(
|
||||
tools,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
),
|
||||
mcp_signature(mcp_tools_by_agent),
|
||||
flags_signature(flags),
|
||||
system_prompt_hash(final_system_prompt),
|
||||
max_input_tokens,
|
||||
sorted(disabled_tools) if disabled_tools else None,
|
||||
)
|
||||
return await get_cache().get_or_build(cache_key, builder=_build)
|
||||
|
||||
|
||||
__all__ = ["build_agent_with_cache", "mcp_signature"]
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
|
|
@ -26,23 +25,24 @@ from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
|||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
|
||||
from app.agents.new_chat.tools.registry import build_tools_async
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||
from ..system_prompt import build_main_agent_system_prompt
|
||||
from ..tools import (
|
||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES,
|
||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED,
|
||||
)
|
||||
from .agent_cache import build_agent_with_cache
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
async def create_surfsense_deep_agent(
|
||||
async def create_multi_agent_chat_deep_agent(
|
||||
llm: BaseChatModel,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
|
|
@ -62,6 +62,9 @@ async def create_surfsense_deep_agent(
|
|||
):
|
||||
"""Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled."""
|
||||
_t_agent_total = time.perf_counter()
|
||||
|
||||
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
|
||||
|
||||
filesystem_selection = filesystem_selection or FilesystemSelection()
|
||||
backend_resolver = build_backend_resolver(
|
||||
filesystem_selection,
|
||||
|
|
@ -85,7 +88,18 @@ async def create_surfsense_deep_agent(
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning("Failed to discover available connectors/document types: %s", e)
|
||||
logging.warning(
|
||||
"Connector/doc-type discovery failed; excluding connector subagents this turn: %s",
|
||||
e,
|
||||
)
|
||||
|
||||
# Fail closed: a None list short-circuits ``get_subagents_to_exclude`` to "exclude
|
||||
# nothing", which would silently advertise every connector specialist on a flaky
|
||||
# discovery call. Empty list excludes connector-gated subagents while keeping builtins.
|
||||
if available_connectors is None:
|
||||
available_connectors = []
|
||||
if available_document_types is None:
|
||||
available_document_types = []
|
||||
_perf_log.info(
|
||||
"[create_agent] Connector/doc-type discovery in %.3fs",
|
||||
time.perf_counter() - _t0,
|
||||
|
|
@ -115,7 +129,16 @@ async def create_surfsense_deep_agent(
|
|||
}
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
mcp_tools_by_agent = await load_mcp_tools_by_connector(db_session, search_space_id)
|
||||
try:
|
||||
mcp_tools_by_agent = await load_mcp_tools_by_connector(db_session, search_space_id)
|
||||
except Exception as e:
|
||||
# Degrade to builtins-only rather than aborting the turn: a transient
|
||||
# DB or MCP-server hiccup should not deny the user a response.
|
||||
logging.warning(
|
||||
"MCP tool discovery failed; subagents will run without MCP tools this turn: %s",
|
||||
e,
|
||||
)
|
||||
mcp_tools_by_agent = {}
|
||||
_perf_log.info(
|
||||
"[create_agent] load_mcp_tools_by_connector in %.3fs (%d buckets)",
|
||||
time.perf_counter() - _t0,
|
||||
|
|
@ -195,9 +218,10 @@ async def create_surfsense_deep_agent(
|
|||
|
||||
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
||||
|
||||
config_id = agent_config.config_id if agent_config is not None else None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await asyncio.to_thread(
|
||||
build_compiled_agent_graph_sync,
|
||||
agent = await build_agent_with_cache(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
final_system_prompt=final_system_prompt,
|
||||
|
|
@ -217,6 +241,7 @@ async def create_surfsense_deep_agent(
|
|||
subagent_dependencies=dependencies,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||
disabled_tools=disabled_tools,
|
||||
config_id=config_id,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] Middleware stack + graph compiled in %.3fs",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
"""Multi-agent middleware stack assembly."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .stack import build_main_agent_deepagent_middleware
|
||||
|
||||
__all__ = ["build_main_agent_deepagent_middleware"]
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
"""Audit row per tool call (reversibility metadata)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import ActionLogMiddleware
|
||||
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_action_log_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
thread_id: int | None,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
) -> ActionLogMiddleware | None:
|
||||
if not enabled(flags, "enable_action_log") or thread_id is None:
|
||||
return None
|
||||
try:
|
||||
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
|
||||
return ActionLogMiddleware(
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
tool_definitions=tool_defs_by_name,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"ActionLogMiddleware init failed; running without it.",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
"""Anonymous document hydration from Redis (cloud only)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import AnonymousDocumentMiddleware
|
||||
|
||||
|
||||
def build_anonymous_doc_mw(
|
||||
*,
|
||||
filesystem_mode: FilesystemMode,
|
||||
anon_session_id: str | None,
|
||||
) -> AnonymousDocumentMiddleware | None:
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
return AnonymousDocumentMiddleware(anon_session_id=anon_session_id)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Per-thread cooperative lock around the whole turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import BusyMutexMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_busy_mutex_mw(flags: AgentFeatureFlags) -> BusyMutexMiddleware | None:
|
||||
return BusyMutexMiddleware() if enabled(flags, "enable_busy_mutex") else None
|
||||
|
|
@ -69,9 +69,16 @@ def build_task_tool_with_parent_config(
|
|||
raise ValueError(msg)
|
||||
|
||||
state_update = {k: v for k, v in result.items() if k not in EXCLUDED_STATE_KEYS}
|
||||
message_text = (
|
||||
result["messages"][-1].text.rstrip() if result["messages"][-1].text else ""
|
||||
)
|
||||
messages = result["messages"]
|
||||
if not messages:
|
||||
msg = (
|
||||
"CompiledSubAgent returned an empty 'messages' list. "
|
||||
"Subagents must produce at least one message so the parent has "
|
||||
"output to forward back to the user."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
last_text = getattr(messages[-1], "text", None) or ""
|
||||
message_text = last_text.rstrip()
|
||||
return Command(
|
||||
update={
|
||||
**state_update,
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""Spill + clear-tool-uses passes to keep payloads under budget."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.main_agent.context_prune.prune_tool_names import (
|
||||
safe_exclude_tools,
|
||||
)
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import (
|
||||
ClearToolUsesEdit,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
)
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_context_editing_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
max_input_tokens: int | None,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
) -> SpillingContextEditingMiddleware | None:
|
||||
if not enabled(flags, "enable_context_editing") or not max_input_tokens:
|
||||
return None
|
||||
spill_edit = SpillToBackendEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
)
|
||||
clear_edit = ClearToolUsesEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
placeholder="[cleared - older tool output trimmed for context]",
|
||||
)
|
||||
return SpillingContextEditingMiddleware(
|
||||
edits=[spill_edit, clear_edit],
|
||||
backend_resolver=backend_resolver,
|
||||
)
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
"""Drop duplicate HITL tool calls before execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware
|
||||
|
||||
|
||||
def build_dedup_hitl_mw(tools: Sequence[BaseTool]) -> DedupHITLToolCallsMiddleware:
|
||||
return DedupHITLToolCallsMiddleware(agent_tools=list(tools))
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Stop N identical tool calls in a row via interrupt."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import DoomLoopMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_doom_loop_mw(flags: AgentFeatureFlags) -> DoomLoopMiddleware | None:
|
||||
return DoomLoopMiddleware(threshold=3) if enabled(flags, "enable_doom_loop") else None
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Commit staged cloud filesystem mutations to Postgres at end of turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import KnowledgeBasePersistenceMiddleware
|
||||
|
||||
|
||||
def build_kb_persistence_mw(
|
||||
*,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
) -> KnowledgeBasePersistenceMiddleware | None:
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
return KnowledgeBasePersistenceMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""KB priority planner: <priority_documents> injection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import KnowledgePriorityMiddleware
|
||||
|
||||
|
||||
def build_knowledge_priority_mw(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
search_space_id: int,
|
||||
filesystem_mode: FilesystemMode,
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
) -> KnowledgePriorityMiddleware:
|
||||
return KnowledgePriorityMiddleware(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""<workspace_tree> injection (cloud only)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import KnowledgeTreeMiddleware
|
||||
|
||||
|
||||
def build_knowledge_tree_mw(
|
||||
*,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
llm: BaseChatModel,
|
||||
) -> KnowledgeTreeMiddleware | None:
|
||||
if filesystem_mode != FilesystemMode.CLOUD:
|
||||
return None
|
||||
return KnowledgeTreeMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
llm=llm,
|
||||
)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Provider-compat: append a `_noop` tool when tools=[] but history has tool calls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import NoopInjectionMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_noop_injection_mw(flags: AgentFeatureFlags) -> NoopInjectionMiddleware | None:
|
||||
return NoopInjectionMiddleware() if enabled(flags, "enable_compaction_v2") else None
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""OTel spans on model and tool calls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import OtelSpanMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_otel_mw(flags: AgentFeatureFlags) -> OtelSpanMiddleware | None:
|
||||
return OtelSpanMiddleware() if enabled(flags, "enable_otel") else None
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
"""Tail-of-stack plugin slot driven by env allowlist."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
PluginContext,
|
||||
load_allowed_plugin_names_from_env,
|
||||
load_plugin_middlewares,
|
||||
)
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_plugin_middlewares(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
visibility: ChatVisibility,
|
||||
llm: BaseChatModel,
|
||||
) -> list[Any]:
|
||||
if not enabled(flags, "enable_plugin_loader"):
|
||||
return []
|
||||
try:
|
||||
allowed_names = load_allowed_plugin_names_from_env()
|
||||
if not allowed_names:
|
||||
return []
|
||||
return load_plugin_middlewares(
|
||||
PluginContext.build(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_visibility=visibility,
|
||||
llm=llm,
|
||||
),
|
||||
allowed_plugin_names=allowed_names,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"Plugin loader failed; continuing without plugins.",
|
||||
exc_info=True,
|
||||
)
|
||||
return []
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""Repair miscased / unknown tool names to the registered set or invalid_tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import ToolCallNameRepairMiddleware
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
# deepagents-built-in tool names the repair pass treats as known.
|
||||
_DEEPAGENT_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"write_todos",
|
||||
"ls",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"glob",
|
||||
"grep",
|
||||
"execute",
|
||||
"task",
|
||||
"mkdir",
|
||||
"cd",
|
||||
"pwd",
|
||||
"move_file",
|
||||
"rm",
|
||||
"rmdir",
|
||||
"list_tree",
|
||||
"execute_code",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def build_repair_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
tools: Sequence[BaseTool],
|
||||
) -> ToolCallNameRepairMiddleware | None:
|
||||
if not enabled(flags, "enable_tool_call_repair"):
|
||||
return None
|
||||
registered_names: set[str] = {t.name for t in tools}
|
||||
registered_names |= _DEEPAGENT_BUILTIN_TOOL_NAMES
|
||||
return ToolCallNameRepairMiddleware(
|
||||
registered_tool_names=registered_names,
|
||||
fuzzy_match_threshold=None,
|
||||
)
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""LLM-based tool subset selection (only when >30 tools)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_selector_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
tools: Sequence[BaseTool],
|
||||
) -> LLMToolSelectorMiddleware | None:
|
||||
if not enabled(flags, "enable_llm_tool_selector") or len(tools) <= 30:
|
||||
return None
|
||||
try:
|
||||
return LLMToolSelectorMiddleware(
|
||||
model="openai:gpt-4o-mini",
|
||||
max_tools=12,
|
||||
always_include=[
|
||||
name
|
||||
for name in (
|
||||
"update_memory",
|
||||
"get_connected_accounts",
|
||||
"scrape_webpage",
|
||||
)
|
||||
if name in {t.name for t in tools}
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
|
||||
return None
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""Skill discovery + injection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from deepagents.middleware.skills import SkillsMiddleware
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import (
|
||||
build_skills_backend_factory,
|
||||
default_skills_sources,
|
||||
)
|
||||
|
||||
from ..shared.flags import enabled
|
||||
|
||||
|
||||
def build_skills_mw(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
) -> SkillsMiddleware | None:
|
||||
if not enabled(flags, "enable_skills"):
|
||||
return None
|
||||
try:
|
||||
skills_factory = build_skills_backend_factory(
|
||||
search_space_id=search_space_id
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
)
|
||||
return SkillsMiddleware(
|
||||
backend=skills_factory,
|
||||
sources=default_skills_sources(),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
|
||||
return None
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Anthropic prompt caching annotations on system/tool/message blocks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
|
||||
|
||||
def build_anthropic_cache_mw() -> AnthropicPromptCachingMiddleware:
|
||||
return AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
"""Context-window summarization with SurfSense protected sections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from deepagents.backends import StateBackend
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.middleware import create_surfsense_compaction_middleware
|
||||
|
||||
|
||||
def build_compaction_mw(llm: BaseChatModel) -> Any:
|
||||
return create_surfsense_compaction_middleware(llm, StateBackend)
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
"""File-intent classifier that gates strict write contracts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.new_chat.middleware import FileIntentMiddleware
|
||||
|
||||
|
||||
def build_file_intent_mw(llm: BaseChatModel) -> FileIntentMiddleware:
|
||||
return FileIntentMiddleware(llm=llm)
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
"""SurfSense filesystem tools/middleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
||||
def build_filesystem_mw(
|
||||
*,
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
) -> SurfSenseFilesystemMiddleware:
|
||||
return SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
"""Single source of truth for the feature-flag predicate."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
|
||||
|
||||
def enabled(flags: AgentFeatureFlags, attr: str) -> bool:
|
||||
"""``flags.<attr>`` is on AND the new-agent-stack kill switch is off."""
|
||||
return getattr(flags, attr) and not flags.disable_new_agent_stack
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""User/team memory injection prepended to the conversation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.middleware import MemoryInjectionMiddleware
|
||||
from app.db import ChatVisibility
|
||||
|
||||
|
||||
def build_memory_mw(
|
||||
*,
|
||||
user_id: str | None,
|
||||
search_space_id: int,
|
||||
visibility: ChatVisibility,
|
||||
) -> MemoryInjectionMiddleware:
|
||||
return MemoryInjectionMiddleware(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
thread_visibility=visibility,
|
||||
)
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Repair dangling tool-call sequences before each agent turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
|
||||
|
||||
def build_patch_tool_calls_mw() -> PatchToolCallsMiddleware:
|
||||
return PatchToolCallsMiddleware()
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Permission rulesets fanned out to parent / general-purpose / subagent stacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .context import PermissionContext, build_permission_context
|
||||
from .middleware import build_full_permission_mw
|
||||
|
||||
__all__ = [
|
||||
"PermissionContext",
|
||||
"build_full_permission_mw",
|
||||
"build_permission_context",
|
||||
]
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
"""Derive shared permission context once; fan out to all three stack layers.
|
||||
|
||||
The context carries:
|
||||
- ``rulesets``: full ask/deny/allow rules for the main-agent permission middleware.
|
||||
- ``general_purpose_interrupt_on``: ``ask`` rules mirrored as deepagents
|
||||
``interrupt_on`` so HITL still triggers from inside ``task`` runs (subagents
|
||||
bypass the main-agent permission middleware).
|
||||
- ``subagent_deny_mw``: a deny-only ``PermissionMiddleware`` instance shared
|
||||
across the general-purpose and registry subagent stacks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import PermissionMiddleware
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PermissionContext:
|
||||
rulesets: list[Ruleset]
|
||||
general_purpose_interrupt_on: dict[str, bool]
|
||||
subagent_deny_mw: PermissionMiddleware | None
|
||||
|
||||
|
||||
def build_permission_context(
|
||||
*,
|
||||
flags: AgentFeatureFlags,
|
||||
filesystem_mode: FilesystemMode,
|
||||
tools: Sequence[BaseTool],
|
||||
available_connectors: list[str] | None,
|
||||
) -> PermissionContext:
|
||||
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||
permission_enabled = enabled(flags, "enable_permission")
|
||||
|
||||
rulesets: list[Ruleset] = []
|
||||
if permission_enabled or is_desktop_fs:
|
||||
rulesets.append(
|
||||
Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
)
|
||||
)
|
||||
if is_desktop_fs:
|
||||
rulesets.append(
|
||||
Ruleset(
|
||||
rules=[
|
||||
Rule(permission="rm", pattern="*", action="ask"),
|
||||
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||
Rule(permission="move_file", pattern="*", action="ask"),
|
||||
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||
Rule(permission="write_file", pattern="*", action="ask"),
|
||||
],
|
||||
origin="desktop_safety",
|
||||
)
|
||||
)
|
||||
|
||||
tool_names_in_use = {t.name for t in tools}
|
||||
|
||||
if permission_enabled:
|
||||
available_set = set(available_connectors or [])
|
||||
synthesized: list[Rule] = []
|
||||
for tool_def in BUILTIN_TOOLS:
|
||||
if tool_def.name not in tool_names_in_use:
|
||||
continue
|
||||
rc = tool_def.required_connector
|
||||
if rc and rc not in available_set:
|
||||
synthesized.append(
|
||||
Rule(permission=tool_def.name, pattern="*", action="deny")
|
||||
)
|
||||
if synthesized:
|
||||
rulesets.append(
|
||||
Ruleset(rules=synthesized, origin="connector_synthesized")
|
||||
)
|
||||
|
||||
general_purpose_interrupt_on: dict[str, bool] = {
|
||||
rule.permission: True
|
||||
for rs in rulesets
|
||||
for rule in rs.rules
|
||||
if rule.action == "ask" and rule.permission in tool_names_in_use
|
||||
}
|
||||
|
||||
deny_rulesets = [
|
||||
Ruleset(
|
||||
rules=[r for r in rs.rules if r.action == "deny"],
|
||||
origin=rs.origin,
|
||||
)
|
||||
for rs in rulesets
|
||||
]
|
||||
deny_rulesets = [rs for rs in deny_rulesets if rs.rules]
|
||||
|
||||
subagent_deny_mw: PermissionMiddleware | None = (
|
||||
PermissionMiddleware(rulesets=deny_rulesets) if deny_rulesets else None
|
||||
)
|
||||
|
||||
return PermissionContext(
|
||||
rulesets=rulesets,
|
||||
general_purpose_interrupt_on=general_purpose_interrupt_on,
|
||||
subagent_deny_mw=subagent_deny_mw,
|
||||
)
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
"""Main-agent permission middleware (full ask/deny/allow rules)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.middleware import PermissionMiddleware
|
||||
from app.agents.new_chat.permissions import Ruleset
|
||||
|
||||
|
||||
def build_full_permission_mw(rulesets: list[Ruleset]) -> PermissionMiddleware | None:
|
||||
return PermissionMiddleware(rulesets=rulesets) if rulesets else None
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""Resilience middleware shared as the same instances across parent / general-purpose / registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .bundle import ResilienceBundle, build_resilience_bundle
|
||||
|
||||
__all__ = ["ResilienceBundle", "build_resilience_bundle"]
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
"""Construct each resilience middleware once; same instances flow into every consumer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import (
|
||||
ModelCallLimitMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import RetryAfterMiddleware
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
|
||||
from .fallback import build_fallback_mw
|
||||
from .model_call_limit import build_model_call_limit_mw
|
||||
from .retry import build_retry_mw
|
||||
from .tool_call_limit import build_tool_call_limit_mw
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResilienceBundle:
|
||||
retry: RetryAfterMiddleware | None
|
||||
fallback: ScopedModelFallbackMiddleware | None
|
||||
model_call_limit: ModelCallLimitMiddleware | None
|
||||
tool_call_limit: ToolCallLimitMiddleware | None
|
||||
|
||||
def as_list(self) -> list[Any]:
|
||||
return [
|
||||
m
|
||||
for m in (
|
||||
self.retry,
|
||||
self.fallback,
|
||||
self.model_call_limit,
|
||||
self.tool_call_limit,
|
||||
)
|
||||
if m is not None
|
||||
]
|
||||
|
||||
|
||||
def build_resilience_bundle(flags: AgentFeatureFlags) -> ResilienceBundle:
|
||||
return ResilienceBundle(
|
||||
retry=build_retry_mw(flags),
|
||||
fallback=build_fallback_mw(flags),
|
||||
model_call_limit=build_model_call_limit_mw(flags),
|
||||
tool_call_limit=build_tool_call_limit_mw(flags),
|
||||
)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""Switch to a fallback model on provider/network errors only."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
def build_fallback_mw(
|
||||
flags: AgentFeatureFlags,
|
||||
) -> ScopedModelFallbackMiddleware | None:
|
||||
if not enabled(flags, "enable_model_fallback"):
|
||||
return None
|
||||
try:
|
||||
return ScopedModelFallbackMiddleware(
|
||||
"openai:gpt-4o-mini",
|
||||
"anthropic:claude-3-5-haiku-20241022",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
||||
return None
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""Cap model calls per thread / per run to prevent runaway cost."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.agents.middleware import ModelCallLimitMiddleware
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
def build_model_call_limit_mw(
|
||||
flags: AgentFeatureFlags,
|
||||
) -> ModelCallLimitMiddleware | None:
|
||||
if not enabled(flags, "enable_model_call_limit"):
|
||||
return None
|
||||
return ModelCallLimitMiddleware(
|
||||
thread_limit=120,
|
||||
run_limit=80,
|
||||
exit_behavior="end",
|
||||
)
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
"""Retry on transient model errors (e.g. Retry-After-bearing 429s)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.middleware import RetryAfterMiddleware
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
def build_retry_mw(flags: AgentFeatureFlags) -> RetryAfterMiddleware | None:
|
||||
return (
|
||||
RetryAfterMiddleware(max_retries=3)
|
||||
if enabled(flags, "enable_retry_after")
|
||||
else None
|
||||
)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""Cap tool calls per thread / per run to bound infinite-loop blast radius."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.agents.middleware import ToolCallLimitMiddleware
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
||||
def build_tool_call_limit_mw(
|
||||
flags: AgentFeatureFlags,
|
||||
) -> ToolCallLimitMiddleware | None:
|
||||
if not enabled(flags, "enable_tool_call_limit"):
|
||||
return None
|
||||
return ToolCallLimitMiddleware(
|
||||
thread_limit=300,
|
||||
run_limit=80,
|
||||
exit_behavior="continue",
|
||||
)
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Todo-list middleware (each consumer needs its own instance)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
|
||||
|
||||
def build_todos_mw() -> TodoListMiddleware:
|
||||
return TodoListMiddleware()
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
"""Main-agent middleware list assembly: one line per slot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from deepagents.backends import StateBackend
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents import (
|
||||
build_subagents,
|
||||
get_subagents_to_exclude,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.builtins.general_purpose.agent import (
|
||||
build_subagent as build_general_purpose_subagent,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import ToolsPermissions
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .main_agent.action_log import build_action_log_mw
|
||||
from .main_agent.anonymous_doc import build_anonymous_doc_mw
|
||||
from .main_agent.busy_mutex import build_busy_mutex_mw
|
||||
from .main_agent.checkpointed_subagent_middleware import (
|
||||
SurfSenseCheckpointedSubAgentMiddleware,
|
||||
)
|
||||
from .main_agent.context_editing import build_context_editing_mw
|
||||
from .main_agent.dedup_hitl import build_dedup_hitl_mw
|
||||
from .main_agent.doom_loop import build_doom_loop_mw
|
||||
from .main_agent.kb_persistence import build_kb_persistence_mw
|
||||
from .main_agent.knowledge_priority import build_knowledge_priority_mw
|
||||
from .main_agent.knowledge_tree import build_knowledge_tree_mw
|
||||
from .main_agent.noop_injection import build_noop_injection_mw
|
||||
from .main_agent.otel import build_otel_mw
|
||||
from .main_agent.plugins import build_plugin_middlewares
|
||||
from .main_agent.repair import build_repair_mw
|
||||
from .main_agent.selector import build_selector_mw
|
||||
from .main_agent.skills import build_skills_mw
|
||||
from .shared.anthropic_cache import build_anthropic_cache_mw
|
||||
from .shared.compaction import build_compaction_mw
|
||||
from .shared.file_intent import build_file_intent_mw
|
||||
from .shared.filesystem import build_filesystem_mw
|
||||
from .shared.memory import build_memory_mw
|
||||
from .shared.patch_tool_calls import build_patch_tool_calls_mw
|
||||
from .shared.permissions import (
|
||||
build_full_permission_mw,
|
||||
build_permission_context,
|
||||
)
|
||||
from .shared.resilience import build_resilience_bundle
|
||||
from .shared.todos import build_todos_mw
|
||||
from .subagent.extras import build_subagent_extras
|
||||
|
||||
|
||||
def build_main_agent_deepagent_middleware(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
visibility: ChatVisibility,
|
||||
anon_session_id: str | None,
|
||||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
max_input_tokens: int | None,
|
||||
flags: AgentFeatureFlags,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
checkpointer: Checkpointer,
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Ordered middleware for ``create_agent`` (None entries already stripped)."""
|
||||
permissions = build_permission_context(
|
||||
flags=flags,
|
||||
filesystem_mode=filesystem_mode,
|
||||
tools=tools,
|
||||
available_connectors=available_connectors,
|
||||
)
|
||||
resilience = build_resilience_bundle(flags)
|
||||
|
||||
# Single instance threaded into both the main-agent stack and the general-purpose subagent.
|
||||
memory_mw = build_memory_mw(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
visibility=visibility,
|
||||
)
|
||||
|
||||
general_purpose_subagent = build_general_purpose_subagent(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
permissions=permissions,
|
||||
resilience=resilience,
|
||||
memory_mw=memory_mw,
|
||||
)
|
||||
|
||||
subagents_registry: list[SubAgent] = []
|
||||
try:
|
||||
subagent_extras = build_subagent_extras(
|
||||
permissions=permissions,
|
||||
resilience=resilience,
|
||||
)
|
||||
subagents_registry = build_subagents(
|
||||
dependencies=subagent_dependencies,
|
||||
model=llm,
|
||||
extra_middleware=subagent_extras,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent or {},
|
||||
exclude=get_subagents_to_exclude(available_connectors),
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
logging.debug(
|
||||
"Subagents registry: %s",
|
||||
[s["name"] for s in subagents_registry],
|
||||
)
|
||||
except Exception:
|
||||
# Degrade to general-purpose-only rather than aborting the turn:
|
||||
# one bad subagent dep should not deny the user a response.
|
||||
logging.exception(
|
||||
"Subagents registry build failed; falling back to general-purpose only"
|
||||
)
|
||||
subagents_registry = []
|
||||
|
||||
subagents: list[SubAgent] = [general_purpose_subagent, *subagents_registry]
|
||||
|
||||
stack: list[Any] = [
|
||||
build_busy_mutex_mw(flags),
|
||||
build_otel_mw(flags),
|
||||
build_todos_mw(),
|
||||
memory_mw,
|
||||
build_anonymous_doc_mw(
|
||||
filesystem_mode=filesystem_mode, anon_session_id=anon_session_id
|
||||
),
|
||||
build_knowledge_tree_mw(
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
llm=llm,
|
||||
),
|
||||
build_knowledge_priority_mw(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
filesystem_mode=filesystem_mode,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
),
|
||||
build_file_intent_mw(llm),
|
||||
build_filesystem_mw(
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
build_kb_persistence_mw(
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
build_skills_mw(
|
||||
flags=flags,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
),
|
||||
SurfSenseCheckpointedSubAgentMiddleware(
|
||||
checkpointer=checkpointer,
|
||||
backend=StateBackend,
|
||||
subagents=subagents,
|
||||
),
|
||||
build_selector_mw(flags=flags, tools=tools),
|
||||
resilience.model_call_limit,
|
||||
resilience.tool_call_limit,
|
||||
build_context_editing_mw(
|
||||
flags=flags,
|
||||
max_input_tokens=max_input_tokens,
|
||||
tools=tools,
|
||||
backend_resolver=backend_resolver,
|
||||
),
|
||||
build_compaction_mw(llm),
|
||||
build_noop_injection_mw(flags),
|
||||
resilience.retry,
|
||||
resilience.fallback,
|
||||
build_repair_mw(flags=flags, tools=tools),
|
||||
build_full_permission_mw(permissions.rulesets),
|
||||
build_doom_loop_mw(flags),
|
||||
build_action_log_mw(
|
||||
flags=flags,
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
),
|
||||
build_patch_tool_calls_mw(),
|
||||
build_dedup_hitl_mw(tools),
|
||||
*build_plugin_middlewares(
|
||||
flags=flags,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
visibility=visibility,
|
||||
llm=llm,
|
||||
),
|
||||
build_anthropic_cache_mw(),
|
||||
]
|
||||
return [m for m in stack if m is not None]
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
"""Extra middleware threaded into every registry subagent's stack.
|
||||
|
||||
Registry subagents are scoped to one domain (deliverables, research, memory,
|
||||
connectors, MCP) and never read or write the SurfSense filesystem — that
|
||||
capability belongs to the main agent and is delegated to the general-purpose
|
||||
subagent as an escape hatch. Keeping FS off the registry stacks avoids
|
||||
polluting their tool surface with FS tools they never act on.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..shared.permissions import PermissionContext
|
||||
from ..shared.resilience import ResilienceBundle
|
||||
from ..shared.todos import build_todos_mw
|
||||
|
||||
|
||||
def build_subagent_extras(
|
||||
*,
|
||||
permissions: PermissionContext,
|
||||
resilience: ResilienceBundle,
|
||||
) -> list[Any]:
|
||||
extras: list[Any] = [build_todos_mw()]
|
||||
if permissions.subagent_deny_mw is not None:
|
||||
extras.append(permissions.subagent_deny_mw)
|
||||
extras.extend(resilience.as_list())
|
||||
return extras
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
"""General-purpose subagent for the multi-agent main agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from deepagents import SubAgent
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.shared.anthropic_cache import (
|
||||
build_anthropic_cache_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.compaction import (
|
||||
build_compaction_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.file_intent import (
|
||||
build_file_intent_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.filesystem import (
|
||||
build_filesystem_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.patch_tool_calls import (
|
||||
build_patch_tool_calls_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.permissions import (
|
||||
PermissionContext,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.resilience import (
|
||||
ResilienceBundle,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.shared.todos import build_todos_mw
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.middleware import MemoryInjectionMiddleware
|
||||
|
||||
NAME = "general-purpose"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
tools: Sequence[BaseTool],
|
||||
backend_resolver: Any,
|
||||
filesystem_mode: FilesystemMode,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
permissions: PermissionContext,
|
||||
resilience: ResilienceBundle,
|
||||
memory_mw: MemoryInjectionMiddleware,
|
||||
) -> SubAgent:
|
||||
"""Deny + resilience inserts encapsulated here so the orchestrator never mutates the list."""
|
||||
middleware: list[Any] = [
|
||||
build_todos_mw(),
|
||||
memory_mw,
|
||||
build_file_intent_mw(llm),
|
||||
build_filesystem_mw(
|
||||
backend_resolver=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
build_compaction_mw(llm),
|
||||
build_patch_tool_calls_mw(),
|
||||
build_anthropic_cache_mw(),
|
||||
]
|
||||
|
||||
if permissions.subagent_deny_mw is not None:
|
||||
patch_idx = next(
|
||||
(
|
||||
i
|
||||
for i, m in enumerate(middleware)
|
||||
if isinstance(m, PatchToolCallsMiddleware)
|
||||
),
|
||||
len(middleware),
|
||||
)
|
||||
middleware.insert(patch_idx, permissions.subagent_deny_mw)
|
||||
|
||||
resilience_mws = resilience.as_list()
|
||||
if resilience_mws:
|
||||
cache_idx = next(
|
||||
(
|
||||
i
|
||||
for i, m in enumerate(middleware)
|
||||
if isinstance(m, AnthropicPromptCachingMiddleware)
|
||||
),
|
||||
len(middleware),
|
||||
)
|
||||
for offset, mw in enumerate(resilience_mws):
|
||||
middleware.insert(cache_idx + offset, mw)
|
||||
|
||||
spec: dict[str, Any] = {
|
||||
**GENERAL_PURPOSE_SUBAGENT,
|
||||
"model": llm,
|
||||
"tools": tools,
|
||||
"middleware": middleware,
|
||||
}
|
||||
if permissions.general_purpose_interrupt_on:
|
||||
spec["interrupt_on"] = permissions.general_purpose_interrupt_on
|
||||
return cast(SubAgent, spec)
|
||||
|
|
@ -31,7 +31,6 @@ from langchain.agents import create_agent
|
|||
from langchain.agents.middleware import (
|
||||
LLMToolSelectorMiddleware,
|
||||
ModelCallLimitMiddleware,
|
||||
ModelFallbackMiddleware,
|
||||
TodoListMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
|
|
@ -77,6 +76,9 @@ from app.agents.new_chat.middleware import (
|
|||
create_surfsense_compaction_middleware,
|
||||
default_skills_sources,
|
||||
)
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
PluginContext,
|
||||
|
|
@ -792,15 +794,15 @@ def _build_compiled_agent_blocking(
|
|||
# Fallback chain — primary is the agent's own model; we add cheap
|
||||
# alternatives. Off by default; only the first call site that
|
||||
# configures the chain via env should enable it.
|
||||
fallback_mw: ModelFallbackMiddleware | None = None
|
||||
fallback_mw: ScopedModelFallbackMiddleware | None = None
|
||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
fallback_mw = ModelFallbackMiddleware(
|
||||
fallback_mw = ScopedModelFallbackMiddleware(
|
||||
"openai:gpt-4o-mini",
|
||||
"anthropic:claude-3-5-haiku-20241022",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
||||
logging.warning("ScopedModelFallbackMiddleware init failed; skipping.")
|
||||
fallback_mw = None
|
||||
model_call_limit_mw = (
|
||||
ModelCallLimitMiddleware(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,91 @@
|
|||
"""Fallback only on provider/network errors; let programming bugs raise."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
# Matched by class name across the MRO so we don't have to import every
|
||||
# provider SDK (openai/anthropic/google/...). Extend as new providers ship.
|
||||
_FALLBACK_ELIGIBLE_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"RateLimitError",
|
||||
"APIStatusError",
|
||||
"InternalServerError",
|
||||
"ServiceUnavailableError",
|
||||
"BadGatewayError",
|
||||
"GatewayTimeoutError",
|
||||
"APIConnectionError",
|
||||
"APITimeoutError",
|
||||
"ConnectError",
|
||||
"ConnectTimeout",
|
||||
"ReadTimeout",
|
||||
"RemoteProtocolError",
|
||||
"TimeoutError",
|
||||
"TimeoutException",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_fallback_eligible(exc: BaseException) -> bool:
|
||||
return any(cls.__name__ in _FALLBACK_ELIGIBLE_NAMES for cls in type(exc).__mro__)
|
||||
|
||||
|
||||
class ScopedModelFallbackMiddleware(ModelFallbackMiddleware):
|
||||
"""Re-raise non-provider exceptions instead of walking the fallback chain."""
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[Any],
|
||||
handler: Callable[[ModelRequest[Any]], ModelResponse[Any]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
last_exception: Exception
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
|
||||
for fallback_model in self.models:
|
||||
try:
|
||||
return handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[Any],
|
||||
handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
last_exception: Exception
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
|
||||
for fallback_model in self.models:
|
||||
try:
|
||||
return await handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
if not _is_fallback_eligible(e):
|
||||
raise
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
|
@ -28,9 +28,7 @@ from langchain_core.messages import HumanMessage
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.multi_agent_chat import (
|
||||
create_surfsense_deep_agent as create_registry_deep_agent,
|
||||
)
|
||||
from app.agents.multi_agent_chat import create_multi_agent_chat_deep_agent
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
|
|
@ -577,6 +575,43 @@ async def _preflight_llm(llm: Any) -> None:
|
|||
)
|
||||
|
||||
|
||||
async def _build_main_agent_for_thread(
|
||||
agent_factory: Any,
|
||||
*,
|
||||
llm: Any,
|
||||
search_space_id: int,
|
||||
db_session: Any,
|
||||
connector_service: ConnectorService,
|
||||
checkpointer: Any,
|
||||
user_id: str | None,
|
||||
thread_id: int | None,
|
||||
agent_config: AgentConfig | None,
|
||||
firecrawl_api_key: str | None,
|
||||
thread_visibility: ChatVisibility | None,
|
||||
filesystem_selection: FilesystemSelection | None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
) -> Any:
|
||||
"""Single (re)build path so the agent factory cannot drift across
|
||||
initial build, preflight repin, and mid-stream 429 recovery for one
|
||||
``thread_id``: a graph swap mid-turn would corrupt checkpointer state."""
|
||||
return await agent_factory(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=db_session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=thread_visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
)
|
||||
|
||||
|
||||
async def _settle_speculative_agent_build(task: asyncio.Task[Any]) -> None:
|
||||
"""Wait for a discarded speculative agent build to release shared state.
|
||||
|
||||
|
|
@ -2767,7 +2802,7 @@ async def stream_new_chat(
|
|||
|
||||
_t0 = time.perf_counter()
|
||||
agent_factory = (
|
||||
create_registry_deep_agent
|
||||
create_multi_agent_chat_deep_agent
|
||||
if use_multi_agent
|
||||
else create_surfsense_deep_agent
|
||||
)
|
||||
|
|
@ -2776,7 +2811,8 @@ async def stream_new_chat(
|
|||
# if preflight reports 429 we will discard this future and rebuild
|
||||
# against the freshly pinned config below.
|
||||
agent_build_task = asyncio.create_task(
|
||||
agent_factory(
|
||||
_build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
|
|
@ -2787,9 +2823,9 @@ async def stream_new_chat(
|
|||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
filesystem_selection=filesystem_selection,
|
||||
),
|
||||
name="agent_build:stream_new_chat",
|
||||
)
|
||||
|
|
@ -3466,7 +3502,8 @@ async def stream_new_chat(
|
|||
title_task = None
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
agent = await _build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
|
|
@ -3477,9 +3514,9 @@ async def stream_new_chat(
|
|||
agent_config=agent_config,
|
||||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
filesystem_selection=filesystem_selection,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Runtime rate-limit recovery repinned "
|
||||
|
|
@ -4130,12 +4167,13 @@ async def stream_resume_chat(
|
|||
|
||||
_t0 = time.perf_counter()
|
||||
agent_factory = (
|
||||
create_registry_deep_agent
|
||||
create_multi_agent_chat_deep_agent
|
||||
if _app_config.MULTI_AGENT_CHAT_ENABLED
|
||||
else create_surfsense_deep_agent
|
||||
)
|
||||
agent_build_task = asyncio.create_task(
|
||||
agent_factory(
|
||||
_build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
|
|
@ -4224,7 +4262,8 @@ async def stream_resume_chat(
|
|||
"fallback_config_id": llm_config_id,
|
||||
},
|
||||
)
|
||||
agent = await agent_factory(
|
||||
agent = await _build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
|
|
@ -4409,7 +4448,8 @@ async def stream_resume_chat(
|
|||
raise stream_exc
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
agent = await create_surfsense_deep_agent(
|
||||
agent = await _build_main_agent_for_thread(
|
||||
agent_factory,
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
|
|
@ -4421,6 +4461,7 @@ async def stream_resume_chat(
|
|||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_resume] Runtime rate-limit recovery repinned "
|
||||
|
|
|
|||
|
|
@ -0,0 +1,208 @@
|
|||
"""End-to-end resume-bridge tests against a real LangGraph subagent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Command, interrupt
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
|
||||
|
||||
class _SubagentState(TypedDict, total=False):
|
||||
messages: list
|
||||
decision_text: str
|
||||
|
||||
|
||||
def _build_single_interrupt_subagent():
|
||||
def approve_node(state):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{
|
||||
"name": "do_thing",
|
||||
"args": {"x": 1},
|
||||
"description": "test action",
|
||||
}
|
||||
],
|
||||
"review_configs": [{}],
|
||||
}
|
||||
)
|
||||
return {
|
||||
"messages": [AIMessage(content="done")],
|
||||
"decision_text": repr(decision),
|
||||
}
|
||||
|
||||
graph = StateGraph(_SubagentState)
|
||||
graph.add_node("approve", approve_node)
|
||||
graph.add_edge(START, "approve")
|
||||
graph.add_edge("approve", END)
|
||||
return graph.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
def _make_runtime(config: dict) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state={"messages": [HumanMessage(content="seed")]},
|
||||
context=None,
|
||||
config=config,
|
||||
stream_writer=None,
|
||||
tool_call_id="parent-tcid-1",
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_bridge_dispatches_decision_into_pending_subagent():
|
||||
"""Side-channel decision must reach the subagent's pending interrupt verbatim."""
|
||||
subagent = _build_single_interrupt_subagent()
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "approver",
|
||||
"description": "approves things",
|
||||
"runnable": subagent,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "shared-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
snap = await subagent.aget_state(parent_config)
|
||||
assert snap.tasks and snap.tasks[0].interrupts, (
|
||||
"fixture broken: subagent should be paused on its interrupt"
|
||||
)
|
||||
|
||||
parent_config["configurable"]["surfsense_resume_value"] = {
|
||||
"decisions": ["APPROVED"]
|
||||
}
|
||||
runtime = _make_runtime(parent_config)
|
||||
|
||||
result = await task_tool.coroutine(
|
||||
description="please approve",
|
||||
subagent_type="approver",
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
assert isinstance(result, Command)
|
||||
update = result.update
|
||||
assert update["decision_text"] == repr({"decisions": ["APPROVED"]})
|
||||
assert "surfsense_resume_value" not in parent_config["configurable"]
|
||||
|
||||
final = await subagent.aget_state(parent_config)
|
||||
assert not final.tasks or all(not t.interrupts for t in final.tasks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_interrupt_without_resume_value_raises_runtime_error():
|
||||
"""Bridge must fail loud rather than silently replay the user's interrupt."""
|
||||
subagent = _build_single_interrupt_subagent()
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "approver",
|
||||
"description": "approves things",
|
||||
"runnable": subagent,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "guard-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
snap = await subagent.aget_state(parent_config)
|
||||
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
|
||||
|
||||
runtime = _make_runtime(parent_config)
|
||||
|
||||
with pytest.raises(RuntimeError, match="resume bridge is broken"):
|
||||
await task_tool.coroutine(
|
||||
description="please approve",
|
||||
subagent_type="approver",
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
|
||||
def _build_bundle_subagent():
|
||||
def bundle_node(state):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{"name": "create_a", "args": {}, "description": ""},
|
||||
{"name": "create_b", "args": {}, "description": ""},
|
||||
{"name": "create_c", "args": {}, "description": ""},
|
||||
],
|
||||
"review_configs": [{}, {}, {}],
|
||||
}
|
||||
)
|
||||
return {
|
||||
"messages": [AIMessage(content="bundle-done")],
|
||||
"decision_text": repr(decision),
|
||||
}
|
||||
|
||||
graph = StateGraph(_SubagentState)
|
||||
graph.add_node("bundle", bundle_node)
|
||||
graph.add_edge(START, "bundle")
|
||||
graph.add_edge("bundle", END)
|
||||
return graph.compile(checkpointer=InMemorySaver())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bundle_three_mixed_decisions_arrive_in_order():
|
||||
"""Approve / edit / reject for a 3-action bundle must land at ordinals 0/1/2."""
|
||||
subagent = _build_bundle_subagent()
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{
|
||||
"name": "bundler",
|
||||
"description": "creates a bundle",
|
||||
"runnable": subagent,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
parent_config: dict = {
|
||||
"configurable": {"thread_id": "bundle-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||
|
||||
decisions_payload = {
|
||||
"decisions": [
|
||||
{"type": "approve", "args": {}},
|
||||
{"type": "edit", "args": {"args": {"name": "edited-b"}}},
|
||||
{"type": "reject", "args": {"message": "no thanks"}},
|
||||
]
|
||||
}
|
||||
parent_config["configurable"]["surfsense_resume_value"] = decisions_payload
|
||||
runtime = _make_runtime(parent_config)
|
||||
|
||||
result = await task_tool.coroutine(
|
||||
description="run bundle",
|
||||
subagent_type="bundler",
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
assert isinstance(result, Command)
|
||||
received = ast.literal_eval(result.update["decision_text"])
|
||||
assert received == decisions_payload
|
||||
assert received["decisions"][0]["type"] == "approve"
|
||||
assert received["decisions"][1]["type"] == "edit"
|
||||
assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}}
|
||||
assert received["decisions"][2]["type"] == "reject"
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""Pins the first-wins assumption of ``get_first_pending_subagent_interrupt``.
|
||||
|
||||
The bridge currently relies on at-most-one pending interrupt per snapshot
|
||||
(sequential tool nodes). If parallel tool calls are ever enabled, the bridge
|
||||
needs an id-aware lookup; these tests will need to be revisited at that point.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume import (
|
||||
get_first_pending_subagent_interrupt,
|
||||
)
|
||||
|
||||
|
||||
class TestGetFirstPendingSubagentInterrupt:
|
||||
def test_returns_first_when_multiple_top_level_interrupts_pending(self):
|
||||
first = SimpleNamespace(id="i-1", value={"decision": "approve"})
|
||||
second = SimpleNamespace(id="i-2", value={"decision": "reject"})
|
||||
state = SimpleNamespace(interrupts=(first, second), tasks=())
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == (
|
||||
"i-1",
|
||||
{"decision": "approve"},
|
||||
)
|
||||
|
||||
def test_returns_first_when_multiple_subtask_interrupts_pending(self):
|
||||
first = SimpleNamespace(id="i-A", value="approve")
|
||||
second = SimpleNamespace(id="i-B", value="reject")
|
||||
sub_task = SimpleNamespace(interrupts=(first, second))
|
||||
state = SimpleNamespace(interrupts=(), tasks=(sub_task,))
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == ("i-A", "approve")
|
||||
|
||||
def test_returns_none_when_no_interrupts(self):
|
||||
state = SimpleNamespace(interrupts=(), tasks=())
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == (None, None)
|
||||
|
||||
def test_returns_none_when_state_is_none(self):
|
||||
assert get_first_pending_subagent_interrupt(None) == (None, None)
|
||||
|
||||
def test_skips_interrupts_with_none_value(self):
|
||||
empty = SimpleNamespace(id="i-empty", value=None)
|
||||
real = SimpleNamespace(id="i-real", value="approve")
|
||||
state = SimpleNamespace(interrupts=(empty, real), tasks=())
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == ("i-real", "approve")
|
||||
|
||||
def test_normalizes_non_string_id_to_none(self):
|
||||
interrupt = SimpleNamespace(id=12345, value="approve")
|
||||
state = SimpleNamespace(interrupts=(interrupt,), tasks=())
|
||||
|
||||
assert get_first_pending_subagent_interrupt(state) == (None, "approve")
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
"""Resume side-channel must be read exactly once per turn."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
|
||||
consume_surfsense_resume,
|
||||
has_surfsense_resume,
|
||||
)
|
||||
|
||||
|
||||
def _runtime_with_config(config: dict) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state=None,
|
||||
context=None,
|
||||
config=config,
|
||||
stream_writer=None,
|
||||
tool_call_id="tcid-test",
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
class TestConsumeSurfsenseResume:
|
||||
def test_pops_value_on_first_call(self):
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
|
||||
)
|
||||
|
||||
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
|
||||
|
||||
def test_second_call_returns_none(self):
|
||||
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
|
||||
runtime = _runtime_with_config({"configurable": configurable})
|
||||
|
||||
consume_surfsense_resume(runtime)
|
||||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
assert "surfsense_resume_value" not in configurable
|
||||
|
||||
def test_returns_none_when_no_payload_queued(self):
|
||||
runtime = _runtime_with_config({"configurable": {}})
|
||||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
|
||||
def test_returns_none_when_configurable_missing(self):
|
||||
runtime = _runtime_with_config({})
|
||||
|
||||
assert consume_surfsense_resume(runtime) is None
|
||||
|
||||
|
||||
class TestHasSurfsenseResume:
|
||||
def test_true_when_payload_queued(self):
|
||||
runtime = _runtime_with_config(
|
||||
{"configurable": {"surfsense_resume_value": "approve"}}
|
||||
)
|
||||
|
||||
assert has_surfsense_resume(runtime) is True
|
||||
|
||||
def test_does_not_consume_payload(self):
|
||||
configurable = {"surfsense_resume_value": "approve"}
|
||||
runtime = _runtime_with_config({"configurable": configurable})
|
||||
|
||||
has_surfsense_resume(runtime)
|
||||
|
||||
assert configurable == {"surfsense_resume_value": "approve"}
|
||||
|
||||
def test_false_when_payload_absent(self):
|
||||
runtime = _runtime_with_config({"configurable": {}})
|
||||
|
||||
assert has_surfsense_resume(runtime) is False
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
"""Subagent resilience contract: ``extra_middleware`` reaches the agent chain."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import ModelFallbackMiddleware
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeMessagesListChatModel,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""Name matches the scoped-fallback eligibility allowlist."""
|
||||
|
||||
|
||||
class _AlwaysFailingChatModel(BaseChatModel):
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "always-failing-test-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "primary llm exploded"
|
||||
raise RateLimitError(msg)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "primary llm exploded"
|
||||
raise RateLimitError(msg)
|
||||
|
||||
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
|
||||
msg = "primary llm exploded"
|
||||
raise RateLimitError(msg)
|
||||
|
||||
async def _astream(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> AsyncIterator[ChatGeneration]:
|
||||
msg = "primary llm exploded"
|
||||
raise RateLimitError(msg)
|
||||
yield # pragma: no cover - unreachable, satisfies async generator typing
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_recovers_when_primary_llm_fails():
|
||||
"""Fallback in ``extra_middleware`` must finish the turn when primary raises."""
|
||||
primary = _AlwaysFailingChatModel()
|
||||
fallback = FakeMessagesListChatModel(
|
||||
responses=[AIMessage(content="recovered via fallback")]
|
||||
)
|
||||
|
||||
spec = pack_subagent(
|
||||
name="resilience_test",
|
||||
description="test subagent",
|
||||
system_prompt="be helpful",
|
||||
tools=[],
|
||||
model=primary,
|
||||
extra_middleware=[ModelFallbackMiddleware(fallback)],
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=spec["model"],
|
||||
tools=spec["tools"],
|
||||
middleware=spec["middleware"],
|
||||
system_prompt=spec["system_prompt"],
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage(content="hi")]})
|
||||
|
||||
final = result["messages"][-1]
|
||||
assert isinstance(final, AIMessage)
|
||||
assert final.content == "recovered via fallback"
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
"""``ScopedModelFallbackMiddleware`` triggers fallback only on provider errors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
|
||||
class _RaisingChatModel(BaseChatModel):
|
||||
exc_to_raise: Any
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "raising-test-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise self.exc_to_raise
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise self.exc_to_raise
|
||||
|
||||
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGeneration]:
|
||||
raise self.exc_to_raise
|
||||
|
||||
async def _astream(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> AsyncIterator[ChatGeneration]:
|
||||
raise self.exc_to_raise
|
||||
yield # pragma: no cover - unreachable
|
||||
|
||||
|
||||
class _RecordingChatModel(BaseChatModel):
|
||||
response_text: str = "fallback-ok"
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "recording-test-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
self.call_count += 1
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=AIMessage(content=self.response_text))
|
||||
]
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return self._generate(messages, stop, None, **kwargs)
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""Name matches the scoped-fallback eligibility allowlist."""
|
||||
|
||||
|
||||
def _build_agent(primary: BaseChatModel, fallback: BaseChatModel):
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from app.agents.new_chat.middleware.scoped_model_fallback import (
|
||||
ScopedModelFallbackMiddleware,
|
||||
)
|
||||
|
||||
return create_agent(
|
||||
model=primary,
|
||||
tools=[],
|
||||
middleware=[ScopedModelFallbackMiddleware(fallback)],
|
||||
system_prompt="be helpful",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_errors_trigger_fallback():
|
||||
"""Eligible exception names must drive the fallback chain."""
|
||||
primary = _RaisingChatModel(exc_to_raise=RateLimitError("429 from provider"))
|
||||
fallback = _RecordingChatModel(response_text="recovered")
|
||||
|
||||
agent = _build_agent(primary, fallback)
|
||||
result = await agent.ainvoke({"messages": [("user", "hi")]})
|
||||
|
||||
final = result["messages"][-1]
|
||||
assert isinstance(final, AIMessage)
|
||||
assert final.content == "recovered"
|
||||
assert fallback.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_programming_errors_propagate_without_invoking_fallback():
|
||||
"""Non-eligible exceptions must propagate; fallback must not be invoked."""
|
||||
primary = _RaisingChatModel(exc_to_raise=KeyError("missing_state_field"))
|
||||
fallback = _RecordingChatModel(response_text="should-never-arrive")
|
||||
|
||||
agent = _build_agent(primary, fallback)
|
||||
|
||||
with pytest.raises(KeyError, match="missing_state_field"):
|
||||
await agent.ainvoke({"messages": [("user", "hi")]})
|
||||
|
||||
assert fallback.call_count == 0
|
||||
|
|
@ -202,6 +202,15 @@ class FakeBudgetLLM:
|
|||
|
||||
|
||||
class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _disable_planner_runnable(self, monkeypatch):
|
||||
# ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the
|
||||
# planner Runnable path is enabled) calls ``.bind()`` on the LLM,
|
||||
# which the mock does not implement. Pin the flag off so the
|
||||
# planner falls through to the legacy ``self.llm.ainvoke`` path
|
||||
# these tests assert against (``llm.calls[0]["config"]``).
|
||||
monkeypatch.setenv("SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "false")
|
||||
|
||||
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
|
||||
messages = [
|
||||
HumanMessage(content="old user context " * 40),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue