mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-07 06:42:39 +02:00
Add MCP tool loading and connector partitioning.
This commit is contained in:
parent
cf3acd87aa
commit
7080b787d1
2 changed files with 176 additions and 0 deletions
|
|
@ -0,0 +1,20 @@
|
|||
"""Load MCP tools, partition by connector agent, apply allow/ask name rules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.multi_agent_with_deepagents.subagents.mcp_tools.permissions import (
|
||||
TOOLS_PERMISSIONS_BY_AGENT,
|
||||
)
|
||||
|
||||
from .index import (
|
||||
fetch_mcp_connector_metadata_maps,
|
||||
load_mcp_tools_by_connector,
|
||||
partition_mcp_tools_by_connector,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TOOLS_PERMISSIONS_BY_AGENT",
|
||||
"fetch_mcp_connector_metadata_maps",
|
||||
"load_mcp_tools_by_connector",
|
||||
"partition_mcp_tools_by_connector",
|
||||
]
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
"""Discover MCP tools, bucket by connector agent, apply allow/ask from policy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from sqlalchemy import cast, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.multi_agent_with_deepagents.constants import (
|
||||
CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS,
|
||||
)
|
||||
from app.agents.multi_agent_with_deepagents.subagents.mcp_tools.permissions import (
|
||||
TOOLS_PERMISSIONS_BY_AGENT,
|
||||
)
|
||||
from app.agents.multi_agent_with_deepagents.subagents.shared.permissions import (
|
||||
ToolPermissionItem,
|
||||
ToolsPermissions,
|
||||
tool_permission_row,
|
||||
)
|
||||
from app.agents.new_chat.tools.mcp_tool import load_mcp_tools
|
||||
from app.db import SearchSourceConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
## Helper functions for fetching connector metadata maps
|
||||
|
||||
async def fetch_mcp_connector_metadata_maps(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> tuple[dict[int, str], dict[str, str]]:
|
||||
"""Resolve connector id and display name to connector type for MCP tool routing."""
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
cast(SearchSourceConnector.config, JSONB).has_key("server_config"),
|
||||
),
|
||||
)
|
||||
id_to_type: dict[int, str] = {}
|
||||
name_to_type: dict[str, str] = {}
|
||||
for connector in result.scalars():
|
||||
ct = (
|
||||
connector.connector_type.value
|
||||
if hasattr(connector.connector_type, "value")
|
||||
else str(connector.connector_type)
|
||||
)
|
||||
id_to_type[connector.id] = ct
|
||||
if connector.name:
|
||||
name_to_type[connector.name] = ct
|
||||
return id_to_type, name_to_type
|
||||
|
||||
|
||||
## Helper functions for partitioning tools by connector agent
|
||||
|
||||
def partition_mcp_tools_by_connector(
|
||||
tools: Sequence[BaseTool],
|
||||
connector_id_to_type: dict[int, str],
|
||||
connector_name_to_type: dict[str, str],
|
||||
) -> dict[str, list[BaseTool]]:
|
||||
"""Assign each MCP tool to one connector-agent bucket from connector metadata."""
|
||||
buckets: dict[str, list[BaseTool]] = defaultdict(list)
|
||||
|
||||
for tool in tools:
|
||||
meta: dict[str, Any] = getattr(tool, "metadata", None) or {}
|
||||
connector_type: str | None = None
|
||||
|
||||
cid = meta.get("mcp_connector_id")
|
||||
if cid is not None:
|
||||
try:
|
||||
cid_int = int(cid)
|
||||
except (TypeError, ValueError):
|
||||
cid_int = None
|
||||
if cid_int is not None:
|
||||
connector_type = connector_id_to_type.get(cid_int)
|
||||
|
||||
if connector_type is None and meta.get("mcp_transport") == "stdio":
|
||||
cname = meta.get("mcp_connector_name")
|
||||
if cname:
|
||||
connector_type = connector_name_to_type.get(str(cname))
|
||||
|
||||
if connector_type is None:
|
||||
logger.debug(
|
||||
"Skipping MCP tool %r — could not resolve connector type from metadata",
|
||||
getattr(tool, "name", None),
|
||||
)
|
||||
continue
|
||||
|
||||
connector_agent = CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS.get(connector_type)
|
||||
if connector_agent is None:
|
||||
logger.warning(
|
||||
"MCP tool %r has unmapped connector type %s — skipped",
|
||||
getattr(tool, "name", None),
|
||||
connector_type,
|
||||
)
|
||||
continue
|
||||
|
||||
buckets[connector_agent].append(tool)
|
||||
|
||||
return dict(buckets)
|
||||
|
||||
## Helper functions for splitting tools by permissions
|
||||
|
||||
def _get_mcp_tool_name(tool: BaseTool) -> str:
|
||||
meta: dict[str, Any] = getattr(tool, "metadata", None) or {}
|
||||
orig = meta.get("mcp_original_tool_name")
|
||||
if isinstance(orig, str) and orig:
|
||||
return orig
|
||||
return getattr(tool, "name", "") or ""
|
||||
|
||||
|
||||
def _split_tools_by_permissions(
|
||||
tools: Sequence[BaseTool],
|
||||
perms: ToolsPermissions,
|
||||
) -> ToolsPermissions:
|
||||
allow_names = frozenset(r["name"] for r in perms["allow"])
|
||||
ask_names = frozenset(r["name"] for r in perms["ask"])
|
||||
allow: list[ToolPermissionItem] = []
|
||||
ask: list[ToolPermissionItem] = []
|
||||
for t in tools:
|
||||
meta: dict[str, Any] = getattr(t, "metadata", None) or {}
|
||||
if meta.get("hitl") is False:
|
||||
allow.append(tool_permission_row(t))
|
||||
continue
|
||||
key = _get_mcp_tool_name(t)
|
||||
if key in allow_names:
|
||||
allow.append(tool_permission_row(t))
|
||||
elif key in ask_names:
|
||||
ask.append(tool_permission_row(t))
|
||||
else:
|
||||
ask.append(tool_permission_row(t))
|
||||
return {"allow": allow, "ask": ask}
|
||||
|
||||
|
||||
## Main function to load MCP tools and split them by permissions for each connector agent
|
||||
|
||||
async def load_mcp_tools_by_connector(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> dict[str, ToolsPermissions]:
|
||||
"""Load MCP tools and split rows using ``TOOLS_PERMISSIONS_BY_AGENT`` name sets."""
|
||||
flat = await load_mcp_tools(session, search_space_id)
|
||||
id_map, name_map = await fetch_mcp_connector_metadata_maps(session, search_space_id)
|
||||
buckets = partition_mcp_tools_by_connector(flat, id_map, name_map)
|
||||
return {
|
||||
agent: _split_tools_by_permissions(
|
||||
tools,
|
||||
TOOLS_PERMISSIONS_BY_AGENT.get(agent, {"allow": [], "ask": []}),
|
||||
)
|
||||
for agent, tools in buckets.items()
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue