feat(mcp): generic MCP tool source with per-node function filtering (#301)

* feat(mcp): generic MCP tool source with per-node function filtering

Adds a Model Context Protocol tool category: connect a customer MCP
server and expose its tools to the agent, with optional per-node
allow-listing of individual MCP functions.

- ToolCategory.MCP enum + alembic migration
- MCP definition validator and collision-safe function-name namespacing
- McpToolSession wrapper: graceful-degrade, per-call open/close lifecycle
- CustomToolManager MCP branch (schemas + proxy handlers)
- Per-node mcp_tool_filters threaded through DTO/graph/engine
- Best-effort discovered_tools catalog cache + POST /tools/{uuid}/mcp/refresh
- UI: MCP create/edit config, tabbed ToolSelector with per-node toggles

* feat: refactor for code standardisation and documentation

---------

Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
Paulo Busato Favarato 2026-05-19 07:40:00 -03:00 committed by GitHub
parent 0097974444
commit 75839f9de5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 3028 additions and 137 deletions

View file

View file

@ -0,0 +1,103 @@
"""A real FastMCP server exposing 2 tools over streamable-HTTP, run in a
background uvicorn thread on an ephemeral port. Used to exercise the real
MCP protocol path in tests.
"""
from __future__ import annotations
import asyncio
import contextlib
import socket
import threading
from typing import AsyncIterator
import httpx
import uvicorn
from fastmcp import FastMCP
from starlette.responses import JSONResponse
def _build_app(required_headers: dict[str, str] | None = None):
mcp = FastMCP("mock-mcp")
@mcp.tool()
def echo(text: str) -> str:
"""Echo the provided text back."""
return f"echo:{text}"
@mcp.tool()
def add(a: int, b: int) -> int:
"""Add two integers."""
return a + b
# FastMCP 3.x: ASGI app for streamable-HTTP transport at "/mcp".
app = mcp.http_app()
if not required_headers:
return app
normalized = {k.lower(): v for k, v in required_headers.items()}
async def guarded_app(scope, receive, send):
if scope["type"] == "http":
headers = {
key.decode("latin-1").lower(): value.decode("latin-1")
for key, value in scope.get("headers", [])
}
for header_name, expected_value in normalized.items():
if headers.get(header_name) != expected_value:
response = JSONResponse(
{"detail": f"Missing or invalid header: {header_name}"},
status_code=401,
)
await response(scope, receive, send)
return
await app(scope, receive, send)
return guarded_app
def _free_port() -> int:
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
@contextlib.asynccontextmanager
async def running_mcp_server(
*, required_headers: dict[str, str] | None = None
) -> AsyncIterator[str]:
"""Yield the base streamable-HTTP URL of a live mock MCP server."""
port = _free_port()
config = uvicorn.Config(
_build_app(required_headers), host="127.0.0.1", port=port, log_level="warning"
)
server = uvicorn.Server(config)
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
base_url = f"http://127.0.0.1:{port}/mcp"
server_ready = False
for _ in range(50):
try:
async with httpx.AsyncClient() as client:
await client.get(base_url, timeout=0.5)
server_ready = True
break
except Exception:
await asyncio.sleep(0.1)
if not server_ready:
server.should_exit = True
thread.join(timeout=5)
raise RuntimeError(f"Mock MCP server at {base_url} failed to start within 5s")
try:
yield base_url
finally:
server.should_exit = True
thread.join(timeout=5)
if thread.is_alive():
import warnings
warnings.warn(
"Mock MCP server thread did not terminate within 5s",
ResourceWarning,
)

View file

@ -0,0 +1,63 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from api.mcp_server.auth import authenticate_mcp_request
@pytest.mark.asyncio
async def test_authenticate_mcp_request_accepts_bearer_authorization():
user = MagicMock()
user.id = 1
user.selected_organization_id = 90
with (
patch(
"api.mcp_server.auth.get_http_headers",
return_value={"authorization": "Bearer secret-api-key"},
) as get_headers,
patch(
"api.mcp_server.auth._handle_api_key_auth",
AsyncMock(return_value=user),
) as handle_auth,
):
authed = await authenticate_mcp_request()
assert authed is user
get_headers.assert_called_once_with(include={"authorization"})
handle_auth.assert_awaited_once_with("secret-api-key")
@pytest.mark.asyncio
async def test_authenticate_mcp_request_accepts_x_api_key():
user = MagicMock()
user.id = 2
user.selected_organization_id = 91
with (
patch(
"api.mcp_server.auth.get_http_headers",
return_value={"x-api-key": "secret-api-key"},
) as get_headers,
patch(
"api.mcp_server.auth._handle_api_key_auth",
AsyncMock(return_value=user),
) as handle_auth,
):
authed = await authenticate_mcp_request()
assert authed is user
get_headers.assert_called_once_with(include={"authorization"})
handle_auth.assert_awaited_once_with("secret-api-key")
@pytest.mark.asyncio
async def test_authenticate_mcp_request_rejects_missing_api_key():
with patch("api.mcp_server.auth.get_http_headers", return_value={}) as get_headers:
with pytest.raises(HTTPException) as exc_info:
await authenticate_mcp_request()
assert exc_info.value.status_code == 401
assert "Missing API key" in str(exc_info.value.detail)
get_headers.assert_called_once_with(include={"authorization"})

View file

@ -0,0 +1,181 @@
import uuid
from unittest.mock import AsyncMock, MagicMock
import pytest
from api.enums import ToolCategory
from api.services.workflow.mcp_tool_session import McpToolSession
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from api.tests.support.mcp_mock_server import running_mcp_server
def _mcp_tool():
t = MagicMock()
t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8]
t.name = "Acme MCP"
t.category = ToolCategory.MCP.value
t.definition = {"type": "mcp", "config": {"url": "https://x/mcp"}}
return t
@pytest.mark.asyncio
async def test_get_tool_schemas_and_handler_for_mcp(monkeypatch):
async with running_mcp_server() as base_url:
tool = _mcp_tool()
session = McpToolSession(
tool_uuid=tool.tool_uuid,
tool_name=tool.name,
url=base_url,
credential=None,
tools_filter=[],
timeout_secs=10,
sse_read_timeout_secs=10,
)
await session.start()
engine = MagicMock()
engine._mcp_sessions = {tool.tool_uuid: session}
registered = {}
reg_kwargs = {}
def _reg(name, fn, **kw):
registered[name] = fn
reg_kwargs[name] = kw
engine.llm.register_function = _reg
mgr = CustomToolManager(engine)
mgr.get_organization_id = AsyncMock(return_value=42)
from api.db import db_client
monkeypatch.setattr(
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
)
try:
schemas = await mgr.get_tool_schemas([tool.tool_uuid])
names = sorted(s.name for s in schemas)
assert names == ["mcp__acme_mcp__add", "mcp__acme_mcp__echo"]
await mgr.register_handlers([tool.tool_uuid])
assert "mcp__acme_mcp__echo" in registered
assert reg_kwargs["mcp__acme_mcp__echo"]["timeout_secs"] == pytest.approx(
15.0
)
captured = {}
class P:
function_name = "mcp__acme_mcp__echo"
arguments = {"text": "yo"}
async def result_callback(self, r, *, properties=None):
captured["r"] = r
await registered["mcp__acme_mcp__echo"](P())
assert "echo:yo" in str(captured["r"])
finally:
await session.close()
@pytest.mark.asyncio
async def test_unavailable_mcp_session_contributes_nothing(monkeypatch):
tool = _mcp_tool()
session = McpToolSession(
tool_uuid=tool.tool_uuid,
tool_name=tool.name,
url="http://127.0.0.1:1/mcp",
credential=None,
tools_filter=[],
timeout_secs=1,
sse_read_timeout_secs=1,
)
await session.start() # degrades
engine = MagicMock()
engine._mcp_sessions = {tool.tool_uuid: session}
mgr = CustomToolManager(engine)
mgr.get_organization_id = AsyncMock(return_value=42)
from api.db import db_client
monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]))
schemas = await mgr.get_tool_schemas([tool.tool_uuid])
assert schemas == []
await mgr.register_handlers([tool.tool_uuid]) # must not raise
def test_call_timeout_secs_is_read_timeout_plus_buffer():
session = McpToolSession(
tool_uuid="uuid-abc123",
tool_name="Acme MCP",
url="https://x/mcp",
credential=None,
tools_filter=[],
timeout_secs=10,
sse_read_timeout_secs=20,
)
assert session.call_timeout_secs == 25.0
@pytest.mark.asyncio
async def test_per_node_mcp_filter_intersection(monkeypatch):
async with running_mcp_server() as base_url:
tool = _mcp_tool()
session = McpToolSession(
tool_uuid=tool.tool_uuid,
tool_name=tool.name,
url=base_url,
credential=None,
tools_filter=[],
timeout_secs=10,
sse_read_timeout_secs=10,
)
await session.start()
engine = MagicMock()
engine._mcp_sessions = {tool.tool_uuid: session}
registered = {}
engine.llm.register_function = lambda name, fn, **kw: registered.__setitem__(
name, fn
)
mgr = CustomToolManager(engine)
mgr.get_organization_id = AsyncMock(return_value=42)
from api.db import db_client
monkeypatch.setattr(
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
)
try:
# Allow only raw "echo" for this node
filters = {tool.tool_uuid: ["echo"]}
schemas = await mgr.get_tool_schemas(
[tool.tool_uuid], mcp_tool_filters=filters
)
# Check only "echo" schema returned (namespaced name depends on tool.name)
assert len(schemas) == 1
assert all("echo" in s.name for s in schemas)
await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters=filters)
assert len(registered) == 1
assert all("echo" in k for k in registered)
# No filter entry for this uuid = none (default-none)
registered.clear()
result = await mgr.get_tool_schemas([tool.tool_uuid], mcp_tool_filters={})
assert result == []
await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters={})
assert registered == {}
# mcp_tool_filters=None = backward-compatible (all tools)
registered.clear()
all_schemas = await mgr.get_tool_schemas([tool.tool_uuid])
assert len(all_schemas) == 2 # both echo and add
await mgr.register_handlers([tool.tool_uuid])
assert len(registered) == 2 # both handlers registered
finally:
await session.close()

