feat(mcp): generic MCP tool source with per-node function filtering (#301)

* feat(mcp): generic MCP tool source with per-node function filtering

Adds a Model Context Protocol tool category: connect a customer MCP
server and expose its tools to the agent, with optional per-node
allow-listing of individual MCP functions.

- ToolCategory.MCP enum + alembic migration
- MCP definition validator and collision-safe function-name namespacing
- McpToolSession wrapper: graceful-degrade, per-call open/close lifecycle
- CustomToolManager MCP branch (schemas + proxy handlers)
- Per-node mcp_tool_filters threaded through DTO/graph/engine
- Best-effort discovered_tools catalog cache + POST /tools/{uuid}/mcp/refresh
- UI: MCP create/edit config, tabbed ToolSelector with per-node toggles

* feat: refactor for code standardisation and documentation

---------

Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
Paulo Busato Favarato 2026-05-19 07:40:00 -03:00 committed by GitHub
parent 0097974444
commit 75839f9de5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 3028 additions and 137 deletions

View file

@ -1,5 +1,5 @@
from enum import Enum
from typing import Annotated, List, Literal, Optional, Union
from typing import Annotated, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, ValidationError, model_validator
@ -69,6 +69,7 @@ class _ExtractionNodeDataMixin(BaseModel):
class _ToolDocumentRefsMixin(BaseModel):
tool_uuids: Optional[List[str]] = None
document_uuids: Optional[List[str]] = None
mcp_tool_filters: Optional[Dict[str, List[str]]] = None
class StartCallNodeData(

View file

@ -0,0 +1,254 @@
"""Single unit that knows the MCP protocol + credentials.
Wraps the vendored Pipecat ``MCPClient`` for connection/session, builds
streamable-HTTP params from a Dograh credential, exposes namespaced
``FunctionSchema``s, and proxies tool calls. Connection failures degrade
(``available = False``) instead of raising the call must survive a
dead MCP server.
"""
from __future__ import annotations
import asyncio
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from loguru import logger
from mcp.client.session_group import StreamableHttpParameters
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.services.mcp_service import MCPClient
from api.services.workflow.tools.mcp_tool import namespace_function_name
from api.utils.credential_auth import build_auth_header
if TYPE_CHECKING:
from api.db.models import ExternalCredentialModel
def build_streamable_http_params(
*,
url: str,
credential: Optional["ExternalCredentialModel"],
timeout_secs: int,
sse_read_timeout_secs: int,
) -> StreamableHttpParameters:
"""Build Pipecat/MCP streamable-HTTP params, injecting the auth header
from an ExternalCredentialModel (reuses the http_api credential path)."""
headers: Optional[Dict[str, str]] = None
if credential is not None:
auth = build_auth_header(credential)
headers = auth or None
return StreamableHttpParameters(
url=url,
headers=headers,
timeout=timedelta(seconds=timeout_secs),
sse_read_timeout=timedelta(seconds=sse_read_timeout_secs),
)
class McpToolSession:
"""One live MCP server connection for the duration of a call."""
def __init__(
self,
*,
tool_uuid: str,
tool_name: str,
url: str,
credential: Optional["ExternalCredentialModel"],
tools_filter: List[str],
timeout_secs: int,
sse_read_timeout_secs: int,
) -> None:
self._tool_uuid = tool_uuid
self._tool_name = tool_name
self._url = url
self._credential = credential
# An empty list is intentionally treated as "no filter (expose all
# tools)" — Pipecat's MCPClient applies a filter only when this is a
# non-empty list, so [] and None are equivalent ("all tools").
self._tools_filter = tools_filter or None
self._timeout_secs = timeout_secs
self._sse_read_timeout_secs = sse_read_timeout_secs
self._client: Optional[MCPClient] = None
self._session: Any = None # mcp.ClientSession (read once after start)
self._schemas: List[FunctionSchema] = []
# namespaced LLM name -> original MCP tool name
self._name_map: Dict[str, str] = {}
self.available: bool = False
async def start(self) -> None:
"""Connect, initialize, and cache the tool list. Never raises —
on any failure the session is marked unavailable."""
try:
params = build_streamable_http_params(
url=self._url,
credential=self._credential,
timeout_secs=self._timeout_secs,
sse_read_timeout_secs=self._sse_read_timeout_secs,
)
self._client = MCPClient(params, tools_filter=self._tools_filter)
await self._client.start()
# Single, isolated touch of Pipecat internals (vendored submodule).
self._session = self._client._active_session
tools_schema = await self._client.get_tools_schema()
fallback = self._tool_uuid[:8] if self._tool_uuid else "server"
for fs in tools_schema.standard_tools:
ns_name = namespace_function_name(
self._tool_name, fs.name, fallback=fallback
)
self._name_map[ns_name] = fs.name
self._schemas.append(
FunctionSchema(
name=ns_name,
description=fs.description,
properties=fs.properties,
required=fs.required,
)
)
self.available = True
logger.info(
f"MCP session ready for tool '{self._tool_name}' "
f"({self._tool_uuid}): {sorted(self._name_map)}"
)
except (KeyboardInterrupt, SystemExit):
raise
except asyncio.CancelledError as e:
# Empirically, a dead/unreachable MCP server does NOT surface as a
# plain Exception here. The real failure is httpx.ConnectError, but
# anyio's streamablehttp_client task group, while tearing down that
# ConnectError, re-surfaces it to our frame as an *internal*
# cancel-scope CancelledError carrying the signature message
# "Cancelled via cancel scope <id>". A genuine *external*
# cancellation (call teardown / shutdown) is a bare CancelledError
# (empty args) or one with an application-chosen message. Type, MRO,
# context chain, and asyncio task.cancelling() are all identical
# between the two, so the anyio scope-signature message is the only
# reliable discriminator. Re-raise genuine external cancellation to
# preserve structured concurrency; degrade only on the anyio
# connect-teardown artifact.
msg = "" if not e.args else str(e.args[0] or "")
if not msg.startswith("Cancelled via cancel scope"):
raise
await self._degrade(e)
except Exception as e: # noqa: BLE001 — see _degrade docstring
# Defensive: if a future Pipecat/httpx version surfaces the connect
# failure directly (e.g. httpx.ConnectError) instead of via the
# anyio cancel-scope artifact above, still degrade gracefully.
await self._degrade(e)
async def _degrade(self, e: BaseException) -> None:
"""Mark this session unavailable and tear down any dangling client so
start() leaves self._client either fully usable or None. The contract
requires graceful degradation on any *connect* failure (never raising
for a dead MCP server) while genuine external cancellation /
KeyboardInterrupt / SystemExit are re-raised by the caller."""
self.available = False
self._schemas = []
self._name_map = {}
# Self-contained cleanup: _client.start() may have succeeded before a
# later step (e.g. get_tools_schema()) failed, leaving an open client.
if self._client is not None:
try:
await self._client.close()
except Exception:
pass
finally:
self._client = None
self._session = None
logger.warning(
f"MCP session unavailable for tool '{self._tool_name}' "
f"({self._tool_uuid}) at {self._url}: {e!r}. "
f"Call proceeds without these tools."
)
@property
def call_timeout_secs(self) -> float:
"""Pipecat function-call timeout for this server's tools. Slightly
longer than the transport read timeout so a slow MCP call surfaces
as a structured tool error (handled in the handler) rather than a
hard pipeline timeout."""
return float(self._sse_read_timeout_secs) + 5.0
def function_schemas(
self, allowed_raw_names: Optional[Set[str]] = None
) -> List[FunctionSchema]:
"""Return cached FunctionSchemas, optionally filtered by raw MCP tool name.
``allowed_raw_names=None`` returns all schemas. An empty set returns none.
Raw names are the pre-namespace MCP tool names (e.g. ``echo``, not
``mcp__slug__echo``).
"""
if allowed_raw_names is None:
return list(self._schemas)
return [
s for s in self._schemas if self._name_map.get(s.name) in allowed_raw_names
]
def discovered_tools(self) -> List[Dict[str, str]]:
"""Raw MCP tool catalog for UI/cache: ``[{name, description}]``
using the *raw* server names (not the namespaced LLM names).
Empty if the session is unavailable."""
out: List[Dict[str, str]] = []
for s in self._schemas:
raw = self._name_map.get(s.name)
if raw is None:
continue
out.append({"name": raw, "description": s.description or ""})
return out
async def call(self, namespaced_name: str, arguments: Dict[str, Any]) -> str:
"""Invoke an MCP tool by its namespaced LLM name. Returns a string
(flattened text content). Raises if the session is unavailable so
the caller can map it to a structured error for the LLM."""
if not self.available or self._session is None:
raise RuntimeError(f"MCP session unavailable for {namespaced_name}")
original = self._name_map.get(namespaced_name)
if original is None:
raise RuntimeError(f"Unknown MCP function {namespaced_name}")
result = await self._session.call_tool(original, arguments=arguments)
text = ""
for content in getattr(result, "content", []) or []:
if getattr(content, "text", None):
text += content.text
return text or "Sorry, the MCP tool returned no content."
async def close(self) -> None:
if self._client is not None:
try:
await self._client.close()
except Exception as e:
logger.warning(f"Error closing MCP session {self._tool_uuid}: {e}")
finally:
self._client = None
self._session = None
async def discover_mcp_tools(
*,
url: str,
credential: Optional["ExternalCredentialModel"],
timeout_secs: int,
sse_read_timeout_secs: int,
) -> List[Dict[str, str]]:
"""Open an ephemeral MCP session, list its tools, close it. Returns
``[{name, description}]`` (raw names). Never raises on any connect
failure returns ``[]``."""
session = McpToolSession(
tool_uuid="discover",
tool_name="discover",
url=url,
credential=credential,
tools_filter=[],
timeout_secs=timeout_secs,
sse_read_timeout_secs=sse_read_timeout_secs,
)
await session.start()
try:
if not session.available:
return []
return session.discovered_tools()
finally:
await session.close()

View file

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Union
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Union
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.frames.frames import (
@ -16,6 +16,7 @@ from pipecat.services.settings import LLMSettings
from pipecat.utils.enums import EndTaskReason
from api.db import db_client
from api.enums import ToolCategory
from api.services.pipecat.audio_playback import play_audio
from api.services.workflow.disposition_mapper import apply_disposition_mapping
from api.services.workflow.workflow_graph import Node, WorkflowGraph
@ -34,6 +35,7 @@ import asyncio
from loguru import logger
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.mcp_tool_session import McpToolSession
from api.services.workflow.pipecat_engine_context_composer import (
compose_functions_for_node,
compose_system_prompt_for_node,
@ -116,6 +118,9 @@ class PipecatEngine:
# Cached organization ID (resolved lazily from workflow run)
self._organization_id: Optional[int] = None
# Open MCP tool sessions for this call, keyed by tool_uuid
self._mcp_sessions: Dict[str, McpToolSession] = {}
# Embeddings configuration (passed from run_pipeline.py)
self._embeddings_api_key: Optional[str] = embeddings_api_key
self._embeddings_model: Optional[str] = embeddings_model
@ -178,6 +183,9 @@ class PipecatEngine:
# Helper that encapsulates custom tool management
self._custom_tool_manager = CustomToolManager(self)
# Open persistent MCP server sessions for this call (degrades on failure)
await self._open_mcp_sessions()
# Helper that encapsulates context summarization
if self._context_compaction_enabled:
self._context_summarization_manager = ContextSummarizationManager(self)
@ -503,7 +511,10 @@ class PipecatEngine:
# Register custom tool handlers for this node
if node.tool_uuids and self._custom_tool_manager:
await self._custom_tool_manager.register_handlers(node.tool_uuids)
await self._custom_tool_manager.register_handlers(
node.tool_uuids,
mcp_tool_filters=getattr(node, "mcp_tool_filters", None),
)
# Register knowledge base retrieval handler if node has documents
if node.document_uuids:
@ -814,6 +825,79 @@ class PipecatEngine:
"""Get the gathered context including extracted variables."""
return self._gathered_context.copy()
async def _open_mcp_sessions(self) -> None:
"""Connect every MCP-category tool referenced by any workflow node.
Failures degrade (session marked unavailable); never raises."""
from api.services.workflow.tools.mcp_tool import (
McpDefinitionError,
validate_mcp_definition,
)
try:
tool_uuids: set[str] = set()
for node in self.workflow.nodes.values():
for tu in getattr(node, "tool_uuids", None) or []:
tool_uuids.add(tu)
if not tool_uuids:
return
organization_id = await self._get_organization_id()
if not organization_id:
logger.warning("Cannot open MCP sessions: organization_id missing")
return
tools = await db_client.get_tools_by_uuids(
list(tool_uuids), organization_id
)
for tool in tools:
if tool.category != ToolCategory.MCP.value:
continue
try:
cfg = validate_mcp_definition(tool.definition)
except McpDefinitionError as e:
logger.warning(
f"Skipping MCP tool '{tool.name}' ({tool.tool_uuid}): "
f"invalid definition: {e}"
)
continue
credential = None
if cfg["credential_uuid"]:
try:
credential = await db_client.get_credential_by_uuid(
cfg["credential_uuid"], organization_id
)
except Exception as e:
logger.warning(
f"MCP tool '{tool.name}': credential fetch failed: {e}"
)
continue
session = McpToolSession(
tool_uuid=tool.tool_uuid,
tool_name=tool.name,
url=cfg["url"],
credential=credential,
tools_filter=cfg["tools_filter"],
timeout_secs=cfg["timeout_secs"],
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
)
await session.start()
self._mcp_sessions[tool.tool_uuid] = session
except Exception as e:
logger.warning(
f"Failed to open MCP sessions; call proceeds without MCP tools: {e}",
exc_info=True,
)
async def _close_mcp_sessions(self) -> None:
for tool_uuid, session in list(self._mcp_sessions.items()):
try:
await session.close()
except Exception as e:
logger.warning(f"Error closing MCP session {tool_uuid}: {e}")
self._mcp_sessions = {}
async def cleanup(self):
"""Clean up engine resources on disconnect."""
# Cancel any pending timeout tasks
@ -823,6 +907,12 @@ class PipecatEngine:
):
self._user_response_timeout_task.cancel()
# Cancel any in-flight background summarization
if self._context_summarization_manager:
await self._context_summarization_manager.cleanup()
# Cancel any in-flight background summarization.
# MCP sessions are closed in a finally block so they are guaranteed to
# run even if the summarization cleanup raises an exception.
try:
if self._context_summarization_manager:
await self._context_summarization_manager.cleanup()
finally:
# Close any open MCP tool sessions
await self._close_mcp_sessions()

View file

@ -117,7 +117,8 @@ async def compose_functions_for_node(
# Custom tools
if node.tool_uuids and custom_tool_manager:
custom_tool_schemas = await custom_tool_manager.get_tool_schemas(
node.tool_uuids
node.tool_uuids,
mcp_tool_filters=getattr(node, "mcp_tool_filters", None),
)
functions.extend(custom_tool_schemas)

View file

@ -34,6 +34,7 @@ from api.services.workflow.tools.custom_tool import (
)
if TYPE_CHECKING:
from api.services.workflow.mcp_tool_session import McpToolSession
from api.services.workflow.pipecat_engine import PipecatEngine
@ -121,11 +122,18 @@ class CustomToolManager:
"""Get the organization ID from the engine (shared cache)."""
return await self._engine._get_organization_id()
async def get_tool_schemas(self, tool_uuids: list[str]) -> list[FunctionSchema]:
async def get_tool_schemas(
self,
tool_uuids: list[str],
mcp_tool_filters: Optional[dict[str, list[str]]] = None,
) -> list[FunctionSchema]:
"""Fetch custom tools and convert them to function schemas.
Args:
tool_uuids: List of tool UUIDs to fetch
mcp_tool_filters: Optional per-node filter mapping tool_uuid list of
raw MCP tool names to expose. None (default) exposes all tools.
Empty dict or entry with [] suppresses all tools for that uuid.
Returns:
List of FunctionSchema objects for LLM
@ -154,6 +162,22 @@ class CustomToolManager:
)
continue
if tool.category == ToolCategory.MCP.value:
session = self._engine._mcp_sessions.get(tool.tool_uuid)
if session is None or not session.available:
logger.warning(
f"MCP tool '{tool.name}' ({tool.tool_uuid}) "
f"unavailable; skipping"
)
continue
allowed = (
None
if mcp_tool_filters is None
else set(mcp_tool_filters.get(tool.tool_uuid, []))
)
schemas.extend(session.function_schemas(allowed))
continue
raw_schema = tool_to_function_schema(tool)
function_name = raw_schema["function"]["name"]
@ -178,11 +202,18 @@ class CustomToolManager:
logger.error(f"Failed to fetch custom tools: {e}")
return []
async def register_handlers(self, tool_uuids: list[str]) -> None:
async def register_handlers(
self,
tool_uuids: list[str],
mcp_tool_filters: Optional[dict[str, list[str]]] = None,
) -> None:
"""Register custom tool execution handlers with the LLM.
Args:
tool_uuids: List of tool UUIDs to register handlers for
mcp_tool_filters: Optional per-node filter mapping tool_uuid list of
raw MCP tool names to expose. None (default) exposes all tools.
Empty dict or entry with [] suppresses all tools for that uuid.
"""
organization_id = await self.get_organization_id()
if not organization_id:
@ -203,6 +234,32 @@ class CustomToolManager:
)
continue
if tool.category == ToolCategory.MCP.value:
session = self._engine._mcp_sessions.get(tool.tool_uuid)
if session is None or not session.available:
logger.warning(
f"MCP tool '{tool.name}' ({tool.tool_uuid}) "
f"unavailable; skipping handler registration"
)
continue
allowed = (
None
if mcp_tool_filters is None
else set(mcp_tool_filters.get(tool.tool_uuid, []))
)
mcp_schemas = session.function_schemas(allowed)
for fs in mcp_schemas:
self._engine.llm.register_function(
fs.name,
self._create_mcp_handler(session, fs.name),
timeout_secs=session.call_timeout_secs,
)
logger.debug(
f"Registered {len(mcp_schemas)} MCP "
f"handlers for tool '{tool.name}' ({tool.tool_uuid})"
)
continue
schema = tool_to_function_schema(tool)
function_name = schema["function"]["name"]
@ -335,6 +392,29 @@ class CustomToolManager:
return http_tool_handler
def _create_mcp_handler(self, session: "McpToolSession", function_name: str):
"""Create a handler that proxies an LLM function call to a live MCP
session. Errors are returned to the LLM as structured text so the
agent can recover verbally; the call is never crashed."""
async def mcp_tool_handler(
function_call_params: FunctionCallParams,
) -> None:
logger.info(f"MCP Tool EXECUTED: {function_name}")
logger.info(f"Arguments: {function_call_params.arguments}")
try:
result = await session.call(
function_name, function_call_params.arguments or {}
)
await function_call_params.result_callback(result)
except Exception as e:
logger.error(f"MCP tool '{function_name}' failed: {e}")
await function_call_params.result_callback(
{"status": "error", "error": str(e)}
)
return mcp_tool_handler
def _create_end_call_handler(self, tool: Any, function_name: str):
"""Create a handler function for an end call tool.

View file

@ -0,0 +1,116 @@
"""Pure helpers for MCP-category tools: definition validation and
LLM-function-name namespacing. No I/O, no MCP protocol here."""
from __future__ import annotations
import re
from typing import Any, Dict, Literal, Optional
from pydantic import BaseModel, Field, ValidationError, field_validator
DEFAULT_TIMEOUT_SECS = 30
DEFAULT_SSE_READ_TIMEOUT_SECS = 300
class McpDefinitionError(ValueError):
"""Raised when an MCP tool definition is structurally invalid."""
class McpToolConfig(BaseModel):
"""Configuration for an MCP tool definition."""
transport: Literal["streamable_http"] = Field(
default="streamable_http", description="MCP transport protocol"
)
url: str = Field(description="MCP server URL (must be http:// or https://)")
credential_uuid: Optional[str] = Field(
default=None, description="Reference to ExternalCredentialModel for auth"
)
tools_filter: list[str] = Field(
default_factory=list,
description="Allowlist of MCP tool names to expose (empty = all tools)",
)
timeout_secs: int = Field(
default=DEFAULT_TIMEOUT_SECS, description="Connection timeout in seconds"
)
sse_read_timeout_secs: int = Field(
default=DEFAULT_SSE_READ_TIMEOUT_SECS,
description="SSE read timeout in seconds",
)
discovered_tools: list[dict[str, Any]] = Field(
default_factory=list,
description=(
"Server-managed cache of the MCP server's tool catalog "
"[{name, description}]. Populated best-effort by the backend."
),
)
@field_validator("url")
@classmethod
def validate_url(cls, v: str) -> str:
if not isinstance(v, str) or not v.startswith(("http://", "https://")):
raise ValueError("config.url must be an http(s) URL")
return v
@field_validator("tools_filter")
@classmethod
def validate_tools_filter(cls, v: list[str]) -> list[str]:
if not all(isinstance(tool_name, str) for tool_name in v):
raise ValueError("config.tools_filter must be a list of strings")
return v
class McpToolDefinition(BaseModel):
"""Persisted MCP tool definition."""
schema_version: int = Field(default=1, description="Schema version")
type: Literal["mcp"] = Field(description="Tool type")
config: McpToolConfig = Field(description="MCP server configuration")
def _format_validation_error(error: ValidationError) -> str:
parts: list[str] = []
for item in error.errors():
location = ".".join(str(part) for part in item["loc"])
parts.append(f"{location}: {item['msg']}")
return "; ".join(parts)
def validate_mcp_definition(definition: Dict[str, Any]) -> Dict[str, Any]:
"""Validate a ``type: "mcp"`` ToolModel definition and return a
normalized config dict with defaults applied.
Raises:
McpDefinitionError: if the definition is missing required fields
or uses an unsupported transport.
"""
if not isinstance(definition, dict) or definition.get("type") != "mcp":
raise McpDefinitionError("definition.type must be 'mcp'")
config = definition.get("config")
if not isinstance(config, dict):
raise McpDefinitionError("definition.config is required and must be an object")
try:
parsed = McpToolDefinition.model_validate(definition)
except ValidationError as e:
raise McpDefinitionError(_format_validation_error(e)) from e
return parsed.config.model_dump(exclude={"discovered_tools"})
def _slugify(value: str) -> str:
slug = re.sub(r"[^a-z0-9]+", "_", value.strip().lower()).strip("_")
return slug
def namespace_function_name(
tool_name: str, mcp_tool_name: str, *, fallback: str = "server"
) -> str:
"""Build a collision-safe LLM function name: ``mcp__<slug>__<tool>``.
``slug`` is derived from the Dograh ToolModel name; if it slugifies to
empty, ``fallback`` (e.g. first 8 chars of tool_uuid) is used instead.
"""
slug = _slugify(tool_name) or _slugify(fallback) or "server"
return f"mcp__{slug}__{mcp_tool_name}"

View file

@ -89,6 +89,7 @@ class Node:
self.delayed_start_duration = getattr(data, "delayed_start_duration", None)
self.tool_uuids = getattr(data, "tool_uuids", None)
self.document_uuids = getattr(data, "document_uuids", None)
self.mcp_tool_filters = getattr(data, "mcp_tool_filters", None)
self.pre_call_fetch_enabled = getattr(data, "pre_call_fetch_enabled", False)
self.pre_call_fetch_url = getattr(data, "pre_call_fetch_url", None)
self.pre_call_fetch_credential_uuid = getattr(