Compose supervisor LangChain agent with SurfSense middleware and connector discovery.

This commit is contained in:
CREDO23 2026-04-30 03:53:22 +02:00
parent 33fc457dcc
commit 2ab4c411fe
5 changed files with 524 additions and 15 deletions

View file

@ -86,7 +86,12 @@ def partition_mcp_tools_by_expert_route(
connector_id_to_type: dict[int, str],
connector_name_to_type: dict[str, str],
) -> dict[str, list[BaseTool]]:
"""Bucket MCP tools by expert route key. Supervisor never receives raw MCP tools."""
"""Bucket MCP tools by expert route key. Supervisor never receives raw MCP tools.
Same inclusion rule as ``new_chat.tools.registry.build_tools_async``: all tools returned by
``load_mcp_tools`` are partitioned connector availability for **registry** builtins is handled via
``get_connector_gated_tools`` / routing gates; MCP tools are not pre-filtered by inventory here.
"""
buckets: dict[str, list[BaseTool]] = defaultdict(list)
for tool in tools:

View file

@ -2,6 +2,8 @@
from __future__ import annotations
import asyncio
import logging
from typing import Any
from langchain_core.language_models import BaseChatModel
@ -9,18 +11,69 @@ from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import ChatVisibility
from app.agents.new_chat.chat_deepagent import _map_connectors_to_searchable_types
from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.feature_flags import get_flags
from app.agents.new_chat.filesystem_backends import build_backend_resolver
from app.agents.new_chat.filesystem_selection import FilesystemSelection
from app.agents.new_chat.tools.mcp_tool import load_mcp_tools
from app.db import ChatVisibility
from app.agents.multi_agent_chat.core.mcp_partition import (
fetch_mcp_connector_metadata_maps,
partition_mcp_tools_by_expert_route,
)
from app.agents.multi_agent_chat.core.registry import build_registry_dependencies
from app.agents.multi_agent_chat.middleware.supervisor_stack import build_supervisor_middleware_stack
from app.agents.multi_agent_chat.routing.supervisor_routing import build_supervisor_routing_tools
from app.agents.multi_agent_chat.supervisor import build_supervisor_agent
logger = logging.getLogger(__name__)
def _compile_supervisor_chat_blocking(
*,
llm: BaseChatModel,
routing_tools: list[BaseTool],
checkpointer: Checkpointer | None,
backend_resolver: Any,
filesystem_mode: Any,
search_space_id: int,
user_id: str,
thread_id: str | None,
thread_visibility: ChatVisibility,
anon_session_id: str | None,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
max_input_tokens: int | None,
) -> Any:
"""CPU-heavy: middleware assembly + ``create_agent`` (runs in a worker thread)."""
flags = get_flags()
middleware = build_supervisor_middleware_stack(
llm=llm,
tools=routing_tools,
backend_resolver=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
visibility=thread_visibility,
anon_session_id=anon_session_id,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
max_input_tokens=max_input_tokens,
flags=flags,
)
return build_supervisor_agent(
llm,
tools=routing_tools,
checkpointer=checkpointer,
middleware=middleware,
context_schema=SurfSenseContextSchema,
)
async def create_multi_agent_chat(
llm: BaseChatModel,
@ -36,18 +89,54 @@ async def create_multi_agent_chat(
available_document_types: list[str] | None = None,
thread_visibility: ChatVisibility = ChatVisibility.PRIVATE,
include_mcp_tools: bool = True,
filesystem_selection: FilesystemSelection | None = None,
anon_session_id: str | None = None,
mentioned_document_ids: list[int] | None = None,
max_input_tokens: int | None = None,
surfsense_stack: bool = True,
):
"""Build the full multi-agent chat graph (supervisor + domain subgraphs via routing tools).
**Builtins** (:mod:`expert_agent.builtins`): registry-grouped **categories** (research, memory, deliverables).
**Connectors** (:mod:`expert_agent.connectors`): **vendor integrations** one subgraph each where split
(e.g. Gmail, Calendar, Discord, Teams, Notion, Confluence, Google Drive, Dropbox, OneDrive, Luma).
**Connectors** (:mod:`expert_agent.connectors`): **vendor integrations** one subgraph per route in
``TOOL_NAMES_BY_CATEGORY`` (e.g. calendar, confluence, discord, dropbox, gmail, google_drive, luma, notion, onedrive, teams).
MCP tools from ``new_chat`` (``load_mcp_tools``) are partitioned inside this package and attached only
to the matching expert subgraphs not to the supervisor tool list as raw MCP calls.
to the matching expert subgraphs not to the supervisor tool list as raw MCP calls. Inclusion matches
``new_chat.tools.registry.build_tools_async``: all tools returned by ``load_mcp_tools`` are merged
after partitioning (no extra inventory filter on MCP). Connector routing uses ``available_connectors``:
pass explicitly, or provide ``connector_service`` so lists are resolved like
``create_surfsense_deep_agent`` (``get_available_connectors`` searchable types).
Deliverables (thread-scoped reports, podcasts, etc.) are registered only when ``thread_id`` is set.
When ``surfsense_stack`` is true (default), the supervisor uses the same SurfSense middleware shell as
``new_chat`` (KB priority/tree, filesystem, compaction, permissions, etc.) except ``SubAgentMiddleware`` /
``task``, since experts are separate graphs behind routing tools. Graph compilation runs in
``asyncio.to_thread`` so heavy CPU work does not block the event loop.
"""
resolved_connectors = available_connectors
resolved_doc_types = available_document_types
if connector_service is not None:
try:
if resolved_connectors is None:
connector_types = await connector_service.get_available_connectors(
search_space_id
)
if connector_types:
resolved_connectors = _map_connectors_to_searchable_types(
connector_types
)
if resolved_doc_types is None:
resolved_doc_types = (
await connector_service.get_available_document_types(search_space_id)
)
except Exception as exc:
logger.warning(
"Failed to discover available connectors/document types: %s",
exc,
)
mcp_tools_by_route: dict[str, list[BaseTool]] | None = None
if include_mcp_tools:
mcp_flat = await load_mcp_tools(db_session, search_space_id)
@ -62,8 +151,8 @@ async def create_multi_agent_chat(
llm=llm,
firecrawl_api_key=firecrawl_api_key,
connector_service=connector_service,
available_connectors=available_connectors,
available_document_types=available_document_types,
available_connectors=resolved_connectors,
available_document_types=resolved_doc_types,
thread_visibility=thread_visibility,
)
routing_tools = build_supervisor_routing_tools(
@ -71,5 +160,31 @@ async def create_multi_agent_chat(
registry_dependencies=registry_dependencies,
include_deliverables=thread_id is not None,
mcp_tools_by_route=mcp_tools_by_route,
available_connectors=resolved_connectors,
)
fs_sel = filesystem_selection or FilesystemSelection()
backend_resolver = build_backend_resolver(fs_sel, search_space_id=search_space_id)
if not surfsense_stack:
return build_supervisor_agent(
llm, tools=routing_tools, checkpointer=checkpointer
)
return await asyncio.to_thread(
_compile_supervisor_chat_blocking,
llm=llm,
routing_tools=routing_tools,
checkpointer=checkpointer,
backend_resolver=backend_resolver,
filesystem_mode=fs_sel.mode,
search_space_id=search_space_id,
user_id=user_id,
thread_id=thread_id,
thread_visibility=thread_visibility,
anon_session_id=anon_session_id,
available_connectors=resolved_connectors,
available_document_types=resolved_doc_types,
mentioned_document_ids=mentioned_document_ids,
max_input_tokens=max_input_tokens,
)
return build_supervisor_agent(llm, tools=routing_tools, checkpointer=checkpointer)

View file

@ -0,0 +1,11 @@
"""SurfSense supervisor middleware (parity with ``new_chat`` main agent, minus subagents)."""
from app.agents.multi_agent_chat.middleware.supervisor_stack import (
build_supervisor_middleware_stack,
parse_thread_id_for_action_log,
)
__all__ = [
"build_supervisor_middleware_stack",
"parse_thread_id_for_action_log",
]

View file

@ -0,0 +1,363 @@
"""Supervisor middleware stack matching ``new_chat`` main agent (no ``SubAgentMiddleware`` / ``task``)."""
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import Any
from deepagents.backends import StateBackend
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
from deepagents.middleware.skills import SkillsMiddleware
from langchain.agents.middleware import (
LLMToolSelectorMiddleware,
ModelCallLimitMiddleware,
ModelFallbackMiddleware,
TodoListMiddleware,
ToolCallLimitMiddleware,
)
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware import (
ActionLogMiddleware,
AnonymousDocumentMiddleware,
BusyMutexMiddleware,
ClearToolUsesEdit,
DedupHITLToolCallsMiddleware,
DoomLoopMiddleware,
FileIntentMiddleware,
KnowledgeBasePersistenceMiddleware,
KnowledgePriorityMiddleware,
KnowledgeTreeMiddleware,
MemoryInjectionMiddleware,
NoopInjectionMiddleware,
OtelSpanMiddleware,
RetryAfterMiddleware,
SpillingContextEditingMiddleware,
SpillToBackendEdit,
SurfSenseFilesystemMiddleware,
ToolCallNameRepairMiddleware,
build_skills_backend_factory,
create_surfsense_compaction_middleware,
default_skills_sources,
)
from app.agents.new_chat.plugin_loader import (
PluginContext,
load_allowed_plugin_names_from_env,
load_plugin_middlewares,
)
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
from app.db import ChatVisibility
logger = logging.getLogger(__name__)
# Routing tools with heavy outputs — never prune via context editing when bound.
_SUPERVISOR_PRUNE_PROTECTED: frozenset[str] = frozenset(
{
"deliverables",
"invalid",
# Align with single-agent surfacing of costly connector reads if names overlap later.
"read_email",
"search_emails",
"generate_report",
"generate_resume",
"generate_podcast",
"generate_video_presentation",
"generate_image",
}
)
def _safe_exclude_tools_supervisor(tools: Sequence[BaseTool]) -> tuple[str, ...]:
enabled = {t.name for t in tools}
return tuple(n for n in _SUPERVISOR_PRUNE_PROTECTED if n in enabled)
def parse_thread_id_for_action_log(thread_id: int | str | None) -> int | None:
"""Numeric DB thread ids only — UUID strings skip action logging (no FK row)."""
if thread_id is None:
return None
if isinstance(thread_id, int):
return thread_id
s = str(thread_id).strip()
if s.isdigit():
return int(s)
return None
def build_supervisor_middleware_stack(
*,
llm: BaseChatModel,
tools: Sequence[BaseTool],
backend_resolver: Any,
filesystem_mode: FilesystemMode,
search_space_id: int,
user_id: str | None,
thread_id: int | str | None,
visibility: ChatVisibility,
anon_session_id: str | None,
available_connectors: list[str] | None,
available_document_types: list[str] | None,
mentioned_document_ids: list[int] | None,
max_input_tokens: int | None,
flags: AgentFeatureFlags | None = None,
) -> list[Any]:
"""Build middleware list for the multi-agent supervisor (parity with ``_build_compiled_agent_blocking`` minus subagents)."""
flags = flags or get_flags()
memory_middleware = MemoryInjectionMiddleware(
user_id=user_id,
search_space_id=search_space_id,
thread_visibility=visibility,
)
summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend)
_ = flags.enable_compaction_v2
context_edit_mw = None
if (
flags.enable_context_editing
and not flags.disable_new_agent_stack
and max_input_tokens
):
spill_edit = SpillToBackendEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=_safe_exclude_tools_supervisor(tools),
clear_tool_inputs=True,
)
clear_edit = ClearToolUsesEdit(
trigger=int(max_input_tokens * 0.55),
clear_at_least=int(max_input_tokens * 0.15),
keep=5,
exclude_tools=_safe_exclude_tools_supervisor(tools),
clear_tool_inputs=True,
placeholder="[cleared - older tool output trimmed for context]",
)
context_edit_mw = SpillingContextEditingMiddleware(
edits=[spill_edit, clear_edit],
backend_resolver=backend_resolver,
)
retry_mw = (
RetryAfterMiddleware(max_retries=3)
if flags.enable_retry_after and not flags.disable_new_agent_stack
else None
)
fallback_mw: ModelFallbackMiddleware | None = None
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
try:
fallback_mw = ModelFallbackMiddleware(
"openai:gpt-4o-mini",
"anthropic:claude-3-5-haiku-20241022",
)
except Exception:
logger.warning("ModelFallbackMiddleware init failed; skipping.")
fallback_mw = None
model_call_limit_mw = (
ModelCallLimitMiddleware(
thread_limit=120,
run_limit=80,
exit_behavior="end",
)
if flags.enable_model_call_limit and not flags.disable_new_agent_stack
else None
)
tool_call_limit_mw = (
ToolCallLimitMiddleware(
thread_limit=300, run_limit=80, exit_behavior="continue"
)
if flags.enable_tool_call_limit and not flags.disable_new_agent_stack
else None
)
noop_mw = (
NoopInjectionMiddleware()
if flags.enable_compaction_v2 and not flags.disable_new_agent_stack
else None
)
repair_mw = None
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
registered_names: set[str] = {t.name for t in tools}
registered_names |= {
"write_todos",
"ls",
"read_file",
"write_file",
"edit_file",
"glob",
"grep",
"execute",
# No ``task`` — multi-agent uses routing tools instead of SubAgentMiddleware.
}
repair_mw = ToolCallNameRepairMiddleware(
registered_tool_names=registered_names,
fuzzy_match_threshold=None,
)
doom_loop_mw = (
DoomLoopMiddleware(threshold=3)
if flags.enable_doom_loop and not flags.disable_new_agent_stack
else None
)
thread_id_action_log = parse_thread_id_for_action_log(thread_id)
action_log_mw: ActionLogMiddleware | None = None
if (
flags.enable_action_log
and not flags.disable_new_agent_stack
and thread_id_action_log is not None
):
try:
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
action_log_mw = ActionLogMiddleware(
thread_id=thread_id_action_log,
search_space_id=search_space_id,
user_id=user_id,
tool_definitions=tool_defs_by_name,
)
except Exception: # pragma: no cover - defensive
logger.warning(
"ActionLogMiddleware init failed; running without it.",
exc_info=True,
)
action_log_mw = None
busy_mutex_mw: BusyMutexMiddleware | None = (
BusyMutexMiddleware()
if flags.enable_busy_mutex and not flags.disable_new_agent_stack
else None
)
otel_mw: OtelSpanMiddleware | None = (
OtelSpanMiddleware()
if flags.enable_otel and not flags.disable_new_agent_stack
else None
)
plugin_middlewares: list[Any] = []
if flags.enable_plugin_loader and not flags.disable_new_agent_stack:
try:
allowed_names = load_allowed_plugin_names_from_env()
if allowed_names:
plugin_middlewares = load_plugin_middlewares(
PluginContext.build(
search_space_id=search_space_id,
user_id=user_id,
thread_visibility=visibility,
llm=llm,
),
allowed_plugin_names=allowed_names,
)
except Exception: # pragma: no cover - defensive
logger.warning(
"Plugin loader failed; continuing without plugins.",
exc_info=True,
)
plugin_middlewares = []
skills_mw: SkillsMiddleware | None = None
if flags.enable_skills and not flags.disable_new_agent_stack:
try:
skills_factory = build_skills_backend_factory(
search_space_id=search_space_id
if filesystem_mode == FilesystemMode.CLOUD
else None,
)
skills_mw = SkillsMiddleware(
backend=skills_factory,
sources=default_skills_sources(),
)
except Exception as exc: # pragma: no cover - defensive
logger.warning("SkillsMiddleware init failed; skipping: %s", exc)
skills_mw = None
names = {t.name for t in tools}
selector_mw: LLMToolSelectorMiddleware | None = None
if (
flags.enable_llm_tool_selector
and not flags.disable_new_agent_stack
and len(tools) > 30
):
try:
selector_mw = LLMToolSelectorMiddleware(
model="openai:gpt-4o-mini",
max_tools=12,
always_include=[
n
for n in (
"research",
"memory",
"update_memory",
"get_connected_accounts",
"scrape_webpage",
)
if n in names
],
)
except Exception:
logger.warning("LLMToolSelectorMiddleware init failed; skipping.")
selector_mw = None
deepagent_middleware = [
busy_mutex_mw,
otel_mw,
TodoListMiddleware(),
memory_middleware,
AnonymousDocumentMiddleware(anon_session_id=anon_session_id)
if filesystem_mode == FilesystemMode.CLOUD
else None,
KnowledgeTreeMiddleware(
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
llm=llm,
)
if filesystem_mode == FilesystemMode.CLOUD
else None,
KnowledgePriorityMiddleware(
llm=llm,
search_space_id=search_space_id,
filesystem_mode=filesystem_mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
),
FileIntentMiddleware(llm=llm),
SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_mode,
search_space_id=search_space_id,
created_by_id=user_id,
thread_id=thread_id,
),
KnowledgeBasePersistenceMiddleware(
search_space_id=search_space_id,
created_by_id=user_id,
filesystem_mode=filesystem_mode,
)
if filesystem_mode == FilesystemMode.CLOUD
else None,
skills_mw,
selector_mw,
model_call_limit_mw,
tool_call_limit_mw,
context_edit_mw,
summarization_mw,
noop_mw,
retry_mw,
fallback_mw,
repair_mw,
doom_loop_mw,
action_log_mw,
PatchToolCallsMiddleware(),
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
*plugin_middlewares,
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
return [m for m in deepagent_middleware if m is not None]

View file

@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
import app.agents.multi_agent_chat.supervisor as supervisor_pkg
@ -19,12 +20,26 @@ def build_supervisor_agent(
*,
tools: Sequence[BaseTool],
checkpointer: Checkpointer | None = None,
middleware: Sequence[Any] | None = None,
context_schema: Any | None = None,
):
"""Compile the supervisor **agent** (graph). ``tools`` = output of ``build_supervisor_routing_tools``."""
system_prompt = read_prompt_md(supervisor_pkg.__name__, "supervisor_prompt")
return create_agent(
llm,
system_prompt=system_prompt,
tools=list(tools),
checkpointer=checkpointer,
)
kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"tools": list(tools),
"checkpointer": checkpointer,
}
if middleware is not None:
kwargs["middleware"] = list(middleware)
if context_schema is not None:
kwargs["context_schema"] = context_schema
agent = create_agent(llm, **kwargs)
if middleware is not None or context_schema is not None:
return agent.with_config(
{
"recursion_limit": 10_000,
"metadata": {"ls_integration": "multi_agent_supervisor"},
}
)
return agent