dograh/api/tests/test_mcp_tool_definition.py
Paulo Busato Favarato 75839f9de5
feat(mcp): generic MCP tool source with per-node function filtering (#301)
* feat(mcp): generic MCP tool source with per-node function filtering

Adds a Model Context Protocol tool category: connect a customer MCP
server and expose its tools to the agent, with optional per-node
allow-listing of individual MCP functions.

- ToolCategory.MCP enum + alembic migration
- MCP definition validator and collision-safe function-name namespacing
- McpToolSession wrapper: graceful-degrade, per-call open/close lifecycle
- CustomToolManager MCP branch (schemas + proxy handlers)
- Per-node mcp_tool_filters threaded through DTO/graph/engine
- Best-effort discovered_tools catalog cache + POST /tools/{uuid}/mcp/refresh
- UI: MCP create/edit config, tabbed ToolSelector with per-node toggles

* feat: refactor for code standardisation and documentation

---------

Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
2026-05-19 16:10:00 +05:30

112 lines
3.3 KiB
Python

import importlib
import pytest
from api.enums import ToolCategory
from api.routes.tool import McpToolConfig as RouteMcpToolConfig
from api.routes.tool import McpToolDefinition as RouteMcpToolDefinition
from api.services.workflow.tools.mcp_tool import (
McpDefinitionError,
McpToolConfig,
McpToolDefinition,
namespace_function_name,
validate_mcp_definition,
)
def test_mcp_category_exists():
assert ToolCategory.MCP.value == "mcp"
assert ToolCategory("mcp") is ToolCategory.MCP
def test_mcp_migration_present_and_chained(monkeypatch):
mod = importlib.import_module(
"api.alembic.versions.0a1b2c3d4e5f_add_mcp_in_toolcategory"
)
assert mod.revision == "0a1b2c3d4e5f"
assert mod.down_revision == "4c1f1e3e8ef2"
calls = []
def fake_sync_enum_values(**kwargs):
calls.append(kwargs)
monkeypatch.setattr(mod.op, "sync_enum_values", fake_sync_enum_values)
mod.upgrade()
mod.downgrade()
assert len(calls) == 2
assert calls[0]["enum_name"] == "tool_category"
assert "mcp" in calls[0]["new_values"]
assert "mcp" not in calls[1]["new_values"]
def test_route_reuses_shared_mcp_models():
assert RouteMcpToolConfig is McpToolConfig
assert RouteMcpToolDefinition is McpToolDefinition
def test_validate_mcp_definition_ok():
cfg = validate_mcp_definition(
{
"schema_version": 1,
"type": "mcp",
"config": {
"transport": "streamable_http",
"url": "https://acme.example.com/mcp",
"credential_uuid": "cred-123",
"tools_filter": ["lookup_patient"],
"timeout_secs": 30,
"sse_read_timeout_secs": 300,
},
}
)
assert cfg["url"] == "https://acme.example.com/mcp"
assert cfg["transport"] == "streamable_http"
assert cfg["tools_filter"] == ["lookup_patient"]
assert cfg["timeout_secs"] == 30
assert cfg["sse_read_timeout_secs"] == 300
assert cfg["credential_uuid"] == "cred-123"
def test_validate_mcp_definition_defaults():
cfg = validate_mcp_definition({"type": "mcp", "config": {"url": "https://x/mcp"}})
assert cfg["transport"] == "streamable_http"
assert cfg["tools_filter"] == []
assert cfg["timeout_secs"] == 30
assert cfg["sse_read_timeout_secs"] == 300
assert cfg["credential_uuid"] is None
@pytest.mark.parametrize(
"definition",
[
{"type": "mcp", "config": {}},
{"type": "mcp", "config": {"url": ""}},
{"type": "mcp", "config": {"url": "ftp://x"}},
{"type": "mcp"},
{"type": "mcp", "config": {"url": "https://x", "transport": "stdio"}},
],
)
def test_validate_mcp_definition_rejects(definition):
with pytest.raises(McpDefinitionError):
validate_mcp_definition(definition)
def test_validate_mcp_definition_zero_timeout_preserved():
cfg = validate_mcp_definition(
{"type": "mcp", "config": {"url": "https://x/mcp", "timeout_secs": 0}}
)
assert cfg["timeout_secs"] == 0
def test_namespace_function_name():
assert (
namespace_function_name("Acme MCP", "lookup_patient")
== "mcp__acme_mcp__lookup_patient"
)
assert (
namespace_function_name("", "ping", fallback="abcd1234")
== "mcp__abcd1234__ping"
)