mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
multi_agent_chat/loader: MCP tools as flat list[BaseTool] per agent
This commit is contained in:
parent
5a00df8e48
commit
014801c764
7 changed files with 30 additions and 133 deletions
|
|
@ -14,9 +14,6 @@ from langgraph.types import Checkpointer
|
|||
from app.agents.multi_agent_chat.middleware.stack import (
|
||||
build_main_agent_deepagent_middleware,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.tool_kinds import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
|
|
@ -42,7 +39,7 @@ def build_compiled_agent_graph_sync(
|
|||
flags: AgentFeatureFlags,
|
||||
checkpointer: Checkpointer,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
|
||||
mcp_tools_by_agent: dict[str, list[BaseTool]] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
):
|
||||
"""Sync compile: middleware + ``create_agent`` (run via ``asyncio.to_thread``)."""
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from langchain_core.language_models import BaseChatModel
|
|||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.tool_kinds import ToolsPermissions
|
||||
from app.agents.new_chat.agent_cache import (
|
||||
flags_signature,
|
||||
get_cache,
|
||||
|
|
@ -25,14 +24,12 @@ from app.db import ChatVisibility
|
|||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||
|
||||
|
||||
def mcp_signature(mcp_tools_by_agent: dict[str, ToolsPermissions]) -> str:
|
||||
def mcp_signature(mcp_tools_by_agent: dict[str, list[BaseTool]]) -> str:
|
||||
"""Hash the per-agent MCP tool surface so a change rotates the cache key."""
|
||||
rows = []
|
||||
for agent_name in sorted(mcp_tools_by_agent.keys()):
|
||||
perms = mcp_tools_by_agent[agent_name]
|
||||
allow_names = sorted(item.get("name", "") for item in perms.get("allow", []))
|
||||
ask_names = sorted(item.get("name", "") for item in perms.get("ask", []))
|
||||
rows.append((agent_name, allow_names, ask_names))
|
||||
names = sorted(getattr(t, "name", "") or "" for t in mcp_tools_by_agent[agent_name])
|
||||
rows.append((agent_name, names))
|
||||
return stable_hash(rows)
|
||||
|
||||
|
||||
|
|
@ -55,7 +52,7 @@ async def build_agent_with_cache(
|
|||
flags: AgentFeatureFlags,
|
||||
checkpointer: Checkpointer,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions],
|
||||
mcp_tools_by_agent: dict[str, list[BaseTool]],
|
||||
disabled_tools: list[str] | None,
|
||||
config_id: str | None,
|
||||
) -> Any:
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ async def create_multi_agent_chat_deep_agent(
|
|||
)
|
||||
mcp_tools_by_agent = {}
|
||||
_perf_log.info(
|
||||
"[create_agent] load_mcp_tools_by_connector in %.3fs (%d buckets)",
|
||||
"[create_agent] load_mcp_tools_by_connector in %.3fs (%d agents)",
|
||||
time.perf_counter() - _t0,
|
||||
len(mcp_tools_by_agent),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.agent import
|
|||
from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.ask_knowledge_base_tool import (
|
||||
build_ask_knowledge_base_tool,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.tool_kinds import ToolsPermissions
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
|
@ -85,7 +84,7 @@ def build_main_agent_deepagent_middleware(
|
|||
flags: AgentFeatureFlags,
|
||||
subagent_dependencies: dict[str, Any],
|
||||
checkpointer: Checkpointer,
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
|
||||
mcp_tools_by_agent: dict[str, list[BaseTool]] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Ordered middleware for ``create_agent`` (None entries already stripped)."""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Load MCP tools, partition by connector agent, apply each subagent's allow/ask permissions."""
|
||||
"""Load MCP tools and partition them by connector agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,10 @@
|
|||
"""Discover MCP tools, bucket by connector agent, apply each subagent's allow/ask permissions."""
|
||||
"""Discover MCP tools and bucket them by connector agent.
|
||||
|
||||
Tool gating is no longer the loader's concern: each subagent declares its
|
||||
own :class:`Ruleset` and the per-subagent :class:`PermissionMiddleware`
|
||||
enforces it at runtime. This module just routes flat ``BaseTool`` lists
|
||||
to the right subagents.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -15,46 +21,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from app.agents.multi_agent_chat.constants import (
|
||||
CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.connectors.airtable.tools.index import (
|
||||
load_tools as _airtable_permissions,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.connectors.clickup.tools.index import (
|
||||
load_tools as _clickup_permissions,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.connectors.jira.tools.index import (
|
||||
load_tools as _jira_permissions,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.connectors.linear.tools.index import (
|
||||
load_tools as _linear_permissions,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.connectors.slack.tools.index import (
|
||||
load_tools as _slack_permissions,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.middleware_gated import (
|
||||
middleware_gated_tool_permission_row,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.tool_kinds import (
|
||||
ToolPermissionItem,
|
||||
ToolsPermissions,
|
||||
)
|
||||
from app.agents.new_chat.tools.mcp_tool import load_mcp_tools
|
||||
from app.db import SearchSourceConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_MCP_PERMISSIONS_BY_AGENT: dict[str, ToolsPermissions] = {
|
||||
"airtable": _airtable_permissions(),
|
||||
"clickup": _clickup_permissions(),
|
||||
"jira": _jira_permissions(),
|
||||
"linear": _linear_permissions(),
|
||||
"slack": _slack_permissions(),
|
||||
}
|
||||
|
||||
|
||||
## Helper functions for fetching connector metadata maps
|
||||
|
||||
|
||||
async def fetch_mcp_connector_metadata_maps(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
|
|
@ -80,9 +52,6 @@ async def fetch_mcp_connector_metadata_maps(
|
|||
return id_to_type, name_to_type
|
||||
|
||||
|
||||
## Helper functions for partitioning tools by connector agent
|
||||
|
||||
|
||||
def partition_mcp_tools_by_connector(
|
||||
tools: Sequence[BaseTool],
|
||||
connector_id_to_type: dict[int, str],
|
||||
|
|
@ -130,59 +99,15 @@ def partition_mcp_tools_by_connector(
|
|||
return dict(buckets)
|
||||
|
||||
|
||||
## Helper functions for splitting tools by permissions
|
||||
|
||||
|
||||
def _get_mcp_tool_name(tool: BaseTool) -> str:
|
||||
meta: dict[str, Any] = getattr(tool, "metadata", None) or {}
|
||||
orig = meta.get("mcp_original_tool_name")
|
||||
if isinstance(orig, str) and orig:
|
||||
return orig
|
||||
return getattr(tool, "name", "") or ""
|
||||
|
||||
|
||||
def _split_tools_by_permissions(
|
||||
tools: Sequence[BaseTool],
|
||||
perms: ToolsPermissions,
|
||||
) -> ToolsPermissions:
|
||||
allow_names = frozenset(r["name"] for r in perms["allow"])
|
||||
ask_names = frozenset(r["name"] for r in perms["ask"])
|
||||
allow: list[ToolPermissionItem] = []
|
||||
ask: list[ToolPermissionItem] = []
|
||||
for t in tools:
|
||||
meta: dict[str, Any] = getattr(t, "metadata", None) or {}
|
||||
if meta.get("hitl") is False:
|
||||
allow.append(middleware_gated_tool_permission_row(t))
|
||||
continue
|
||||
key = _get_mcp_tool_name(t)
|
||||
if key in allow_names:
|
||||
allow.append(middleware_gated_tool_permission_row(t))
|
||||
elif key in ask_names:
|
||||
ask.append(middleware_gated_tool_permission_row(t))
|
||||
else:
|
||||
ask.append(middleware_gated_tool_permission_row(t))
|
||||
return {"allow": allow, "ask": ask}
|
||||
|
||||
|
||||
## Main function to load MCP tools and split them by permissions for each connector agent
|
||||
|
||||
|
||||
async def load_mcp_tools_by_connector(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> dict[str, ToolsPermissions]:
|
||||
"""Load MCP tools and split rows per subagent's own allow/ask permissions.
|
||||
) -> dict[str, list[BaseTool]]:
|
||||
"""Load MCP tools and route them to each subagent as a flat list.
|
||||
|
||||
Pass ``bypass_internal_hitl=True`` so the subagent's
|
||||
``HumanInTheLoopMiddleware`` is the single HITL gate.
|
||||
``bypass_internal_hitl=True`` is set so tool gating is uniformly the
|
||||
consuming subagent's :class:`PermissionMiddleware` responsibility.
|
||||
"""
|
||||
flat = await load_mcp_tools(session, search_space_id, bypass_internal_hitl=True)
|
||||
id_map, name_map = await fetch_mcp_connector_metadata_maps(session, search_space_id)
|
||||
buckets = partition_mcp_tools_by_connector(flat, id_map, name_map)
|
||||
return {
|
||||
agent: _split_tools_by_permissions(
|
||||
tools,
|
||||
_MCP_PERMISSIONS_BY_AGENT.get(agent, {"allow": [], "ask": []}),
|
||||
)
|
||||
for agent, tools in buckets.items()
|
||||
}
|
||||
return partition_mcp_tools_by_connector(flat, id_map, name_map)
|
||||
|
|
|
|||
|
|
@ -72,9 +72,6 @@ from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
|||
read_md_file,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec
|
||||
from app.agents.multi_agent_chat.subagents.shared.tool_kinds import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
||||
|
||||
class SubagentBuilder(Protocol):
|
||||
|
|
@ -84,20 +81,8 @@ class SubagentBuilder(Protocol):
|
|||
dependencies: dict[str, Any],
|
||||
model: BaseChatModel | None = None,
|
||||
middleware_stack: dict[str, Any] | None = None,
|
||||
extra_tools_bucket: ToolsPermissions | None = None,
|
||||
) -> SubAgent | SurfSenseSubagentSpec: ...
|
||||
|
||||
|
||||
def _unwrap_spec(result: SubAgent | SurfSenseSubagentSpec) -> SubAgent:
|
||||
"""Project a builder's return value down to the deepagents-shaped dict.
|
||||
|
||||
Transitional helper while subagents migrate to ``SurfSenseSubagentSpec``.
|
||||
Once every builder returns the new container, this becomes a single
|
||||
``return result.spec``.
|
||||
"""
|
||||
if isinstance(result, SurfSenseSubagentSpec):
|
||||
return result.spec
|
||||
return result
|
||||
mcp_tools: list[BaseTool] | None = None,
|
||||
) -> SurfSenseSubagentSpec: ...
|
||||
|
||||
|
||||
SUBAGENT_BUILDERS_BY_NAME: dict[str, SubagentBuilder] = {
|
||||
|
|
@ -167,7 +152,7 @@ def _filter_disabled_tools_in_place(
|
|||
spec: SubAgent,
|
||||
disabled_names: frozenset[str],
|
||||
) -> None:
|
||||
"""Drop UI-disabled tools from ``spec["tools"]`` and ``spec["interrupt_on"]``."""
|
||||
"""Drop UI-disabled tools from ``spec["tools"]``."""
|
||||
if not disabled_names:
|
||||
return
|
||||
tools = spec.get("tools") # type: ignore[typeddict-item]
|
||||
|
|
@ -175,11 +160,6 @@ def _filter_disabled_tools_in_place(
|
|||
spec["tools"] = [ # type: ignore[typeddict-unknown-key]
|
||||
t for t in tools if getattr(t, "name", None) not in disabled_names
|
||||
]
|
||||
interrupt_on = spec.get("interrupt_on") # type: ignore[typeddict-item]
|
||||
if isinstance(interrupt_on, dict):
|
||||
spec["interrupt_on"] = { # type: ignore[typeddict-unknown-key]
|
||||
k: v for k, v in interrupt_on.items() if k not in disabled_names
|
||||
}
|
||||
|
||||
|
||||
def _inject_ask_kb_tool_in_place(spec: SubAgent, ask_kb_tool: BaseTool) -> None:
|
||||
|
|
@ -200,7 +180,7 @@ def build_subagents(
|
|||
dependencies: dict[str, Any],
|
||||
model: BaseChatModel | None = None,
|
||||
middleware_stack: dict[str, Any] | None = None,
|
||||
mcp_tools_by_agent: dict[str, ToolsPermissions] | None = None,
|
||||
mcp_tools_by_agent: dict[str, list[BaseTool]] | None = None,
|
||||
exclude: list[str] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
ask_kb_tool: BaseTool | None = None,
|
||||
|
|
@ -216,14 +196,13 @@ def build_subagents(
|
|||
if name in excluded:
|
||||
continue
|
||||
builder = SUBAGENT_BUILDERS_BY_NAME[name]
|
||||
spec = _unwrap_spec(
|
||||
builder(
|
||||
dependencies=dependencies,
|
||||
model=model,
|
||||
middleware_stack=middleware_stack,
|
||||
extra_tools_bucket=mcp.get(name),
|
||||
)
|
||||
result = builder(
|
||||
dependencies=dependencies,
|
||||
model=model,
|
||||
middleware_stack=middleware_stack,
|
||||
mcp_tools=mcp.get(name),
|
||||
)
|
||||
spec = result.spec
|
||||
_filter_disabled_tools_in_place(spec, disabled_names)
|
||||
if ask_kb_tool is not None:
|
||||
_inject_ask_kb_tool_in_place(spec, ask_kb_tool)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue