dograh/api/tests/test_mcp_custom_tool_manager.py
Paulo Busato Favarato 75839f9de5
feat(mcp): generic MCP tool source with per-node function filtering (#301)
* feat(mcp): generic MCP tool source with per-node function filtering

Adds a Model Context Protocol tool category: connect a customer MCP
server and expose its tools to the agent, with optional per-node
allow-listing of individual MCP functions.

- ToolCategory.MCP enum + alembic migration
- MCP definition validator and collision-safe function-name namespacing
- McpToolSession wrapper: graceful-degrade, per-call open/close lifecycle
- CustomToolManager MCP branch (schemas + proxy handlers)
- Per-node mcp_tool_filters threaded through DTO/graph/engine
- Best-effort discovered_tools catalog cache + POST /tools/{uuid}/mcp/refresh
- UI: MCP create/edit config, tabbed ToolSelector with per-node toggles

* feat: refactor for code standardisation and documentation

---------

Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
2026-05-19 16:10:00 +05:30

181 lines
5.8 KiB
Python

import uuid
from unittest.mock import AsyncMock, MagicMock
import pytest
from api.enums import ToolCategory
from api.services.workflow.mcp_tool_session import McpToolSession
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from api.tests.support.mcp_mock_server import running_mcp_server
def _mcp_tool():
t = MagicMock()
t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8]
t.name = "Acme MCP"
t.category = ToolCategory.MCP.value
t.definition = {"type": "mcp", "config": {"url": "https://x/mcp"}}
return t
@pytest.mark.asyncio
async def test_get_tool_schemas_and_handler_for_mcp(monkeypatch):
async with running_mcp_server() as base_url:
tool = _mcp_tool()
session = McpToolSession(
tool_uuid=tool.tool_uuid,
tool_name=tool.name,
url=base_url,
credential=None,
tools_filter=[],
timeout_secs=10,
sse_read_timeout_secs=10,
)
await session.start()
engine = MagicMock()
engine._mcp_sessions = {tool.tool_uuid: session}
registered = {}
reg_kwargs = {}
def _reg(name, fn, **kw):
registered[name] = fn
reg_kwargs[name] = kw
engine.llm.register_function = _reg
mgr = CustomToolManager(engine)
mgr.get_organization_id = AsyncMock(return_value=42)
from api.db import db_client
monkeypatch.setattr(
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
)
try:
schemas = await mgr.get_tool_schemas([tool.tool_uuid])
names = sorted(s.name for s in schemas)
assert names == ["mcp__acme_mcp__add", "mcp__acme_mcp__echo"]
await mgr.register_handlers([tool.tool_uuid])
assert "mcp__acme_mcp__echo" in registered
assert reg_kwargs["mcp__acme_mcp__echo"]["timeout_secs"] == pytest.approx(
15.0
)
captured = {}
class P:
function_name = "mcp__acme_mcp__echo"
arguments = {"text": "yo"}
async def result_callback(self, r, *, properties=None):
captured["r"] = r
await registered["mcp__acme_mcp__echo"](P())
assert "echo:yo" in str(captured["r"])
finally:
await session.close()
@pytest.mark.asyncio
async def test_unavailable_mcp_session_contributes_nothing(monkeypatch):
tool = _mcp_tool()
session = McpToolSession(
tool_uuid=tool.tool_uuid,
tool_name=tool.name,
url="http://127.0.0.1:1/mcp",
credential=None,
tools_filter=[],
timeout_secs=1,
sse_read_timeout_secs=1,
)
await session.start() # degrades
engine = MagicMock()
engine._mcp_sessions = {tool.tool_uuid: session}
mgr = CustomToolManager(engine)
mgr.get_organization_id = AsyncMock(return_value=42)
from api.db import db_client
monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]))
schemas = await mgr.get_tool_schemas([tool.tool_uuid])
assert schemas == []
await mgr.register_handlers([tool.tool_uuid]) # must not raise
def test_call_timeout_secs_is_read_timeout_plus_buffer():
session = McpToolSession(
tool_uuid="uuid-abc123",
tool_name="Acme MCP",
url="https://x/mcp",
credential=None,
tools_filter=[],
timeout_secs=10,
sse_read_timeout_secs=20,
)
assert session.call_timeout_secs == 25.0
@pytest.mark.asyncio
async def test_per_node_mcp_filter_intersection(monkeypatch):
async with running_mcp_server() as base_url:
tool = _mcp_tool()
session = McpToolSession(
tool_uuid=tool.tool_uuid,
tool_name=tool.name,
url=base_url,
credential=None,
tools_filter=[],
timeout_secs=10,
sse_read_timeout_secs=10,
)
await session.start()
engine = MagicMock()
engine._mcp_sessions = {tool.tool_uuid: session}
registered = {}
engine.llm.register_function = lambda name, fn, **kw: registered.__setitem__(
name, fn
)
mgr = CustomToolManager(engine)
mgr.get_organization_id = AsyncMock(return_value=42)
from api.db import db_client
monkeypatch.setattr(
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
)
try:
# Allow only raw "echo" for this node
filters = {tool.tool_uuid: ["echo"]}
schemas = await mgr.get_tool_schemas(
[tool.tool_uuid], mcp_tool_filters=filters
)
# Check only "echo" schema returned (namespaced name depends on tool.name)
assert len(schemas) == 1
assert all("echo" in s.name for s in schemas)
await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters=filters)
assert len(registered) == 1
assert all("echo" in k for k in registered)
# No filter entry for this uuid = none (default-none)
registered.clear()
result = await mgr.get_tool_schemas([tool.tool_uuid], mcp_tool_filters={})
assert result == []
await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters={})
assert registered == {}
# mcp_tool_filters=None = backward-compatible (all tools)
registered.clear()
all_schemas = await mgr.get_tool_schemas([tool.tool_uuid])
assert len(all_schemas) == 2 # both echo and add
await mgr.register_handlers([tool.tool_uuid])
assert len(registered) == 2 # both handlers registered
finally:
await session.close()