View file

@ -0,0 +1,107 @@
import uuid
from unittest.mock import AsyncMock, MagicMock
import pytest
from api.enums import ToolCategory
from api.services.workflow.pipecat_engine import PipecatEngine
from api.tests.support.mcp_mock_server import running_mcp_server
def _mcp_tool(url: str):
t = MagicMock()
t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8]
t.name = "Acme MCP"
t.category = ToolCategory.MCP.value
t.definition = {
"schema_version": 1,
"type": "mcp",
"config": {"transport": "streamable_http", "url": url},
}
return t
@pytest.mark.asyncio
async def test_engine_opens_and_closes_mcp_sessions(monkeypatch):
async with running_mcp_server() as base_url:
tool = _mcp_tool(base_url)
engine = PipecatEngine.__new__(PipecatEngine)
node = MagicMock()
node.tool_uuids = [tool.tool_uuid]
workflow = MagicMock()
workflow.nodes = {"n1": node}
engine.workflow = workflow
engine._mcp_sessions = {}
from api.db import db_client
monkeypatch.setattr(
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
)
monkeypatch.setattr(
db_client, "get_credential_by_uuid", AsyncMock(return_value=None)
)
engine._get_organization_id = AsyncMock(return_value=42)
await engine._open_mcp_sessions()
try:
assert tool.tool_uuid in engine._mcp_sessions
sess = engine._mcp_sessions[tool.tool_uuid]
assert sess.available is True
assert len(sess.function_schemas()) == 2
finally:
await engine._close_mcp_sessions()
assert engine._mcp_sessions == {}
@pytest.mark.asyncio
async def test_open_mcp_sessions_swallows_db_error(monkeypatch):
engine = PipecatEngine.__new__(PipecatEngine)
node = MagicMock()
node.tool_uuids = ["uuid-deadbeef"]
workflow = MagicMock()
workflow.nodes = {"n1": node}
engine.workflow = workflow
engine._mcp_sessions = {}
from api.db import db_client
monkeypatch.setattr(
db_client,
"get_tools_by_uuids",
AsyncMock(side_effect=RuntimeError("db down")),
)
engine._get_organization_id = AsyncMock(return_value=42)
# Must NOT raise
await engine._open_mcp_sessions()
assert engine._mcp_sessions == {}
@pytest.mark.asyncio
async def test_open_mcp_sessions_skips_tool_when_credential_fetch_fails(monkeypatch):
tool = _mcp_tool("http://127.0.0.1:1/mcp")
tool.definition["config"]["credential_uuid"] = "cred-1234"
engine = PipecatEngine.__new__(PipecatEngine)
node = MagicMock()
node.tool_uuids = [tool.tool_uuid]
workflow = MagicMock()
workflow.nodes = {"n1": node}
engine.workflow = workflow
engine._mcp_sessions = {}
from api.db import db_client
monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]))
monkeypatch.setattr(
db_client,
"get_credential_by_uuid",
AsyncMock(side_effect=RuntimeError("cred store down")),
)
engine._get_organization_id = AsyncMock(return_value=42)
# Must NOT raise, and must skip the tool (no futile unauthenticated start)
await engine._open_mcp_sessions()
assert engine._mcp_sessions == {}

