mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-07-01 08:59:46 +02:00
Merge remote-tracking branch 'origin/main' into pr-316-search-docs-main-merged
This commit is contained in:
commit
4618af20b8
146 changed files with 7800 additions and 3848 deletions
|
|
@ -15,16 +15,14 @@ import pytest
|
|||
|
||||
from api.services.workflow.dto import (
|
||||
AgentNodeData,
|
||||
AgentRFNode,
|
||||
EdgeDataDTO,
|
||||
EndCallNodeData,
|
||||
EndCallRFNode,
|
||||
ExtractionVariableDTO,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
StartCallNodeData,
|
||||
StartCallRFNode,
|
||||
VariableType,
|
||||
)
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
|
|
@ -270,8 +268,9 @@ def simple_workflow() -> WorkflowGraph:
|
|||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -290,8 +289,9 @@ def simple_workflow() -> WorkflowGraph:
|
|||
],
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
@ -333,8 +333,9 @@ def three_node_workflow() -> WorkflowGraph:
|
|||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -353,8 +354,9 @@ def three_node_workflow() -> WorkflowGraph:
|
|||
],
|
||||
),
|
||||
),
|
||||
AgentRFNode(
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type="agentNode",
|
||||
position=Position(x=0, y=200),
|
||||
data=AgentNodeData(
|
||||
name="Collect Info",
|
||||
|
|
@ -372,8 +374,9 @@ def three_node_workflow() -> WorkflowGraph:
|
|||
],
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=400),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
@ -424,8 +427,9 @@ def three_node_workflow_extraction_start_only() -> WorkflowGraph:
|
|||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -444,8 +448,9 @@ def three_node_workflow_extraction_start_only() -> WorkflowGraph:
|
|||
],
|
||||
),
|
||||
),
|
||||
AgentRFNode(
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type="agentNode",
|
||||
position=Position(x=0, y=200),
|
||||
data=AgentNodeData(
|
||||
name="Collect Info",
|
||||
|
|
@ -455,8 +460,9 @@ def three_node_workflow_extraction_start_only() -> WorkflowGraph:
|
|||
extraction_enabled=False, # Explicitly disabled for testing
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=400),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
@ -503,8 +509,9 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
|||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -515,8 +522,9 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
AgentRFNode(
|
||||
RFNodeDTO(
|
||||
id="agent",
|
||||
type="agentNode",
|
||||
position=Position(x=0, y=200),
|
||||
data=AgentNodeData(
|
||||
name="Collect Info",
|
||||
|
|
@ -526,8 +534,9 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
|||
extraction_enabled=False, # Explicitly disabled for testing
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=400),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
|
|||
|
|
@ -63,7 +63,6 @@
|
|||
},
|
||||
"data": {
|
||||
"prompt": "Hello, I am Abhishek from Dograh. ",
|
||||
"is_static": true,
|
||||
"name": "Start Call",
|
||||
"is_start": true
|
||||
},
|
||||
|
|
@ -83,7 +82,6 @@
|
|||
},
|
||||
"data": {
|
||||
"prompt": "Thank you for calling Dograh. Have a great day!",
|
||||
"is_static": true,
|
||||
"name": "End Call"
|
||||
},
|
||||
"measured": {
|
||||
|
|
@ -161,4 +159,4 @@
|
|||
"y": 0,
|
||||
"zoom": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
0
api/tests/support/__init__.py
Normal file
0
api/tests/support/__init__.py
Normal file
103
api/tests/support/mcp_mock_server.py
Normal file
103
api/tests/support/mcp_mock_server.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""A real FastMCP server exposing 2 tools over streamable-HTTP, run in a
|
||||
background uvicorn thread on an ephemeral port. Used to exercise the real
|
||||
MCP protocol path in tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import socket
|
||||
import threading
|
||||
from typing import AsyncIterator
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastmcp import FastMCP
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
def _build_app(required_headers: dict[str, str] | None = None):
|
||||
mcp = FastMCP("mock-mcp")
|
||||
|
||||
@mcp.tool()
|
||||
def echo(text: str) -> str:
|
||||
"""Echo the provided text back."""
|
||||
return f"echo:{text}"
|
||||
|
||||
@mcp.tool()
|
||||
def add(a: int, b: int) -> int:
|
||||
"""Add two integers."""
|
||||
return a + b
|
||||
|
||||
# FastMCP 3.x: ASGI app for streamable-HTTP transport at "/mcp".
|
||||
app = mcp.http_app()
|
||||
if not required_headers:
|
||||
return app
|
||||
|
||||
normalized = {k.lower(): v for k, v in required_headers.items()}
|
||||
|
||||
async def guarded_app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
headers = {
|
||||
key.decode("latin-1").lower(): value.decode("latin-1")
|
||||
for key, value in scope.get("headers", [])
|
||||
}
|
||||
for header_name, expected_value in normalized.items():
|
||||
if headers.get(header_name) != expected_value:
|
||||
response = JSONResponse(
|
||||
{"detail": f"Missing or invalid header: {header_name}"},
|
||||
status_code=401,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
await app(scope, receive, send)
|
||||
|
||||
return guarded_app
|
||||
|
||||
|
||||
def _free_port() -> int:
|
||||
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def running_mcp_server(
|
||||
*, required_headers: dict[str, str] | None = None
|
||||
) -> AsyncIterator[str]:
|
||||
"""Yield the base streamable-HTTP URL of a live mock MCP server."""
|
||||
port = _free_port()
|
||||
config = uvicorn.Config(
|
||||
_build_app(required_headers), host="127.0.0.1", port=port, log_level="warning"
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
|
||||
base_url = f"http://127.0.0.1:{port}/mcp"
|
||||
server_ready = False
|
||||
for _ in range(50):
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
await client.get(base_url, timeout=0.5)
|
||||
server_ready = True
|
||||
break
|
||||
except Exception:
|
||||
await asyncio.sleep(0.1)
|
||||
if not server_ready:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
raise RuntimeError(f"Mock MCP server at {base_url} failed to start within 5s")
|
||||
try:
|
||||
yield base_url
|
||||
finally:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
if thread.is_alive():
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Mock MCP server thread did not terminate within 5s",
|
||||
ResourceWarning,
|
||||
)
|
||||
|
|
@ -153,45 +153,16 @@ async def test_verify_inbound_signature_rejects_missing_config_public_key():
|
|||
_, headers = _signed_headers(body)
|
||||
provider = _provider()
|
||||
|
||||
# REMOVE-AFTER 2026-05-15: drop the patch wrapper once
|
||||
# TELNYX_WEBHOOK_VERIFICATION_OPTIONAL is removed; the bare call below
|
||||
# will then assert the only path.
|
||||
with patch(
|
||||
"api.services.telephony.providers.telnyx.provider.TELNYX_WEBHOOK_VERIFICATION_OPTIONAL",
|
||||
False,
|
||||
):
|
||||
result = await provider.verify_inbound_signature(
|
||||
"https://example.test/api/v1/telephony/inbound/run",
|
||||
json.loads(body),
|
||||
headers,
|
||||
body,
|
||||
)
|
||||
result = await provider.verify_inbound_signature(
|
||||
"https://example.test/api/v1/telephony/inbound/run",
|
||||
json.loads(body),
|
||||
headers,
|
||||
body,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# REMOVE-AFTER 2026-05-15: delete this whole test along with the
|
||||
# TELNYX_WEBHOOK_VERIFICATION_OPTIONAL flag.
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_inbound_signature_allows_missing_key_when_optional_flag_set():
|
||||
body = _body()
|
||||
_, headers = _signed_headers(body)
|
||||
provider = _provider()
|
||||
|
||||
with patch(
|
||||
"api.services.telephony.providers.telnyx.provider.TELNYX_WEBHOOK_VERIFICATION_OPTIONAL",
|
||||
True,
|
||||
):
|
||||
result = await provider.verify_inbound_signature(
|
||||
"https://example.test/api/v1/telephony/inbound/run",
|
||||
json.loads(body),
|
||||
headers,
|
||||
body,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_inbound_signature_reads_headers_case_insensitively():
|
||||
body = _body()
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from dograh_sdk.typed import (
|
|||
Qa,
|
||||
StartCall,
|
||||
Trigger,
|
||||
Tuner,
|
||||
TypedNode,
|
||||
Webhook,
|
||||
)
|
||||
|
|
@ -50,6 +51,7 @@ def client() -> _StubClient:
|
|||
(Trigger, "trigger"),
|
||||
(Webhook, "webhook"),
|
||||
(Qa, "qa"),
|
||||
(Tuner, "tuner"),
|
||||
],
|
||||
ids=lambda v: v.__name__ if isinstance(v, type) else v,
|
||||
)
|
||||
|
|
@ -68,8 +70,15 @@ def test_typed_class_declares_spec_name(cls: type[TypedNode], expected_type: str
|
|||
inst = cls(name="t")
|
||||
elif cls is Webhook:
|
||||
inst = cls(name="wh")
|
||||
else: # Qa
|
||||
elif cls is Qa:
|
||||
inst = cls(name="qa")
|
||||
else: # Tuner
|
||||
inst = cls(
|
||||
name="tuner",
|
||||
tuner_agent_id="agent",
|
||||
tuner_workspace_id=1,
|
||||
tuner_api_key="secret",
|
||||
)
|
||||
assert inst.type == expected_type
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,37 @@ async def test_dto():
|
|||
assert dto is not None
|
||||
|
||||
|
||||
def test_dto_ignores_legacy_unknown_node_data_fields():
|
||||
dto = ReactFlowDTO.model_validate(
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "n1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"name": "Start",
|
||||
"prompt": "Hello",
|
||||
"is_static": True,
|
||||
"detect_voicemail": True,
|
||||
"wait_for_user_response": False,
|
||||
"wait_for_user_response_timeout": 2.5,
|
||||
"legacy_field": "ignored",
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
)
|
||||
|
||||
data = dto.nodes[0].data.model_dump()
|
||||
assert "is_static" not in data
|
||||
assert "detect_voicemail" not in data
|
||||
assert "wait_for_user_response" not in data
|
||||
assert "wait_for_user_response_timeout" not in data
|
||||
assert "legacy_field" not in data
|
||||
|
||||
|
||||
def test_sanitize_strips_ui_runtime_fields():
|
||||
definition = {
|
||||
"viewport": {"x": 0, "y": 0, "zoom": 1},
|
||||
|
|
|
|||
63
api/tests/test_mcp_auth.py
Normal file
63
api/tests/test_mcp_auth.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_mcp_request_accepts_bearer_authorization():
|
||||
user = MagicMock()
|
||||
user.id = 1
|
||||
user.selected_organization_id = 90
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.auth.get_http_headers",
|
||||
return_value={"authorization": "Bearer secret-api-key"},
|
||||
) as get_headers,
|
||||
patch(
|
||||
"api.mcp_server.auth._handle_api_key_auth",
|
||||
AsyncMock(return_value=user),
|
||||
) as handle_auth,
|
||||
):
|
||||
authed = await authenticate_mcp_request()
|
||||
|
||||
assert authed is user
|
||||
get_headers.assert_called_once_with(include={"authorization"})
|
||||
handle_auth.assert_awaited_once_with("secret-api-key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_mcp_request_accepts_x_api_key():
|
||||
user = MagicMock()
|
||||
user.id = 2
|
||||
user.selected_organization_id = 91
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.auth.get_http_headers",
|
||||
return_value={"x-api-key": "secret-api-key"},
|
||||
) as get_headers,
|
||||
patch(
|
||||
"api.mcp_server.auth._handle_api_key_auth",
|
||||
AsyncMock(return_value=user),
|
||||
) as handle_auth,
|
||||
):
|
||||
authed = await authenticate_mcp_request()
|
||||
|
||||
assert authed is user
|
||||
get_headers.assert_called_once_with(include={"authorization"})
|
||||
handle_auth.assert_awaited_once_with("secret-api-key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_mcp_request_rejects_missing_api_key():
|
||||
with patch("api.mcp_server.auth.get_http_headers", return_value={}) as get_headers:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await authenticate_mcp_request()
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Missing API key" in str(exc_info.value.detail)
|
||||
get_headers.assert_called_once_with(include={"authorization"})
|
||||
181
api/tests/test_mcp_custom_tool_manager.py
Normal file
181
api/tests/test_mcp_custom_tool_manager.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import ToolCategory
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.tests.support.mcp_mock_server import running_mcp_server
|
||||
|
||||
|
||||
def _mcp_tool():
|
||||
t = MagicMock()
|
||||
t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8]
|
||||
t.name = "Acme MCP"
|
||||
t.category = ToolCategory.MCP.value
|
||||
t.definition = {"type": "mcp", "config": {"url": "https://x/mcp"}}
|
||||
return t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tool_schemas_and_handler_for_mcp(monkeypatch):
|
||||
async with running_mcp_server() as base_url:
|
||||
tool = _mcp_tool()
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
|
||||
engine = MagicMock()
|
||||
engine._mcp_sessions = {tool.tool_uuid: session}
|
||||
registered = {}
|
||||
reg_kwargs = {}
|
||||
|
||||
def _reg(name, fn, **kw):
|
||||
registered[name] = fn
|
||||
reg_kwargs[name] = kw
|
||||
|
||||
engine.llm.register_function = _reg
|
||||
|
||||
mgr = CustomToolManager(engine)
|
||||
mgr.get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
|
||||
)
|
||||
|
||||
try:
|
||||
schemas = await mgr.get_tool_schemas([tool.tool_uuid])
|
||||
names = sorted(s.name for s in schemas)
|
||||
assert names == ["mcp__acme_mcp__add", "mcp__acme_mcp__echo"]
|
||||
|
||||
await mgr.register_handlers([tool.tool_uuid])
|
||||
assert "mcp__acme_mcp__echo" in registered
|
||||
assert reg_kwargs["mcp__acme_mcp__echo"]["timeout_secs"] == pytest.approx(
|
||||
15.0
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
class P:
|
||||
function_name = "mcp__acme_mcp__echo"
|
||||
arguments = {"text": "yo"}
|
||||
|
||||
async def result_callback(self, r, *, properties=None):
|
||||
captured["r"] = r
|
||||
|
||||
await registered["mcp__acme_mcp__echo"](P())
|
||||
assert "echo:yo" in str(captured["r"])
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unavailable_mcp_session_contributes_nothing(monkeypatch):
|
||||
tool = _mcp_tool()
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=1,
|
||||
sse_read_timeout_secs=1,
|
||||
)
|
||||
await session.start() # degrades
|
||||
|
||||
engine = MagicMock()
|
||||
engine._mcp_sessions = {tool.tool_uuid: session}
|
||||
mgr = CustomToolManager(engine)
|
||||
mgr.get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]))
|
||||
|
||||
schemas = await mgr.get_tool_schemas([tool.tool_uuid])
|
||||
assert schemas == []
|
||||
await mgr.register_handlers([tool.tool_uuid]) # must not raise
|
||||
|
||||
|
||||
def test_call_timeout_secs_is_read_timeout_plus_buffer():
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-abc123",
|
||||
tool_name="Acme MCP",
|
||||
url="https://x/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
assert session.call_timeout_secs == 25.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_node_mcp_filter_intersection(monkeypatch):
|
||||
async with running_mcp_server() as base_url:
|
||||
tool = _mcp_tool()
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
|
||||
engine = MagicMock()
|
||||
engine._mcp_sessions = {tool.tool_uuid: session}
|
||||
registered = {}
|
||||
engine.llm.register_function = lambda name, fn, **kw: registered.__setitem__(
|
||||
name, fn
|
||||
)
|
||||
|
||||
mgr = CustomToolManager(engine)
|
||||
mgr.get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
|
||||
)
|
||||
try:
|
||||
# Allow only raw "echo" for this node
|
||||
filters = {tool.tool_uuid: ["echo"]}
|
||||
schemas = await mgr.get_tool_schemas(
|
||||
[tool.tool_uuid], mcp_tool_filters=filters
|
||||
)
|
||||
# Check only "echo" schema returned (namespaced name depends on tool.name)
|
||||
assert len(schemas) == 1
|
||||
assert all("echo" in s.name for s in schemas)
|
||||
|
||||
await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters=filters)
|
||||
assert len(registered) == 1
|
||||
assert all("echo" in k for k in registered)
|
||||
|
||||
# No filter entry for this uuid = none (default-none)
|
||||
registered.clear()
|
||||
result = await mgr.get_tool_schemas([tool.tool_uuid], mcp_tool_filters={})
|
||||
assert result == []
|
||||
await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters={})
|
||||
assert registered == {}
|
||||
|
||||
# mcp_tool_filters=None = backward-compatible (all tools)
|
||||
registered.clear()
|
||||
all_schemas = await mgr.get_tool_schemas([tool.tool_uuid])
|
||||
assert len(all_schemas) == 2 # both echo and add
|
||||
await mgr.register_handlers([tool.tool_uuid])
|
||||
assert len(registered) == 2 # both handlers registered
|
||||
finally:
|
||||
await session.close()
|
||||
107
api/tests/test_mcp_integration.py
Normal file
107
api/tests/test_mcp_integration.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import ToolCategory
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.tests.support.mcp_mock_server import running_mcp_server
|
||||
|
||||
|
||||
def _mcp_tool(url: str):
|
||||
t = MagicMock()
|
||||
t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8]
|
||||
t.name = "Acme MCP"
|
||||
t.category = ToolCategory.MCP.value
|
||||
t.definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http", "url": url},
|
||||
}
|
||||
return t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_opens_and_closes_mcp_sessions(monkeypatch):
|
||||
async with running_mcp_server() as base_url:
|
||||
tool = _mcp_tool(base_url)
|
||||
|
||||
engine = PipecatEngine.__new__(PipecatEngine)
|
||||
node = MagicMock()
|
||||
node.tool_uuids = [tool.tool_uuid]
|
||||
workflow = MagicMock()
|
||||
workflow.nodes = {"n1": node}
|
||||
engine.workflow = workflow
|
||||
engine._mcp_sessions = {}
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_credential_by_uuid", AsyncMock(return_value=None)
|
||||
)
|
||||
engine._get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
await engine._open_mcp_sessions()
|
||||
try:
|
||||
assert tool.tool_uuid in engine._mcp_sessions
|
||||
sess = engine._mcp_sessions[tool.tool_uuid]
|
||||
assert sess.available is True
|
||||
assert len(sess.function_schemas()) == 2
|
||||
finally:
|
||||
await engine._close_mcp_sessions()
|
||||
assert engine._mcp_sessions == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_mcp_sessions_swallows_db_error(monkeypatch):
|
||||
engine = PipecatEngine.__new__(PipecatEngine)
|
||||
node = MagicMock()
|
||||
node.tool_uuids = ["uuid-deadbeef"]
|
||||
workflow = MagicMock()
|
||||
workflow.nodes = {"n1": node}
|
||||
engine.workflow = workflow
|
||||
engine._mcp_sessions = {}
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client,
|
||||
"get_tools_by_uuids",
|
||||
AsyncMock(side_effect=RuntimeError("db down")),
|
||||
)
|
||||
engine._get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
# Must NOT raise
|
||||
await engine._open_mcp_sessions()
|
||||
assert engine._mcp_sessions == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_mcp_sessions_skips_tool_when_credential_fetch_fails(monkeypatch):
|
||||
tool = _mcp_tool("http://127.0.0.1:1/mcp")
|
||||
tool.definition["config"]["credential_uuid"] = "cred-1234"
|
||||
|
||||
engine = PipecatEngine.__new__(PipecatEngine)
|
||||
node = MagicMock()
|
||||
node.tool_uuids = [tool.tool_uuid]
|
||||
workflow = MagicMock()
|
||||
workflow.nodes = {"n1": node}
|
||||
engine.workflow = workflow
|
||||
engine._mcp_sessions = {}
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]))
|
||||
monkeypatch.setattr(
|
||||
db_client,
|
||||
"get_credential_by_uuid",
|
||||
AsyncMock(side_effect=RuntimeError("cred store down")),
|
||||
)
|
||||
engine._get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
# Must NOT raise, and must skip the tool (no futile unauthenticated start)
|
||||
await engine._open_mcp_sessions()
|
||||
assert engine._mcp_sessions == {}
|
||||
112
api/tests/test_mcp_tool_definition.py
Normal file
112
api/tests/test_mcp_tool_definition.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
import importlib
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import ToolCategory
|
||||
from api.routes.tool import McpToolConfig as RouteMcpToolConfig
|
||||
from api.routes.tool import McpToolDefinition as RouteMcpToolDefinition
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpDefinitionError,
|
||||
McpToolConfig,
|
||||
McpToolDefinition,
|
||||
namespace_function_name,
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_category_exists():
|
||||
assert ToolCategory.MCP.value == "mcp"
|
||||
assert ToolCategory("mcp") is ToolCategory.MCP
|
||||
|
||||
|
||||
def test_mcp_migration_present_and_chained(monkeypatch):
|
||||
mod = importlib.import_module(
|
||||
"api.alembic.versions.0a1b2c3d4e5f_add_mcp_in_toolcategory"
|
||||
)
|
||||
assert mod.revision == "0a1b2c3d4e5f"
|
||||
assert mod.down_revision == "4c1f1e3e8ef2"
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_sync_enum_values(**kwargs):
|
||||
calls.append(kwargs)
|
||||
|
||||
monkeypatch.setattr(mod.op, "sync_enum_values", fake_sync_enum_values)
|
||||
|
||||
mod.upgrade()
|
||||
mod.downgrade()
|
||||
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["enum_name"] == "tool_category"
|
||||
assert "mcp" in calls[0]["new_values"]
|
||||
assert "mcp" not in calls[1]["new_values"]
|
||||
|
||||
|
||||
def test_route_reuses_shared_mcp_models():
|
||||
assert RouteMcpToolConfig is McpToolConfig
|
||||
assert RouteMcpToolDefinition is McpToolDefinition
|
||||
|
||||
|
||||
def test_validate_mcp_definition_ok():
|
||||
cfg = validate_mcp_definition(
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://acme.example.com/mcp",
|
||||
"credential_uuid": "cred-123",
|
||||
"tools_filter": ["lookup_patient"],
|
||||
"timeout_secs": 30,
|
||||
"sse_read_timeout_secs": 300,
|
||||
},
|
||||
}
|
||||
)
|
||||
assert cfg["url"] == "https://acme.example.com/mcp"
|
||||
assert cfg["transport"] == "streamable_http"
|
||||
assert cfg["tools_filter"] == ["lookup_patient"]
|
||||
assert cfg["timeout_secs"] == 30
|
||||
assert cfg["sse_read_timeout_secs"] == 300
|
||||
assert cfg["credential_uuid"] == "cred-123"
|
||||
|
||||
|
||||
def test_validate_mcp_definition_defaults():
|
||||
cfg = validate_mcp_definition({"type": "mcp", "config": {"url": "https://x/mcp"}})
|
||||
assert cfg["transport"] == "streamable_http"
|
||||
assert cfg["tools_filter"] == []
|
||||
assert cfg["timeout_secs"] == 30
|
||||
assert cfg["sse_read_timeout_secs"] == 300
|
||||
assert cfg["credential_uuid"] is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"definition",
|
||||
[
|
||||
{"type": "mcp", "config": {}},
|
||||
{"type": "mcp", "config": {"url": ""}},
|
||||
{"type": "mcp", "config": {"url": "ftp://x"}},
|
||||
{"type": "mcp"},
|
||||
{"type": "mcp", "config": {"url": "https://x", "transport": "stdio"}},
|
||||
],
|
||||
)
|
||||
def test_validate_mcp_definition_rejects(definition):
|
||||
with pytest.raises(McpDefinitionError):
|
||||
validate_mcp_definition(definition)
|
||||
|
||||
|
||||
def test_validate_mcp_definition_zero_timeout_preserved():
|
||||
cfg = validate_mcp_definition(
|
||||
{"type": "mcp", "config": {"url": "https://x/mcp", "timeout_secs": 0}}
|
||||
)
|
||||
assert cfg["timeout_secs"] == 0
|
||||
|
||||
|
||||
def test_namespace_function_name():
|
||||
assert (
|
||||
namespace_function_name("Acme MCP", "lookup_patient")
|
||||
== "mcp__acme_mcp__lookup_patient"
|
||||
)
|
||||
assert (
|
||||
namespace_function_name("", "ping", fallback="abcd1234")
|
||||
== "mcp__abcd1234__ping"
|
||||
)
|
||||
437
api/tests/test_mcp_tool_route.py
Normal file
437
api/tests/test_mcp_tool_route.py
Normal file
|
|
@ -0,0 +1,437 @@
|
|||
"""Route-level tests for the MCP tool definition schema.
|
||||
|
||||
These tests exercise the Pydantic request models (CreateToolRequest /
|
||||
UpdateToolRequest) to catch schema gaps at the route/request-model layer —
|
||||
the layer where the pre-fix defect lived (HTTP 422 on every MCP tool
|
||||
creation attempt).
|
||||
|
||||
Test coverage:
|
||||
- CreateToolRequest validates a valid MCP definition (was 422 before Part A).
|
||||
- UpdateToolRequest validates a valid MCP definition.
|
||||
- Invalid MCP bodies are rejected (ftp:// url, missing url).
|
||||
- Round-trip: validated definition dict passes through validate_mcp_definition
|
||||
unchanged, proving the request schema and call-time validator agree.
|
||||
- Full HTTP round-trip via the ASGI test client (POST /api/v1/tools/).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.routes.tool import CreateToolRequest, McpToolDefinition, UpdateToolRequest
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
||||
# ── Canonical valid MCP request body ─────────────────────────────────────────
|
||||
|
||||
VALID_MCP_DEFINITION = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://x/mcp",
|
||||
"credential_uuid": None,
|
||||
"tools_filter": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ── Part A regression: CreateToolRequest / UpdateToolRequest validation ───────
|
||||
|
||||
|
||||
def test_create_tool_request_accepts_mcp_definition():
|
||||
"""CreateToolRequest must accept an MCP definition (was HTTP 422 before fix)."""
|
||||
req = CreateToolRequest(
|
||||
name="My MCP Tool",
|
||||
description="Integration via MCP",
|
||||
category="mcp",
|
||||
definition=VALID_MCP_DEFINITION,
|
||||
)
|
||||
assert isinstance(req.definition, McpToolDefinition)
|
||||
assert req.definition.type == "mcp"
|
||||
assert req.definition.config.url == "https://x/mcp"
|
||||
assert req.definition.config.transport == "streamable_http"
|
||||
assert req.definition.config.credential_uuid is None
|
||||
assert req.definition.config.tools_filter == []
|
||||
assert req.definition.config.timeout_secs == 30
|
||||
assert req.definition.config.sse_read_timeout_secs == 300
|
||||
|
||||
|
||||
def test_update_tool_request_accepts_mcp_definition():
|
||||
"""UpdateToolRequest must also accept an MCP definition."""
|
||||
req = UpdateToolRequest(
|
||||
name="Updated MCP Tool",
|
||||
definition=VALID_MCP_DEFINITION,
|
||||
)
|
||||
assert isinstance(req.definition, McpToolDefinition)
|
||||
assert req.definition.type == "mcp"
|
||||
assert req.definition.config.url == "https://x/mcp"
|
||||
|
||||
|
||||
def test_create_tool_request_accepts_mcp_with_all_fields():
|
||||
"""All optional MCP config fields are accepted and preserved."""
|
||||
req = CreateToolRequest(
|
||||
name="Full MCP Tool",
|
||||
category="mcp",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://acme.example.com/mcp",
|
||||
"credential_uuid": "cred-abc-123",
|
||||
"tools_filter": ["lookup_patient", "schedule_appointment"],
|
||||
"timeout_secs": 60,
|
||||
"sse_read_timeout_secs": 600,
|
||||
},
|
||||
},
|
||||
)
|
||||
cfg = req.definition.config # type: ignore[union-attr]
|
||||
assert cfg.url == "https://acme.example.com/mcp"
|
||||
assert cfg.credential_uuid == "cred-abc-123"
|
||||
assert cfg.tools_filter == ["lookup_patient", "schedule_appointment"]
|
||||
assert cfg.timeout_secs == 60
|
||||
assert cfg.sse_read_timeout_secs == 600
|
||||
|
||||
|
||||
# ── Invalid bodies are rejected ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"definition",
|
||||
[
|
||||
# ftp:// URL — rejected by McpToolConfig.validate_url
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http", "url": "ftp://x/mcp"},
|
||||
},
|
||||
# Empty url — rejected by McpToolConfig.validate_url
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http", "url": ""},
|
||||
},
|
||||
# Missing url — rejected by McpToolConfig (required field)
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http"},
|
||||
},
|
||||
# Unsupported transport — rejected because Literal["streamable_http"] constraint
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"url": "https://x/mcp", "transport": "stdio"},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_create_tool_request_rejects_invalid_mcp_definition(definition):
|
||||
"""Invalid MCP definitions must raise ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
CreateToolRequest(
|
||||
name="Bad MCP Tool",
|
||||
category="mcp",
|
||||
definition=definition,
|
||||
)
|
||||
|
||||
|
||||
# ── Round-trip compatibility: request schema ↔ validate_mcp_definition ───────
|
||||
|
||||
|
||||
def test_mcp_definition_round_trips_through_validate_mcp_definition():
|
||||
"""The dict produced by CreateToolRequest.definition.model_dump() must be
|
||||
accepted by validate_mcp_definition without raising, and the result must
|
||||
contain the expected fields. This proves the request-layer schema and the
|
||||
call-time validator agree on the stored config shape."""
|
||||
req = CreateToolRequest(
|
||||
name="Round-Trip MCP Tool",
|
||||
category="mcp",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://roundtrip.example.com/mcp",
|
||||
"credential_uuid": "cred-rt-456",
|
||||
"tools_filter": ["ping"],
|
||||
"timeout_secs": 45,
|
||||
"sse_read_timeout_secs": 400,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Simulate what the route does: persist definition as a plain dict
|
||||
persisted = req.definition.model_dump() # type: ignore[union-attr]
|
||||
|
||||
# validate_mcp_definition must accept the persisted shape without raising
|
||||
normalized = validate_mcp_definition(persisted)
|
||||
|
||||
assert normalized["url"] == "https://roundtrip.example.com/mcp"
|
||||
assert normalized["transport"] == "streamable_http"
|
||||
assert normalized["credential_uuid"] == "cred-rt-456"
|
||||
assert normalized["tools_filter"] == ["ping"]
|
||||
assert normalized["timeout_secs"] == 45
|
||||
assert normalized["sse_read_timeout_secs"] == 400
|
||||
|
||||
|
||||
def test_mcp_definition_round_trip_defaults():
|
||||
"""Round-trip with minimal body: defaults fill in correctly and
|
||||
validate_mcp_definition agrees on them."""
|
||||
req = CreateToolRequest(
|
||||
name="Minimal MCP Tool",
|
||||
category="mcp",
|
||||
definition=VALID_MCP_DEFINITION,
|
||||
)
|
||||
|
||||
persisted = req.definition.model_dump() # type: ignore[union-attr]
|
||||
normalized = validate_mcp_definition(persisted)
|
||||
|
||||
assert normalized["transport"] == "streamable_http"
|
||||
assert normalized["tools_filter"] == []
|
||||
assert normalized["timeout_secs"] == 30
|
||||
assert normalized["sse_read_timeout_secs"] == 300
|
||||
assert normalized["credential_uuid"] is None
|
||||
# Part B: auth_header / auth_scheme must NOT be present in the normalized
|
||||
# config dict (they were dead config removed in the fix)
|
||||
assert "auth_header" not in normalized
|
||||
assert "auth_scheme" not in normalized
|
||||
|
||||
|
||||
# ── Full HTTP round-trip via ASGI test client ─────────────────────────────────
|
||||
|
||||
|
||||
async def test_post_tool_mcp_returns_200(test_client_factory, db_session):
|
||||
"""POST /api/v1/tools/ with an MCP definition must return HTTP 200 and
|
||||
persist the definition with type='mcp'. Before Part A this always
|
||||
returned 422."""
|
||||
# Create a user and an organization, then link them so the route's
|
||||
# selected_organization_id check passes.
|
||||
user, _ = await db_session.get_or_create_user_by_provider_id("mcp_route_test_user")
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
"mcp_route_test_org", user.id
|
||||
)
|
||||
await db_session.update_user_selected_organization(user.id, org.id)
|
||||
# Reload the user so selected_organization_id is populated on the object.
|
||||
user = await db_session.get_user_by_id(user.id)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/tools/",
|
||||
json={
|
||||
"name": "HTTP Round-Trip MCP Tool",
|
||||
"description": "Testing the full route",
|
||||
"category": "mcp",
|
||||
"definition": {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://roundtrip.example.com/mcp",
|
||||
"credential_uuid": None,
|
||||
"tools_filter": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, (
|
||||
f"Expected 200, got {response.status_code}: {response.text}"
|
||||
)
|
||||
body = response.json()
|
||||
assert body["definition"]["type"] == "mcp"
|
||||
assert body["definition"]["config"]["url"] == "https://roundtrip.example.com/mcp"
|
||||
assert body["category"] == "mcp"
|
||||
|
||||
|
||||
async def test_post_tool_mcp_invalid_url_returns_422(test_client_factory, db_session):
|
||||
"""POST /api/v1/tools/ with an ftp:// URL must return HTTP 422."""
|
||||
user, _ = await db_session.get_or_create_user_by_provider_id(
|
||||
"mcp_route_test_user_422"
|
||||
)
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
"mcp_route_test_org_422", user.id
|
||||
)
|
||||
await db_session.update_user_selected_organization(user.id, org.id)
|
||||
user = await db_session.get_user_by_id(user.id)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/tools/",
|
||||
json={
|
||||
"name": "Bad MCP Tool",
|
||||
"category": "mcp",
|
||||
"definition": {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "ftp://invalid.example.com/mcp",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ── Task 6: discovered_tools field and _populate_discovered_tools helper ──────
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from api.routes.tool import McpToolConfig, _populate_discovered_tools
|
||||
|
||||
|
||||
def test_mcp_config_accepts_discovered_tools():
|
||||
cfg = McpToolConfig(
|
||||
url="https://x/mcp",
|
||||
discovered_tools=[{"name": "echo", "description": "Echo"}],
|
||||
)
|
||||
assert cfg.discovered_tools == [{"name": "echo", "description": "Echo"}]
|
||||
# Defaults to [] when omitted
|
||||
assert McpToolConfig(url="https://x/mcp").discovered_tools == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_overwrites_cache(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
|
||||
)
|
||||
definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"url": "https://x/mcp",
|
||||
"tools_filter": [],
|
||||
"discovered_tools": [{"name": "stale", "description": "old"}],
|
||||
},
|
||||
}
|
||||
out = await _populate_discovered_tools(definition, organization_id=1)
|
||||
assert out["config"]["discovered_tools"] == [
|
||||
{"name": "echo", "description": "Echo"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_non_mcp_is_noop():
|
||||
definition = {"schema_version": 1, "type": "http_api", "config": {}}
|
||||
out = await _populate_discovered_tools(definition, organization_id=1)
|
||||
assert out == definition # untouched
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_server_down_sets_empty(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(side_effect=RuntimeError("connection refused")),
|
||||
)
|
||||
definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"url": "https://x/mcp", "tools_filter": []},
|
||||
}
|
||||
out = await _populate_discovered_tools(definition, organization_id=1)
|
||||
assert out["config"]["discovered_tools"] == []
|
||||
|
||||
|
||||
# ── Task 7: POST /{tool_uuid}/mcp/refresh ─────────────────────────────────────
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.routes.tool import refresh_mcp_tools
|
||||
|
||||
|
||||
def _fake_user(org_id=1):
|
||||
u = MagicMock()
|
||||
u.selected_organization_id = org_id
|
||||
u.id = 1
|
||||
u.provider_id = "p1"
|
||||
return u
|
||||
|
||||
|
||||
def _mcp_tool_model(org_id=1):
|
||||
t = MagicMock()
|
||||
t.tool_uuid = "tu-mcp"
|
||||
t.name = "Mock MCP"
|
||||
t.category = "mcp"
|
||||
t.definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"url": "https://x/mcp", "tools_filter": []},
|
||||
}
|
||||
return t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_success(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client,
|
||||
"update_tool",
|
||||
AsyncMock(return_value=tool),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
|
||||
)
|
||||
resp = await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert resp.discovered_tools == [{"name": "echo", "description": "Echo"}]
|
||||
assert resp.error is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_server_down_returns_200_with_error(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
monkeypatch.setattr(tool_mod.db_client, "update_tool", AsyncMock(return_value=tool))
|
||||
monkeypatch.setattr(tool_mod, "discover_mcp_tools", AsyncMock(return_value=[]))
|
||||
resp = await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert resp.discovered_tools == []
|
||||
assert resp.error # non-empty human-readable message
|
||||
# update_tool should NOT be called when discovery returns empty
|
||||
tool_mod.db_client.update_tool.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_non_mcp_is_400(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
tool.category = "http_api"
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert ei.value.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_not_found_is_404(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=None)
|
||||
)
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
await refresh_mcp_tools("nope", user=_fake_user())
|
||||
assert ei.value.status_code == 404
|
||||
274
api/tests/test_mcp_tool_session.py
Normal file
274
api/tests/test_mcp_tool_session.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.mcp_tool_session import (
|
||||
McpToolSession,
|
||||
build_streamable_http_params,
|
||||
discover_mcp_tools,
|
||||
)
|
||||
from api.tests.support.mcp_mock_server import running_mcp_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_server_starts_and_serves():
|
||||
async with running_mcp_server() as base_url:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(base_url, timeout=5.0)
|
||||
assert resp.status_code in (400, 404, 405, 406)
|
||||
|
||||
|
||||
def test_build_streamable_http_params_with_credential():
|
||||
cred = MagicMock()
|
||||
cred.credential_type = "bearer_token"
|
||||
cred.credential_data = {"token": "abc"}
|
||||
params = build_streamable_http_params(
|
||||
url="https://acme.example.com/mcp",
|
||||
credential=cred,
|
||||
timeout_secs=30,
|
||||
sse_read_timeout_secs=300,
|
||||
)
|
||||
assert params.url == "https://acme.example.com/mcp"
|
||||
assert params.headers == {"Authorization": "Bearer abc"}
|
||||
assert params.timeout == timedelta(seconds=30)
|
||||
assert params.sse_read_timeout == timedelta(seconds=300)
|
||||
|
||||
|
||||
def test_build_streamable_http_params_no_credential():
|
||||
params = build_streamable_http_params(
|
||||
url="https://acme.example.com/mcp",
|
||||
credential=None,
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
assert params.headers is None or params.headers == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_start_passes_auth_header_to_real_server():
|
||||
cred = MagicMock()
|
||||
cred.credential_type = "bearer_token"
|
||||
cred.credential_data = {"token": "abc"}
|
||||
|
||||
async with running_mcp_server(
|
||||
required_headers={"Authorization": "Bearer abc"}
|
||||
) as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-auth-ok",
|
||||
tool_name="Secure MCP",
|
||||
url=base_url,
|
||||
credential=cred,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
assert session.available is True
|
||||
names = sorted(s.name for s in session.function_schemas())
|
||||
assert names == ["mcp__secure_mcp__add", "mcp__secure_mcp__echo"]
|
||||
result = await session.call("mcp__secure_mcp__echo", {"text": "hi"})
|
||||
assert "echo:hi" in result
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_auth_failure_degrades_not_raises():
|
||||
async with running_mcp_server(
|
||||
required_headers={"Authorization": "Bearer abc"}
|
||||
) as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-auth-fail",
|
||||
tool_name="Secure MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=2,
|
||||
sse_read_timeout_secs=2,
|
||||
)
|
||||
await session.start() # must degrade instead of raising on 401
|
||||
try:
|
||||
assert session.available is False
|
||||
assert session.function_schemas() == []
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_start_lists_and_calls_real_server():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
assert session.available is True
|
||||
schemas = session.function_schemas()
|
||||
names = sorted(s.name for s in schemas)
|
||||
assert names == ["mcp__acme_mcp__add", "mcp__acme_mcp__echo"]
|
||||
result = await session.call("mcp__acme_mcp__echo", {"text": "hi"})
|
||||
assert "echo:hi" in result
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_tools_filter_applied():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=["echo"],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
names = sorted(s.name for s in session.function_schemas())
|
||||
assert names == ["mcp__acme_mcp__echo"]
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_unreachable_degrades_not_raises():
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=2,
|
||||
sse_read_timeout_secs=2,
|
||||
)
|
||||
await session.start() # must NOT raise
|
||||
assert session.available is False
|
||||
assert session.function_schemas() == []
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_on_unavailable_session_raises():
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=2,
|
||||
sse_read_timeout_secs=2,
|
||||
)
|
||||
await session.start()
|
||||
with pytest.raises(RuntimeError):
|
||||
await session.call("mcp__acme_mcp__echo", {"text": "x"})
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_unknown_function_raises():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
with pytest.raises(RuntimeError):
|
||||
await session.call("mcp__acme_mcp__does_not_exist", {})
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_schemas_filter_by_raw_name():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="t-filter",
|
||||
tool_name="Mock MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
# No arg = all (backward compatible)
|
||||
all_names = sorted(s.name for s in session.function_schemas())
|
||||
assert all_names == ["mcp__mock_mcp__add", "mcp__mock_mcp__echo"]
|
||||
|
||||
# Allow only raw "echo"
|
||||
only_echo = session.function_schemas(allowed_raw_names={"echo"})
|
||||
assert [s.name for s in only_echo] == ["mcp__mock_mcp__echo"]
|
||||
|
||||
# Empty set = none (default-none semantics)
|
||||
assert session.function_schemas(allowed_raw_names=set()) == []
|
||||
|
||||
# Unknown raw name = skipped (pure intersection)
|
||||
assert session.function_schemas(allowed_raw_names={"nope"}) == []
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_mcp_tools_success():
|
||||
async with running_mcp_server() as base_url:
|
||||
tools = await discover_mcp_tools(
|
||||
url=base_url,
|
||||
credential=None,
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
names = sorted(t["name"] for t in tools)
|
||||
assert names == ["add", "echo"]
|
||||
by_name = {t["name"]: t for t in tools}
|
||||
assert by_name["echo"]["description"] # non-empty description
|
||||
assert set(by_name["echo"]) == {"name", "description"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_mcp_tools_server_down_returns_empty():
|
||||
# Unroutable port, short timeouts: must degrade to [] (never raise).
|
||||
tools = await discover_mcp_tools(
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
timeout_secs=1,
|
||||
sse_read_timeout_secs=1,
|
||||
)
|
||||
assert tools == []
|
||||
|
||||
|
||||
def test_agent_node_data_carries_mcp_tool_filters():
|
||||
from api.services.workflow.dto import AgentNodeData, NodeType
|
||||
from api.services.workflow.workflow_graph import Node
|
||||
|
||||
data = AgentNodeData(
|
||||
name="N1",
|
||||
tool_uuids=["tu-1"],
|
||||
mcp_tool_filters={"tu-1": ["echo"]},
|
||||
)
|
||||
assert data.mcp_tool_filters == {"tu-1": ["echo"]}
|
||||
|
||||
node = Node("n1", NodeType.agentNode, data)
|
||||
assert node.mcp_tool_filters == {"tu-1": ["echo"]}
|
||||
|
||||
# Absent field defaults to None (backward compatible)
|
||||
data2 = AgentNodeData(name="N2")
|
||||
assert data2.mcp_tool_filters is None
|
||||
assert Node("n2", NodeType.agentNode, data2).mcp_tool_filters is None
|
||||
|
|
@ -14,7 +14,12 @@ import re
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.dto import NodeType, ReactFlowDTO
|
||||
from api.services.workflow.dto import (
|
||||
ReactFlowDTO,
|
||||
all_node_type_names,
|
||||
get_node_data_model,
|
||||
)
|
||||
from api.services.workflow.node_data import BaseNodeData
|
||||
from api.services.workflow.node_specs import (
|
||||
NodeSpec,
|
||||
PropertySpec,
|
||||
|
|
@ -118,9 +123,9 @@ def test_fixed_collection_has_sub_properties(spec: NodeSpec):
|
|||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_spec_name_matches_dto_discriminator(spec: NodeSpec):
|
||||
valid_names = {t.value for t in NodeType}
|
||||
valid_names = all_node_type_names()
|
||||
assert spec.name in valid_names, (
|
||||
f"NodeSpec {spec.name!r} doesn't match any NodeType discriminator. "
|
||||
f"NodeSpec {spec.name!r} doesn't match any registered node type. "
|
||||
f"Valid: {sorted(valid_names)}"
|
||||
)
|
||||
|
||||
|
|
@ -187,10 +192,226 @@ def test_examples_validate_against_dto(spec: NodeSpec):
|
|||
|
||||
|
||||
def test_all_dto_types_have_specs():
|
||||
"""Every NodeType discriminator value must have a registered NodeSpec —
|
||||
catches the case where someone adds a new node type to dto.py but
|
||||
forgets to author a spec."""
|
||||
"""Every registered node type must have a registered NodeSpec."""
|
||||
spec_names = {s.name for s in all_specs()}
|
||||
type_values = {t.value for t in NodeType}
|
||||
type_values = all_node_type_names()
|
||||
missing = type_values - spec_names
|
||||
assert not missing, f"NodeType discriminators without specs: {sorted(missing)}"
|
||||
assert not missing, f"Registered node types without specs: {sorted(missing)}"
|
||||
|
||||
|
||||
def test_all_registered_node_models_inherit_base_node_data():
|
||||
for type_name in sorted(all_node_type_names()):
|
||||
data_model = get_node_data_model(type_name)
|
||||
assert data_model is not None, f"{type_name}: missing node data model"
|
||||
assert issubclass(data_model, BaseNodeData), (
|
||||
f"{type_name}: node data model must inherit BaseNodeData"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("spec_name", "expected_order"),
|
||||
[
|
||||
(
|
||||
"startCall",
|
||||
[
|
||||
"name",
|
||||
"greeting_type",
|
||||
"greeting",
|
||||
"greeting_recording_id",
|
||||
"prompt",
|
||||
"allow_interrupt",
|
||||
"add_global_prompt",
|
||||
"delayed_start",
|
||||
"delayed_start_duration",
|
||||
"extraction_enabled",
|
||||
"extraction_prompt",
|
||||
"extraction_variables",
|
||||
"tool_uuids",
|
||||
"document_uuids",
|
||||
"pre_call_fetch_enabled",
|
||||
"pre_call_fetch_url",
|
||||
"pre_call_fetch_credential_uuid",
|
||||
],
|
||||
),
|
||||
(
|
||||
"agentNode",
|
||||
[
|
||||
"name",
|
||||
"prompt",
|
||||
"allow_interrupt",
|
||||
"add_global_prompt",
|
||||
"extraction_enabled",
|
||||
"extraction_prompt",
|
||||
"extraction_variables",
|
||||
"tool_uuids",
|
||||
"document_uuids",
|
||||
],
|
||||
),
|
||||
(
|
||||
"endCall",
|
||||
[
|
||||
"name",
|
||||
"prompt",
|
||||
"add_global_prompt",
|
||||
"extraction_enabled",
|
||||
"extraction_prompt",
|
||||
"extraction_variables",
|
||||
],
|
||||
),
|
||||
("globalNode", ["name", "prompt"]),
|
||||
("trigger", ["name", "enabled", "trigger_path"]),
|
||||
(
|
||||
"webhook",
|
||||
[
|
||||
"name",
|
||||
"enabled",
|
||||
"http_method",
|
||||
"endpoint_url",
|
||||
"credential_uuid",
|
||||
"custom_headers",
|
||||
"payload_template",
|
||||
],
|
||||
),
|
||||
(
|
||||
"qa",
|
||||
[
|
||||
"name",
|
||||
"qa_enabled",
|
||||
"qa_system_prompt",
|
||||
"qa_min_call_duration",
|
||||
"qa_voicemail_calls",
|
||||
"qa_sample_rate",
|
||||
"qa_use_workflow_llm",
|
||||
"qa_provider",
|
||||
"qa_model",
|
||||
"qa_api_key",
|
||||
"qa_endpoint",
|
||||
],
|
||||
),
|
||||
(
|
||||
"tuner",
|
||||
[
|
||||
"name",
|
||||
"tuner_enabled",
|
||||
"tuner_agent_id",
|
||||
"tuner_workspace_id",
|
||||
"tuner_api_key",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_node_spec_property_order_stable(spec_name: str, expected_order: list[str]):
|
||||
spec = next(spec for spec in all_specs() if spec.name == spec_name)
|
||||
assert [prop.name for prop in spec.properties] == expected_order
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# `to_mcp_dict` projection — the lean view served by the `get_node_type`
|
||||
# MCP tool. UI-only metadata is dropped so it doesn't poison LLM context;
|
||||
# the full spec stays available to the frontend and SDK via other paths.
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
# Keys that are UI-rendering concerns and must never reach the LLM view, at
|
||||
# either the node or property level.
|
||||
_UI_ONLY_KEYS = frozenset(
|
||||
{
|
||||
"display_name",
|
||||
"icon",
|
||||
"category",
|
||||
"version",
|
||||
"placeholder",
|
||||
"display_options",
|
||||
"editor",
|
||||
"extra",
|
||||
"label", # PropertyOption display string
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _walk_dicts(node):
|
||||
"""Yield every dict nested anywhere inside a projected structure."""
|
||||
if isinstance(node, dict):
|
||||
yield node
|
||||
for value in node.values():
|
||||
yield from _walk_dicts(value)
|
||||
elif isinstance(node, list):
|
||||
for item in node:
|
||||
yield from _walk_dicts(item)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_to_mcp_dict_drops_ui_only_keys(spec: NodeSpec):
|
||||
projected = spec.to_mcp_dict()
|
||||
for d in _walk_dicts(projected):
|
||||
leaked = _UI_ONLY_KEYS & d.keys()
|
||||
assert not leaked, f"{spec.name}: UI-only keys leaked into LLM view: {leaked}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_to_mcp_dict_omits_null_and_empty(spec: NodeSpec):
|
||||
"""The lean view never emits null values — absent means unset/optional,
|
||||
which is what halves the noise versus the full model dump."""
|
||||
for d in _walk_dicts(spec.to_mcp_dict()):
|
||||
for key, value in d.items():
|
||||
assert value is not None, f"{spec.name}: {key!r} emitted as null"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", all_specs(), ids=lambda s: s.name)
|
||||
def test_to_mcp_dict_keeps_property_essentials(spec: NodeSpec):
|
||||
"""Every property in the LLM view carries the minimum an LLM needs to
|
||||
author a value: machine name, type, and a description."""
|
||||
|
||||
def _check(props: list[dict]):
|
||||
for prop in props:
|
||||
assert prop.get("name"), f"{spec.name}: property missing name"
|
||||
assert prop.get("type"), f"{spec.name}.{prop.get('name')}: missing type"
|
||||
assert prop.get("description"), (
|
||||
f"{spec.name}.{prop.get('name')}: missing description"
|
||||
)
|
||||
if prop.get("properties"):
|
||||
_check(prop["properties"])
|
||||
|
||||
_check(spec.to_mcp_dict()["properties"])
|
||||
|
||||
|
||||
def test_to_mcp_dict_retains_authoring_signal_startcall():
|
||||
"""startCall is the richest core node — lock in that the projection
|
||||
keeps the fields an LLM actually authors against while shedding the rest."""
|
||||
spec = next(s for s in all_specs() if s.name == "startCall")
|
||||
projected = spec.to_mcp_dict()
|
||||
|
||||
assert set(projected) == {
|
||||
"name",
|
||||
"description",
|
||||
"llm_hint",
|
||||
"properties",
|
||||
"examples",
|
||||
"graph_constraints",
|
||||
}
|
||||
|
||||
props = {p["name"]: p for p in projected["properties"]}
|
||||
|
||||
# Required field keeps `required`; optional fields omit it.
|
||||
assert props["prompt"]["required"] is True
|
||||
assert "required" not in props["greeting"]
|
||||
|
||||
# Enum options project to bare values, dropping the UI label.
|
||||
assert props["greeting_type"]["options"] == [{"value": "text"}, {"value": "audio"}]
|
||||
|
||||
# Validation bounds survive (they constrain valid authored values).
|
||||
assert props["delayed_start_duration"]["min_value"] == 0.1
|
||||
assert props["delayed_start_duration"]["max_value"] == 10.0
|
||||
|
||||
# llm_hint survives where present (catalog-tool references).
|
||||
assert "list_recordings" in props["greeting_recording_id"]["llm_hint"]
|
||||
|
||||
# fixed_collection rows recurse through the same projection.
|
||||
var_rows = {p["name"]: p for p in props["extraction_variables"]["properties"]}
|
||||
assert var_rows["type"]["options"] == [
|
||||
{"value": "string"},
|
||||
{"value": "number"},
|
||||
{"value": "boolean"},
|
||||
]
|
||||
|
||||
# graph_constraints drops its null sub-fields.
|
||||
assert projected["graph_constraints"] == {"min_incoming": 0, "max_incoming": 0}
|
||||
|
|
|
|||
|
|
@ -45,12 +45,11 @@ from api.enums import ToolCategory
|
|||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
EndCallNodeData,
|
||||
EndCallRFNode,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
StartCallNodeData,
|
||||
StartCallRFNode,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
|
|
@ -1014,8 +1013,9 @@ class TestEndCallExtractionBehavior:
|
|||
# Create a workflow where start node has NO extraction
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -1026,8 +1026,9 @@ class TestEndCallExtractionBehavior:
|
|||
extraction_enabled=False, # No extraction
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
|
|||
|
|
@ -34,12 +34,11 @@ from api.services.pipecat.recording_audio_cache import RecordingAudio
|
|||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
EndCallNodeData,
|
||||
EndCallRFNode,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
StartCallNodeData,
|
||||
StartCallRFNode,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
|
|
@ -65,8 +64,9 @@ def text_workflow() -> WorkflowGraph:
|
|||
"""Start->End workflow with text greeting and text transition speech."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -79,8 +79,9 @@ def text_workflow() -> WorkflowGraph:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
@ -114,8 +115,9 @@ def audio_workflow() -> WorkflowGraph:
|
|||
"""Start->End workflow with audio greeting and audio transition speech."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start Call",
|
||||
|
|
@ -128,8 +130,9 @@ def audio_workflow() -> WorkflowGraph:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End Call",
|
||||
|
|
@ -290,8 +293,9 @@ class TestStartGreeting:
|
|||
"""No greeting configured should return None."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start",
|
||||
|
|
@ -301,8 +305,9 @@ class TestStartGreeting:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End",
|
||||
|
|
@ -333,8 +338,9 @@ class TestStartGreeting:
|
|||
"""Text greeting with {{variable}} placeholders should be rendered."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
StartCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start",
|
||||
|
|
@ -346,8 +352,9 @@ class TestStartGreeting:
|
|||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
EndCallRFNode(
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End",
|
||||
|
|
|
|||
|
|
@ -18,6 +18,25 @@ def _qa_node(node_id="qa-1", api_key="", **extra_data):
|
|||
return {"id": node_id, "type": "qa", "position": {"x": 0, "y": 0}, "data": data}
|
||||
|
||||
|
||||
def _tuner_node(node_id="tuner-1", api_key="", **extra_data):
|
||||
"""Helper to build a Tuner node."""
|
||||
data = {
|
||||
"name": "Tuner",
|
||||
"tuner_enabled": True,
|
||||
"tuner_agent_id": "sales-bot",
|
||||
"tuner_workspace_id": 7,
|
||||
**extra_data,
|
||||
}
|
||||
if api_key:
|
||||
data["tuner_api_key"] = api_key
|
||||
return {
|
||||
"id": node_id,
|
||||
"type": "tuner",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": data,
|
||||
}
|
||||
|
||||
|
||||
def _agent_node(node_id="agent-1"):
|
||||
"""Helper to build a non-QA node."""
|
||||
return {
|
||||
|
|
@ -66,6 +85,19 @@ class TestMaskWorkflowDefinition:
|
|||
assert "qa_api_key" not in masked["nodes"][0]["data"]
|
||||
assert masked["nodes"][1]["data"]["qa_api_key"] == mask_key("sk-secret1234")
|
||||
|
||||
def test_masks_tuner_api_key(self):
|
||||
"""Tuner node api_key is masked, showing only last 4 chars."""
|
||||
real_key = "tuner_live_abcdefghijklmnop"
|
||||
wf = _make_workflow_def([_tuner_node(api_key=real_key)])
|
||||
|
||||
masked = mask_workflow_definition(wf)
|
||||
|
||||
masked_key = masked["nodes"][0]["data"]["tuner_api_key"]
|
||||
assert masked_key == mask_key(real_key)
|
||||
assert masked_key.endswith("mnop")
|
||||
assert masked_key.startswith("*")
|
||||
assert real_key not in str(masked)
|
||||
|
||||
def test_qa_node_without_api_key(self):
|
||||
"""QA node with no api_key is left as-is."""
|
||||
wf = _make_workflow_def([_qa_node()])
|
||||
|
|
@ -154,6 +186,16 @@ class TestMergeWorkflowApiKeys:
|
|||
|
||||
assert result["nodes"][0]["data"]["qa_api_key"] == new_key
|
||||
|
||||
def test_masked_tuner_key_is_restored(self):
|
||||
"""Masked Tuner keys round-trip without losing the stored secret."""
|
||||
real_key = "tuner_live_abcdefghijklmnop"
|
||||
existing = _make_workflow_def([_tuner_node(api_key=real_key)])
|
||||
incoming = _make_workflow_def([_tuner_node(api_key=mask_key(real_key))])
|
||||
|
||||
result = merge_workflow_api_keys(incoming, existing)
|
||||
|
||||
assert result["nodes"][0]["data"]["tuner_api_key"] == real_key
|
||||
|
||||
def test_no_incoming_api_key(self):
|
||||
"""QA node without api_key in incoming is left alone."""
|
||||
existing = _make_workflow_def([_qa_node(api_key="sk-existing-key1")])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue