mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
* feat(mcp): generic MCP tool source with per-node function filtering
Adds a Model Context Protocol tool category: connect a customer MCP
server and expose its tools to the agent, with optional per-node
allow-listing of individual MCP functions.
- ToolCategory.MCP enum + alembic migration
- MCP definition validator and collision-safe function-name namespacing
- McpToolSession wrapper: graceful-degrade, per-call open/close lifecycle
- CustomToolManager MCP branch (schemas + proxy handlers)
- Per-node mcp_tool_filters threaded through DTO/graph/engine
- Best-effort discovered_tools catalog cache + POST /tools/{uuid}/mcp/refresh
- UI: MCP create/edit config, tabbed ToolSelector with per-node toggles
* feat: refactor for code standardisation and documentation
---------
Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
import uuid
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from api.enums import ToolCategory
|
|
from api.services.workflow.pipecat_engine import PipecatEngine
|
|
from api.tests.support.mcp_mock_server import running_mcp_server
|
|
|
|
|
|
def _mcp_tool(url: str):
|
|
t = MagicMock()
|
|
t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8]
|
|
t.name = "Acme MCP"
|
|
t.category = ToolCategory.MCP.value
|
|
t.definition = {
|
|
"schema_version": 1,
|
|
"type": "mcp",
|
|
"config": {"transport": "streamable_http", "url": url},
|
|
}
|
|
return t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_engine_opens_and_closes_mcp_sessions(monkeypatch):
|
|
async with running_mcp_server() as base_url:
|
|
tool = _mcp_tool(base_url)
|
|
|
|
engine = PipecatEngine.__new__(PipecatEngine)
|
|
node = MagicMock()
|
|
node.tool_uuids = [tool.tool_uuid]
|
|
workflow = MagicMock()
|
|
workflow.nodes = {"n1": node}
|
|
engine.workflow = workflow
|
|
engine._mcp_sessions = {}
|
|
|
|
from api.db import db_client
|
|
|
|
monkeypatch.setattr(
|
|
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
|
|
)
|
|
monkeypatch.setattr(
|
|
db_client, "get_credential_by_uuid", AsyncMock(return_value=None)
|
|
)
|
|
engine._get_organization_id = AsyncMock(return_value=42)
|
|
|
|
await engine._open_mcp_sessions()
|
|
try:
|
|
assert tool.tool_uuid in engine._mcp_sessions
|
|
sess = engine._mcp_sessions[tool.tool_uuid]
|
|
assert sess.available is True
|
|
assert len(sess.function_schemas()) == 2
|
|
finally:
|
|
await engine._close_mcp_sessions()
|
|
assert engine._mcp_sessions == {}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_open_mcp_sessions_swallows_db_error(monkeypatch):
|
|
engine = PipecatEngine.__new__(PipecatEngine)
|
|
node = MagicMock()
|
|
node.tool_uuids = ["uuid-deadbeef"]
|
|
workflow = MagicMock()
|
|
workflow.nodes = {"n1": node}
|
|
engine.workflow = workflow
|
|
engine._mcp_sessions = {}
|
|
|
|
from api.db import db_client
|
|
|
|
monkeypatch.setattr(
|
|
db_client,
|
|
"get_tools_by_uuids",
|
|
AsyncMock(side_effect=RuntimeError("db down")),
|
|
)
|
|
engine._get_organization_id = AsyncMock(return_value=42)
|
|
|
|
# Must NOT raise
|
|
await engine._open_mcp_sessions()
|
|
assert engine._mcp_sessions == {}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_open_mcp_sessions_skips_tool_when_credential_fetch_fails(monkeypatch):
|
|
tool = _mcp_tool("http://127.0.0.1:1/mcp")
|
|
tool.definition["config"]["credential_uuid"] = "cred-1234"
|
|
|
|
engine = PipecatEngine.__new__(PipecatEngine)
|
|
node = MagicMock()
|
|
node.tool_uuids = [tool.tool_uuid]
|
|
workflow = MagicMock()
|
|
workflow.nodes = {"n1": node}
|
|
engine.workflow = workflow
|
|
engine._mcp_sessions = {}
|
|
|
|
from api.db import db_client
|
|
|
|
monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]))
|
|
monkeypatch.setattr(
|
|
db_client,
|
|
"get_credential_by_uuid",
|
|
AsyncMock(side_effect=RuntimeError("cred store down")),
|
|
)
|
|
engine._get_organization_id = AsyncMock(return_value=42)
|
|
|
|
# Must NOT raise, and must skip the tool (no futile unauthenticated start)
|
|
await engine._open_mcp_sessions()
|
|
assert engine._mcp_sessions == {}
|