View file

@ -0,0 +1,112 @@
import importlib
import pytest
from api.enums import ToolCategory
from api.routes.tool import McpToolConfig as RouteMcpToolConfig
from api.routes.tool import McpToolDefinition as RouteMcpToolDefinition
from api.services.workflow.tools.mcp_tool import (
McpDefinitionError,
McpToolConfig,
McpToolDefinition,
namespace_function_name,
validate_mcp_definition,
)
def test_mcp_category_exists():
assert ToolCategory.MCP.value == "mcp"
assert ToolCategory("mcp") is ToolCategory.MCP
def test_mcp_migration_present_and_chained(monkeypatch):
mod = importlib.import_module(
"api.alembic.versions.0a1b2c3d4e5f_add_mcp_in_toolcategory"
)
assert mod.revision == "0a1b2c3d4e5f"
assert mod.down_revision == "4c1f1e3e8ef2"
calls = []
def fake_sync_enum_values(**kwargs):
calls.append(kwargs)
monkeypatch.setattr(mod.op, "sync_enum_values", fake_sync_enum_values)
mod.upgrade()
mod.downgrade()
assert len(calls) == 2
assert calls[0]["enum_name"] == "tool_category"
assert "mcp" in calls[0]["new_values"]
assert "mcp" not in calls[1]["new_values"]
def test_route_reuses_shared_mcp_models():
assert RouteMcpToolConfig is McpToolConfig
assert RouteMcpToolDefinition is McpToolDefinition
def test_validate_mcp_definition_ok():
cfg = validate_mcp_definition(
{
"schema_version": 1,
"type": "mcp",
"config": {
"transport": "streamable_http",
"url": "https://acme.example.com/mcp",
"credential_uuid": "cred-123",
"tools_filter": ["lookup_patient"],
"timeout_secs": 30,
"sse_read_timeout_secs": 300,
},
}
)
assert cfg["url"] == "https://acme.example.com/mcp"
assert cfg["transport"] == "streamable_http"
assert cfg["tools_filter"] == ["lookup_patient"]
assert cfg["timeout_secs"] == 30
assert cfg["sse_read_timeout_secs"] == 300
assert cfg["credential_uuid"] == "cred-123"
def test_validate_mcp_definition_defaults():
cfg = validate_mcp_definition({"type": "mcp", "config": {"url": "https://x/mcp"}})
assert cfg["transport"] == "streamable_http"
assert cfg["tools_filter"] == []
assert cfg["timeout_secs"] == 30
assert cfg["sse_read_timeout_secs"] == 300
assert cfg["credential_uuid"] is None
@pytest.mark.parametrize(
"definition",
[
{"type": "mcp", "config": {}},
{"type": "mcp", "config": {"url": ""}},
{"type": "mcp", "config": {"url": "ftp://x"}},
{"type": "mcp"},
{"type": "mcp", "config": {"url": "https://x", "transport": "stdio"}},
],
)
def test_validate_mcp_definition_rejects(definition):
with pytest.raises(McpDefinitionError):
validate_mcp_definition(definition)
def test_validate_mcp_definition_zero_timeout_preserved():
cfg = validate_mcp_definition(
{"type": "mcp", "config": {"url": "https://x/mcp", "timeout_secs": 0}}
)
assert cfg["timeout_secs"] == 0
def test_namespace_function_name():
assert (
namespace_function_name("Acme MCP", "lookup_patient")
== "mcp__acme_mcp__lookup_patient"
)
assert (
namespace_function_name("", "ping", fallback="abcd1234")
== "mcp__abcd1234__ping"
)

View file

@ -0,0 +1,437 @@
"""Route-level tests for the MCP tool definition schema.
These tests exercise the Pydantic request models (CreateToolRequest /
UpdateToolRequest) to catch schema gaps at the route/request-model layer
the layer where the pre-fix defect lived (HTTP 422 on every MCP tool
creation attempt).
Test coverage:
- CreateToolRequest validates a valid MCP definition (was 422 before Part A).
- UpdateToolRequest validates a valid MCP definition.
- Invalid MCP bodies are rejected (ftp:// url, missing url).
- Round-trip: validated definition dict passes through validate_mcp_definition
unchanged, proving the request schema and call-time validator agree.
- Full HTTP round-trip via the ASGI test client (POST /api/v1/tools/).
"""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from api.routes.tool import CreateToolRequest, McpToolDefinition, UpdateToolRequest
from api.services.workflow.tools.mcp_tool import (
validate_mcp_definition,
)
# ── Canonical valid MCP request body ─────────────────────────────────────────
VALID_MCP_DEFINITION = {
"schema_version": 1,
"type": "mcp",
"config": {
"transport": "streamable_http",
"url": "https://x/mcp",
"credential_uuid": None,
"tools_filter": [],
},
}
# ── Part A regression: CreateToolRequest / UpdateToolRequest validation ───────
def test_create_tool_request_accepts_mcp_definition():
"""CreateToolRequest must accept an MCP definition (was HTTP 422 before fix)."""
req = CreateToolRequest(
name="My MCP Tool",
description="Integration via MCP",
category="mcp",
definition=VALID_MCP_DEFINITION,
)
assert isinstance(req.definition, McpToolDefinition)
assert req.definition.type == "mcp"
assert req.definition.config.url == "https://x/mcp"
assert req.definition.config.transport == "streamable_http"
assert req.definition.config.credential_uuid is None
assert req.definition.config.tools_filter == []
assert req.definition.config.timeout_secs == 30
assert req.definition.config.sse_read_timeout_secs == 300
def test_update_tool_request_accepts_mcp_definition():
"""UpdateToolRequest must also accept an MCP definition."""
req = UpdateToolRequest(
name="Updated MCP Tool",
definition=VALID_MCP_DEFINITION,
)
assert isinstance(req.definition, McpToolDefinition)
assert req.definition.type == "mcp"
assert req.definition.config.url == "https://x/mcp"
def test_create_tool_request_accepts_mcp_with_all_fields():
"""All optional MCP config fields are accepted and preserved."""
req = CreateToolRequest(
name="Full MCP Tool",
category="mcp",
definition={
"schema_version": 1,
"type": "mcp",
"config": {
"transport": "streamable_http",
"url": "https://acme.example.com/mcp",
"credential_uuid": "cred-abc-123",
"tools_filter": ["lookup_patient", "schedule_appointment"],
"timeout_secs": 60,
"sse_read_timeout_secs": 600,
},
},
)
cfg = req.definition.config # type: ignore[union-attr]
assert cfg.url == "https://acme.example.com/mcp"
assert cfg.credential_uuid == "cred-abc-123"
assert cfg.tools_filter == ["lookup_patient", "schedule_appointment"]
assert cfg.timeout_secs == 60
assert cfg.sse_read_timeout_secs == 600
# ── Invalid bodies are rejected ───────────────────────────────────────────────
@pytest.mark.parametrize(
"definition",
[
# ftp:// URL — rejected by McpToolConfig.validate_url
{
"schema_version": 1,
"type": "mcp",
"config": {"transport": "streamable_http", "url": "ftp://x/mcp"},
},
# Empty url — rejected by McpToolConfig.validate_url
{
"schema_version": 1,
"type": "mcp",
"config": {"transport": "streamable_http", "url": ""},
},
# Missing url — rejected by McpToolConfig (required field)
{
"schema_version": 1,
"type": "mcp",
"config": {"transport": "streamable_http"},
},
# Unsupported transport — rejected because Literal["streamable_http"] constraint
{
"schema_version": 1,
"type": "mcp",
"config": {"url": "https://x/mcp", "transport": "stdio"},
},
],
)
def test_create_tool_request_rejects_invalid_mcp_definition(definition):
"""Invalid MCP definitions must raise ValidationError."""
with pytest.raises(ValidationError):
CreateToolRequest(
name="Bad MCP Tool",
category="mcp",
definition=definition,
)
# ── Round-trip compatibility: request schema ↔ validate_mcp_definition ───────
def test_mcp_definition_round_trips_through_validate_mcp_definition():
"""The dict produced by CreateToolRequest.definition.model_dump() must be
accepted by validate_mcp_definition without raising, and the result must
contain the expected fields. This proves the request-layer schema and the
call-time validator agree on the stored config shape."""
req = CreateToolRequest(
name="Round-Trip MCP Tool",
category="mcp",
definition={
"schema_version": 1,
"type": "mcp",
"config": {
"transport": "streamable_http",
"url": "https://roundtrip.example.com/mcp",
"credential_uuid": "cred-rt-456",
"tools_filter": ["ping"],
"timeout_secs": 45,
"sse_read_timeout_secs": 400,
},
},
)
# Simulate what the route does: persist definition as a plain dict
persisted = req.definition.model_dump() # type: ignore[union-attr]
# validate_mcp_definition must accept the persisted shape without raising
normalized = validate_mcp_definition(persisted)
assert normalized["url"] == "https://roundtrip.example.com/mcp"
assert normalized["transport"] == "streamable_http"
assert normalized["credential_uuid"] == "cred-rt-456"
assert normalized["tools_filter"] == ["ping"]
assert normalized["timeout_secs"] == 45
assert normalized["sse_read_timeout_secs"] == 400
def test_mcp_definition_round_trip_defaults():
"""Round-trip with minimal body: defaults fill in correctly and
validate_mcp_definition agrees on them."""
req = CreateToolRequest(
name="Minimal MCP Tool",
category="mcp",
definition=VALID_MCP_DEFINITION,
)
persisted = req.definition.model_dump() # type: ignore[union-attr]
normalized = validate_mcp_definition(persisted)
assert normalized["transport"] == "streamable_http"
assert normalized["tools_filter"] == []
assert normalized["timeout_secs"] == 30
assert normalized["sse_read_timeout_secs"] == 300
assert normalized["credential_uuid"] is None
# Part B: auth_header / auth_scheme must NOT be present in the normalized
# config dict (they were dead config removed in the fix)
assert "auth_header" not in normalized
assert "auth_scheme" not in normalized
# ── Full HTTP round-trip via ASGI test client ─────────────────────────────────
async def test_post_tool_mcp_returns_200(test_client_factory, db_session):
"""POST /api/v1/tools/ with an MCP definition must return HTTP 200 and
persist the definition with type='mcp'. Before Part A this always
returned 422."""
# Create a user and an organization, then link them so the route's
# selected_organization_id check passes.
user, _ = await db_session.get_or_create_user_by_provider_id("mcp_route_test_user")
org, _ = await db_session.get_or_create_organization_by_provider_id(
"mcp_route_test_org", user.id
)
await db_session.update_user_selected_organization(user.id, org.id)
# Reload the user so selected_organization_id is populated on the object.
user = await db_session.get_user_by_id(user.id)
async with test_client_factory(user) as client:
response = await client.post(
"/api/v1/tools/",
json={
"name": "HTTP Round-Trip MCP Tool",
"description": "Testing the full route",
"category": "mcp",
"definition": {
"schema_version": 1,
"type": "mcp",
"config": {
"transport": "streamable_http",
"url": "https://roundtrip.example.com/mcp",
"credential_uuid": None,
"tools_filter": [],
},
},
},
)
assert response.status_code == 200, (
f"Expected 200, got {response.status_code}: {response.text}"
)
body = response.json()
assert body["definition"]["type"] == "mcp"
assert body["definition"]["config"]["url"] == "https://roundtrip.example.com/mcp"
assert body["category"] == "mcp"
async def test_post_tool_mcp_invalid_url_returns_422(test_client_factory, db_session):
"""POST /api/v1/tools/ with an ftp:// URL must return HTTP 422."""
user, _ = await db_session.get_or_create_user_by_provider_id(
"mcp_route_test_user_422"
)
org, _ = await db_session.get_or_create_organization_by_provider_id(
"mcp_route_test_org_422", user.id
)
await db_session.update_user_selected_organization(user.id, org.id)
user = await db_session.get_user_by_id(user.id)
async with test_client_factory(user) as client:
response = await client.post(
"/api/v1/tools/",
json={
"name": "Bad MCP Tool",
"category": "mcp",
"definition": {
"schema_version": 1,
"type": "mcp",
"config": {
"transport": "streamable_http",
"url": "ftp://invalid.example.com/mcp",
},
},
},
)
assert response.status_code == 422
# ── Task 6: discovered_tools field and _populate_discovered_tools helper ──────
from unittest.mock import AsyncMock, MagicMock
from api.routes.tool import McpToolConfig, _populate_discovered_tools
def test_mcp_config_accepts_discovered_tools():
cfg = McpToolConfig(
url="https://x/mcp",
discovered_tools=[{"name": "echo", "description": "Echo"}],
)
assert cfg.discovered_tools == [{"name": "echo", "description": "Echo"}]
# Defaults to [] when omitted
assert McpToolConfig(url="https://x/mcp").discovered_tools == []
@pytest.mark.asyncio
async def test_populate_discovered_tools_overwrites_cache(monkeypatch):
import api.routes.tool as tool_mod
monkeypatch.setattr(
tool_mod,
"discover_mcp_tools",
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
)
definition = {
"schema_version": 1,
"type": "mcp",
"config": {
"url": "https://x/mcp",
"tools_filter": [],
"discovered_tools": [{"name": "stale", "description": "old"}],
},
}
out = await _populate_discovered_tools(definition, organization_id=1)
assert out["config"]["discovered_tools"] == [
{"name": "echo", "description": "Echo"}
]
@pytest.mark.asyncio
async def test_populate_discovered_tools_non_mcp_is_noop():
definition = {"schema_version": 1, "type": "http_api", "config": {}}
out = await _populate_discovered_tools(definition, organization_id=1)
assert out == definition # untouched
@pytest.mark.asyncio
async def test_populate_discovered_tools_server_down_sets_empty(monkeypatch):
import api.routes.tool as tool_mod
monkeypatch.setattr(
tool_mod,
"discover_mcp_tools",
AsyncMock(side_effect=RuntimeError("connection refused")),
)
definition = {
"schema_version": 1,
"type": "mcp",
"config": {"url": "https://x/mcp", "tools_filter": []},
}
out = await _populate_discovered_tools(definition, organization_id=1)
assert out["config"]["discovered_tools"] == []
# ── Task 7: POST /{tool_uuid}/mcp/refresh ─────────────────────────────────────
from fastapi import HTTPException
from api.routes.tool import refresh_mcp_tools
def _fake_user(org_id=1):
u = MagicMock()
u.selected_organization_id = org_id
u.id = 1
u.provider_id = "p1"
return u
def _mcp_tool_model(org_id=1):
t = MagicMock()
t.tool_uuid = "tu-mcp"
t.name = "Mock MCP"
t.category = "mcp"
t.definition = {
"schema_version": 1,
"type": "mcp",
"config": {"url": "https://x/mcp", "tools_filter": []},
}
return t
@pytest.mark.asyncio
async def test_refresh_success(monkeypatch):
import api.routes.tool as tool_mod
tool = _mcp_tool_model()
monkeypatch.setattr(
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
)
monkeypatch.setattr(
tool_mod.db_client,
"update_tool",
AsyncMock(return_value=tool),
)
monkeypatch.setattr(
tool_mod,
"discover_mcp_tools",
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
)
resp = await refresh_mcp_tools("tu-mcp", user=_fake_user())
assert resp.discovered_tools == [{"name": "echo", "description": "Echo"}]
assert resp.error is None
@pytest.mark.asyncio
async def test_refresh_server_down_returns_200_with_error(monkeypatch):
import api.routes.tool as tool_mod
tool = _mcp_tool_model()
monkeypatch.setattr(
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
)
monkeypatch.setattr(tool_mod.db_client, "update_tool", AsyncMock(return_value=tool))
monkeypatch.setattr(tool_mod, "discover_mcp_tools", AsyncMock(return_value=[]))
resp = await refresh_mcp_tools("tu-mcp", user=_fake_user())
assert resp.discovered_tools == []
assert resp.error # non-empty human-readable message
# update_tool should NOT be called when discovery returns empty
tool_mod.db_client.update_tool.assert_not_called()
@pytest.mark.asyncio
async def test_refresh_non_mcp_is_400(monkeypatch):
import api.routes.tool as tool_mod
tool = _mcp_tool_model()
tool.category = "http_api"
monkeypatch.setattr(
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
)
with pytest.raises(HTTPException) as ei:
await refresh_mcp_tools("tu-mcp", user=_fake_user())
assert ei.value.status_code == 400
@pytest.mark.asyncio
async def test_refresh_not_found_is_404(monkeypatch):
import api.routes.tool as tool_mod
monkeypatch.setattr(
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=None)
)
with pytest.raises(HTTPException) as ei:
await refresh_mcp_tools("nope", user=_fake_user())
assert ei.value.status_code == 404

View file

@ -0,0 +1,274 @@
from datetime import timedelta
from unittest.mock import MagicMock
import httpx
import pytest
from api.services.workflow.mcp_tool_session import (
McpToolSession,
build_streamable_http_params,
discover_mcp_tools,
)
from api.tests.support.mcp_mock_server import running_mcp_server
@pytest.mark.asyncio
async def test_mock_server_starts_and_serves():
async with running_mcp_server() as base_url:
async with httpx.AsyncClient() as client:
resp = await client.get(base_url, timeout=5.0)
assert resp.status_code in (400, 404, 405, 406)
def test_build_streamable_http_params_with_credential():
cred = MagicMock()
cred.credential_type = "bearer_token"
cred.credential_data = {"token": "abc"}
params = build_streamable_http_params(
url="https://acme.example.com/mcp",
credential=cred,
timeout_secs=30,
sse_read_timeout_secs=300,
)
assert params.url == "https://acme.example.com/mcp"
assert params.headers == {"Authorization": "Bearer abc"}
assert params.timeout == timedelta(seconds=30)
assert params.sse_read_timeout == timedelta(seconds=300)
def test_build_streamable_http_params_no_credential():
params = build_streamable_http_params(
url="https://acme.example.com/mcp",
credential=None,
timeout_secs=10,
sse_read_timeout_secs=20,
)
assert params.headers is None or params.headers == {}
@pytest.mark.asyncio
async def test_session_start_passes_auth_header_to_real_server():
cred = MagicMock()
cred.credential_type = "bearer_token"
cred.credential_data = {"token": "abc"}
async with running_mcp_server(
required_headers={"Authorization": "Bearer abc"}
) as base_url:
session = McpToolSession(
tool_uuid="uuid-auth-ok",
tool_name="Secure MCP",
url=base_url,
credential=cred,
tools_filter=[],
timeout_secs=10,
sse_read_timeout_secs=20,
)
await session.start()
try:
assert session.available is True
names = sorted(s.name for s in session.function_schemas())
assert names == ["mcp__secure_mcp__add", "mcp__secure_mcp__echo"]
result = await session.call("mcp__secure_mcp__echo", {"text": "hi"})
assert "echo:hi" in result
finally:
await session.close()
@pytest.mark.asyncio
async def test_session_auth_failure_degrades_not_raises():
async with running_mcp_server(
required_headers={"Authorization": "Bearer abc"}
) as base_url:
session = McpToolSession(
tool_uuid="uuid-auth-fail",
tool_name="Secure MCP",
url=base_url,
credential=None,
tools_filter=[],
timeout_secs=2,
sse_read_timeout_secs=2,
)
await session.start() # must degrade instead of raising on 401
try:
assert session.available is False
assert session.function_schemas() == []
finally:
await session.close()
@pytest.mark.asyncio
async def test_session_start_lists_and_calls_real_server():
async with running_mcp_server() as base_url:
session = McpToolSession(
tool_uuid="uuid-1234abcd",
tool_name="Acme MCP",
url=base_url,
credential=None,
tools_filter=[],
timeout_secs=10,
sse_read_timeout_secs=20,
)
await session.start()
try:
assert session.available is True
schemas = session.function_schemas()
names = sorted(s.name for s in schemas)
assert names == ["mcp__acme_mcp__add", "mcp__acme_mcp__echo"]
result = await session.call("mcp__acme_mcp__echo", {"text": "hi"})
assert "echo:hi" in result
finally:
await session.close()
@pytest.mark.asyncio
async def test_session_tools_filter_applied():
async with running_mcp_server() as base_url:
session = McpToolSession(
tool_uuid="uuid-1234abcd",
tool_name="Acme MCP",
url=base_url,
credential=None,
tools_filter=["echo"],
timeout_secs=10,
sse_read_timeout_secs=20,
)
await session.start()
try:
names = sorted(s.name for s in session.function_schemas())
assert names == ["mcp__acme_mcp__echo"]
finally:
await session.close()
@pytest.mark.asyncio
async def test_session_unreachable_degrades_not_raises():
session = McpToolSession(
tool_uuid="uuid-1234abcd",
tool_name="Acme MCP",
url="http://127.0.0.1:1/mcp",
credential=None,
tools_filter=[],
timeout_secs=2,
sse_read_timeout_secs=2,
)
await session.start() # must NOT raise
assert session.available is False
assert session.function_schemas() == []
await session.close()
@pytest.mark.asyncio
async def test_call_on_unavailable_session_raises():
session = McpToolSession(
tool_uuid="uuid-1234abcd",
tool_name="Acme MCP",
url="http://127.0.0.1:1/mcp",
credential=None,
tools_filter=[],
timeout_secs=2,
sse_read_timeout_secs=2,
)
await session.start()
with pytest.raises(RuntimeError):
await session.call("mcp__acme_mcp__echo", {"text": "x"})
await session.close()
@pytest.mark.asyncio
async def test_call_unknown_function_raises():
async with running_mcp_server() as base_url:
session = McpToolSession(
tool_uuid="uuid-1234abcd",
tool_name="Acme MCP",
url=base_url,
credential=None,
tools_filter=[],
timeout_secs=10,
sse_read_timeout_secs=10,
)
await session.start()
try:
with pytest.raises(RuntimeError):
await session.call("mcp__acme_mcp__does_not_exist", {})
finally:
await session.close()
@pytest.mark.asyncio
async def test_function_schemas_filter_by_raw_name():
async with running_mcp_server() as base_url:
session = McpToolSession(
tool_uuid="t-filter",
tool_name="Mock MCP",
url=base_url,
credential=None,
tools_filter=[],
timeout_secs=10,
sse_read_timeout_secs=10,
)
await session.start()
try:
# No arg = all (backward compatible)
all_names = sorted(s.name for s in session.function_schemas())
assert all_names == ["mcp__mock_mcp__add", "mcp__mock_mcp__echo"]
# Allow only raw "echo"
only_echo = session.function_schemas(allowed_raw_names={"echo"})
assert [s.name for s in only_echo] == ["mcp__mock_mcp__echo"]
# Empty set = none (default-none semantics)
assert session.function_schemas(allowed_raw_names=set()) == []
# Unknown raw name = skipped (pure intersection)
assert session.function_schemas(allowed_raw_names={"nope"}) == []
finally:
await session.close()
@pytest.mark.asyncio
async def test_discover_mcp_tools_success():
async with running_mcp_server() as base_url:
tools = await discover_mcp_tools(
url=base_url,
credential=None,
timeout_secs=10,
sse_read_timeout_secs=10,
)
names = sorted(t["name"] for t in tools)
assert names == ["add", "echo"]
by_name = {t["name"]: t for t in tools}
assert by_name["echo"]["description"] # non-empty description
assert set(by_name["echo"]) == {"name", "description"}
@pytest.mark.asyncio
async def test_discover_mcp_tools_server_down_returns_empty():
# Unroutable port, short timeouts: must degrade to [] (never raise).
tools = await discover_mcp_tools(
url="http://127.0.0.1:1/mcp",
credential=None,
timeout_secs=1,
sse_read_timeout_secs=1,
)
assert tools == []
def test_agent_node_data_carries_mcp_tool_filters():
from api.services.workflow.dto import AgentNodeData, NodeType
from api.services.workflow.workflow_graph import Node
data = AgentNodeData(
name="N1",
tool_uuids=["tu-1"],
mcp_tool_filters={"tu-1": ["echo"]},
)
assert data.mcp_tool_filters == {"tu-1": ["echo"]}
node = Node("n1", NodeType.agentNode, data)
assert node.mcp_tool_filters == {"tu-1": ["echo"]}
# Absent field defaults to None (backward compatible)
data2 = AgentNodeData(name="N2")
assert data2.mcp_tool_filters is None
assert Node("n2", NodeType.agentNode, data2).mcp_tool_filters is None