mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat(mcp): generic MCP tool source with per-node function filtering (#301)
* feat(mcp): generic MCP tool source with per-node function filtering
Adds a Model Context Protocol tool category: connect a customer MCP
server and expose its tools to the agent, with optional per-node
allow-listing of individual MCP functions.
- ToolCategory.MCP enum + alembic migration
- MCP definition validator and collision-safe function-name namespacing
- McpToolSession wrapper: graceful-degrade, per-call open/close lifecycle
- CustomToolManager MCP branch (schemas + proxy handlers)
- Per-node mcp_tool_filters threaded through DTO/graph/engine
- Best-effort discovered_tools catalog cache + POST /tools/{uuid}/mcp/refresh
- UI: MCP create/edit config, tabbed ToolSelector with per-node toggles
* feat: refactor for code standardisation and documentation
---------
Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
parent
0097974444
commit
75839f9de5
40 changed files with 3028 additions and 137 deletions
|
|
@ -20,7 +20,7 @@ RUN pip install --user --no-cache-dir -r requirements.txt && \
|
|||
|
||||
# Copy and install pipecat from local submodule
|
||||
COPY pipecat /tmp/pipecat
|
||||
RUN pip install --user --no-cache-dir '/tmp/pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb]' && \
|
||||
RUN pip install --user --no-cache-dir '/tmp/pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb,mcp]' && \
|
||||
# Swap opencv-python (pulled by pipecat[webrtc]) for opencv-python-headless
|
||||
# to drop X11/Qt dependencies that otherwise require libxcb etc. in runner.
|
||||
pip uninstall -y opencv-python && \
|
||||
|
|
|
|||
64
api/alembic/versions/0a1b2c3d4e5f_add_mcp_in_toolcategory.py
Normal file
64
api/alembic/versions/0a1b2c3d4e5f_add_mcp_in_toolcategory.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"""add mcp in ToolCategory
|
||||
|
||||
Revision ID: 0a1b2c3d4e5f
|
||||
Revises: 4c1f1e3e8ef2
|
||||
Create Date: 2026-05-16 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from alembic_postgresql_enum import TableReference
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0a1b2c3d4e5f"
|
||||
down_revision: Union[str, None] = "4c1f1e3e8ef2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.sync_enum_values(
|
||||
enum_schema="public",
|
||||
enum_name="tool_category",
|
||||
new_values=[
|
||||
"http_api",
|
||||
"end_call",
|
||||
"transfer_call",
|
||||
"calculator",
|
||||
"native",
|
||||
"integration",
|
||||
"mcp",
|
||||
],
|
||||
affected_columns=[
|
||||
TableReference(
|
||||
table_schema="public", table_name="tools", column_name="category"
|
||||
)
|
||||
],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.sync_enum_values(
|
||||
enum_schema="public",
|
||||
enum_name="tool_category",
|
||||
new_values=[
|
||||
"http_api",
|
||||
"end_call",
|
||||
"transfer_call",
|
||||
"calculator",
|
||||
"native",
|
||||
"integration",
|
||||
],
|
||||
affected_columns=[
|
||||
TableReference(
|
||||
table_schema="public", table_name="tools", column_name="category"
|
||||
)
|
||||
],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
|
|
@ -133,6 +133,7 @@ class ToolCategory(Enum):
|
|||
CALCULATOR = "calculator" # Built-in calculator tool
|
||||
NATIVE = "native" # Built-in integrations (future: dtmf_input)
|
||||
INTEGRATION = "integration" # Third-party integrations (future: Google Calendar, Salesforce, etc.)
|
||||
MCP = "mcp" # Customer-provided MCP server exposing a tool catalog
|
||||
|
||||
|
||||
class ToolStatus(Enum):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ async def authenticate_mcp_request() -> UserModel:
|
|||
the `langfuse.user.id` / `langfuse.session.id` attributes make the
|
||||
span filterable in the Langfuse UI.
|
||||
"""
|
||||
headers = get_http_headers()
|
||||
# FastMCP strips Authorization by default unless explicitly included.
|
||||
# Preserve it here so Bearer API keys work for MCP tool invocations.
|
||||
headers = get_http_headers(include={"authorization"})
|
||||
api_key = headers.get("x-api-key")
|
||||
if not api_key:
|
||||
auth = headers.get("authorization", "")
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
"""API routes for managing tools."""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from api.db import db_client
|
||||
|
|
@ -13,9 +15,23 @@ from api.enums import PostHogEvent, ToolCategory, ToolStatus
|
|||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.workflow.mcp_tool_session import discover_mcp_tools
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpDefinitionError,
|
||||
validate_mcp_definition,
|
||||
)
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpToolConfig as SharedMcpToolConfig,
|
||||
)
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpToolDefinition as SharedMcpToolDefinition,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/tools")
|
||||
|
||||
McpToolConfig = SharedMcpToolConfig
|
||||
McpToolDefinition = SharedMcpToolDefinition
|
||||
|
||||
|
||||
# Request/Response schemas
|
||||
class ToolParameter(BaseModel):
|
||||
|
|
@ -183,6 +199,7 @@ ToolDefinition = Annotated[
|
|||
EndCallToolDefinition,
|
||||
TransferCallToolDefinition,
|
||||
CalculatorToolDefinition,
|
||||
McpToolDefinition,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
@ -248,6 +265,14 @@ class ToolResponse(BaseModel):
|
|||
from_attributes = True
|
||||
|
||||
|
||||
class McpRefreshResponse(BaseModel):
|
||||
"""Result of re-discovering an MCP server's tool catalog."""
|
||||
|
||||
tool_uuid: str
|
||||
discovered_tools: list = Field(default_factory=list)
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def build_tool_response(tool, include_created_by: bool = False) -> ToolResponse:
|
||||
"""Build a response from a tool model."""
|
||||
created_by = None
|
||||
|
|
@ -336,6 +361,52 @@ async def list_tools(
|
|||
return [build_tool_response(tool) for tool in tools]
|
||||
|
||||
|
||||
async def _fetch_credential(credential_uuid: Optional[str], organization_id: int):
|
||||
"""Best-effort credential lookup for MCP auth. A missing/failed credential
|
||||
degrades to ``None`` (unauthenticated) rather than failing the request."""
|
||||
if not credential_uuid:
|
||||
return None
|
||||
try:
|
||||
return await db_client.get_credential_by_uuid(credential_uuid, organization_id)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(f"MCP: credential fetch failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _populate_discovered_tools(definition: dict, *, organization_id: int) -> dict:
|
||||
"""Best-effort: for an MCP definition, connect to the server, list its
|
||||
tools, and overwrite ``config.discovered_tools``. Never raises and never
|
||||
blocks tool save — a dead server yields ``discovered_tools: []``. Non-MCP
|
||||
definitions pass through untouched."""
|
||||
if not isinstance(definition, dict) or definition.get("type") != "mcp":
|
||||
return definition
|
||||
try:
|
||||
cfg = validate_mcp_definition(definition)
|
||||
except McpDefinitionError:
|
||||
return definition
|
||||
|
||||
credential = await _fetch_credential(cfg.get("credential_uuid"), organization_id)
|
||||
|
||||
# Run discovery in an isolated asyncio task so an anyio cancel-scope
|
||||
# CancelledError doesn't bleed into the parent task and corrupt the
|
||||
# subsequent DB write. _run() never raises (degrades to []).
|
||||
async def _run() -> list:
|
||||
try:
|
||||
return await discover_mcp_tools(
|
||||
url=cfg["url"],
|
||||
credential=credential,
|
||||
timeout_secs=cfg["timeout_secs"],
|
||||
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
||||
)
|
||||
except BaseException as e: # noqa: BLE001
|
||||
logger.warning(f"MCP discovery failed; caching empty list: {e}")
|
||||
return []
|
||||
|
||||
discovered = await asyncio.ensure_future(_run())
|
||||
definition["config"]["discovered_tools"] = discovered
|
||||
return definition
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_tool(
|
||||
request: CreateToolRequest,
|
||||
|
|
@ -357,11 +428,16 @@ async def create_tool(
|
|||
|
||||
validate_category(request.category)
|
||||
|
||||
definition = await _populate_discovered_tools(
|
||||
request.definition.model_dump(),
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
tool = await db_client.create_tool(
|
||||
organization_id=user.selected_organization_id,
|
||||
user_id=user.id,
|
||||
name=request.name,
|
||||
definition=request.definition.model_dump(),
|
||||
definition=definition,
|
||||
category=request.category,
|
||||
description=request.description,
|
||||
icon=request.icon,
|
||||
|
|
@ -410,6 +486,67 @@ async def get_tool(
|
|||
return build_tool_response(tool, include_created_by=True)
|
||||
|
||||
|
||||
@router.post("/{tool_uuid}/mcp/refresh")
|
||||
async def refresh_mcp_tools(
|
||||
tool_uuid: str,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> McpRefreshResponse:
|
||||
"""Re-discover an MCP tool's server catalog and overwrite the cached
|
||||
``definition.config.discovered_tools``. Server down → 200 with error
|
||||
(cache not overwritten on transient failure)."""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
tool = await db_client.get_tool_by_uuid(
|
||||
tool_uuid, user.selected_organization_id, include_archived=True
|
||||
)
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
if tool.category != ToolCategory.MCP.value:
|
||||
raise HTTPException(status_code=400, detail="Tool is not an MCP tool")
|
||||
|
||||
try:
|
||||
cfg = validate_mcp_definition(tool.definition)
|
||||
except McpDefinitionError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid MCP definition: {e}")
|
||||
|
||||
credential = await _fetch_credential(
|
||||
cfg.get("credential_uuid"), user.selected_organization_id
|
||||
)
|
||||
|
||||
try:
|
||||
discovered = await discover_mcp_tools(
|
||||
url=cfg["url"],
|
||||
credential=credential,
|
||||
timeout_secs=cfg["timeout_secs"],
|
||||
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(f"MCP refresh discovery failed: {e}")
|
||||
discovered = []
|
||||
|
||||
if not discovered:
|
||||
error = (
|
||||
f"Could not reach the MCP server at {cfg['url']} "
|
||||
f"(or it exposes no tools). Previously cached list retained."
|
||||
)
|
||||
# Do NOT clobber a previously-good cache with [] on a transient outage.
|
||||
return McpRefreshResponse(tool_uuid=tool_uuid, discovered_tools=[], error=error)
|
||||
|
||||
new_def = dict(tool.definition or {})
|
||||
new_def["config"] = {**new_def.get("config", {}), "discovered_tools": discovered}
|
||||
await db_client.update_tool(
|
||||
tool_uuid=tool_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
definition=new_def,
|
||||
)
|
||||
return McpRefreshResponse(
|
||||
tool_uuid=tool_uuid, discovered_tools=discovered, error=None
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{tool_uuid}")
|
||||
async def update_tool(
|
||||
tool_uuid: str,
|
||||
|
|
@ -434,12 +571,21 @@ async def update_tool(
|
|||
if request.status:
|
||||
validate_status(request.status)
|
||||
|
||||
definition = (
|
||||
await _populate_discovered_tools(
|
||||
request.definition.model_dump(),
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
if request.definition
|
||||
else None
|
||||
)
|
||||
|
||||
tool = await db_client.update_tool(
|
||||
tool_uuid=tool_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
definition=request.definition.model_dump() if request.definition else None,
|
||||
definition=definition,
|
||||
icon=request.icon,
|
||||
icon_color=request.icon_color,
|
||||
status=request.status,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
from typing import Annotated, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, model_validator
|
||||
|
||||
|
|
@ -69,6 +69,7 @@ class _ExtractionNodeDataMixin(BaseModel):
|
|||
class _ToolDocumentRefsMixin(BaseModel):
|
||||
tool_uuids: Optional[List[str]] = None
|
||||
document_uuids: Optional[List[str]] = None
|
||||
mcp_tool_filters: Optional[Dict[str, List[str]]] = None
|
||||
|
||||
|
||||
class StartCallNodeData(
|
||||
|
|
|
|||
254
api/services/workflow/mcp_tool_session.py
Normal file
254
api/services/workflow/mcp_tool_session.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
"""Single unit that knows the MCP protocol + credentials.
|
||||
|
||||
Wraps the vendored Pipecat ``MCPClient`` for connection/session, builds
|
||||
streamable-HTTP params from a Dograh credential, exposes namespaced
|
||||
``FunctionSchema``s, and proxies tool calls. Connection failures degrade
|
||||
(``available = False``) instead of raising — the call must survive a
|
||||
dead MCP server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
||||
|
||||
from loguru import logger
|
||||
from mcp.client.session_group import StreamableHttpParameters
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.services.mcp_service import MCPClient
|
||||
|
||||
from api.services.workflow.tools.mcp_tool import namespace_function_name
|
||||
from api.utils.credential_auth import build_auth_header
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.db.models import ExternalCredentialModel
|
||||
|
||||
|
||||
def build_streamable_http_params(
|
||||
*,
|
||||
url: str,
|
||||
credential: Optional["ExternalCredentialModel"],
|
||||
timeout_secs: int,
|
||||
sse_read_timeout_secs: int,
|
||||
) -> StreamableHttpParameters:
|
||||
"""Build Pipecat/MCP streamable-HTTP params, injecting the auth header
|
||||
from an ExternalCredentialModel (reuses the http_api credential path)."""
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
if credential is not None:
|
||||
auth = build_auth_header(credential)
|
||||
headers = auth or None
|
||||
return StreamableHttpParameters(
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=timedelta(seconds=timeout_secs),
|
||||
sse_read_timeout=timedelta(seconds=sse_read_timeout_secs),
|
||||
)
|
||||
|
||||
|
||||
class McpToolSession:
|
||||
"""One live MCP server connection for the duration of a call."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tool_uuid: str,
|
||||
tool_name: str,
|
||||
url: str,
|
||||
credential: Optional["ExternalCredentialModel"],
|
||||
tools_filter: List[str],
|
||||
timeout_secs: int,
|
||||
sse_read_timeout_secs: int,
|
||||
) -> None:
|
||||
self._tool_uuid = tool_uuid
|
||||
self._tool_name = tool_name
|
||||
self._url = url
|
||||
self._credential = credential
|
||||
# An empty list is intentionally treated as "no filter (expose all
|
||||
# tools)" — Pipecat's MCPClient applies a filter only when this is a
|
||||
# non-empty list, so [] and None are equivalent ("all tools").
|
||||
self._tools_filter = tools_filter or None
|
||||
self._timeout_secs = timeout_secs
|
||||
self._sse_read_timeout_secs = sse_read_timeout_secs
|
||||
|
||||
self._client: Optional[MCPClient] = None
|
||||
self._session: Any = None # mcp.ClientSession (read once after start)
|
||||
self._schemas: List[FunctionSchema] = []
|
||||
# namespaced LLM name -> original MCP tool name
|
||||
self._name_map: Dict[str, str] = {}
|
||||
self.available: bool = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Connect, initialize, and cache the tool list. Never raises —
|
||||
on any failure the session is marked unavailable."""
|
||||
try:
|
||||
params = build_streamable_http_params(
|
||||
url=self._url,
|
||||
credential=self._credential,
|
||||
timeout_secs=self._timeout_secs,
|
||||
sse_read_timeout_secs=self._sse_read_timeout_secs,
|
||||
)
|
||||
self._client = MCPClient(params, tools_filter=self._tools_filter)
|
||||
await self._client.start()
|
||||
# Single, isolated touch of Pipecat internals (vendored submodule).
|
||||
self._session = self._client._active_session
|
||||
tools_schema = await self._client.get_tools_schema()
|
||||
|
||||
fallback = self._tool_uuid[:8] if self._tool_uuid else "server"
|
||||
for fs in tools_schema.standard_tools:
|
||||
ns_name = namespace_function_name(
|
||||
self._tool_name, fs.name, fallback=fallback
|
||||
)
|
||||
self._name_map[ns_name] = fs.name
|
||||
self._schemas.append(
|
||||
FunctionSchema(
|
||||
name=ns_name,
|
||||
description=fs.description,
|
||||
properties=fs.properties,
|
||||
required=fs.required,
|
||||
)
|
||||
)
|
||||
self.available = True
|
||||
logger.info(
|
||||
f"MCP session ready for tool '{self._tool_name}' "
|
||||
f"({self._tool_uuid}): {sorted(self._name_map)}"
|
||||
)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
raise
|
||||
except asyncio.CancelledError as e:
|
||||
# Empirically, a dead/unreachable MCP server does NOT surface as a
|
||||
# plain Exception here. The real failure is httpx.ConnectError, but
|
||||
# anyio's streamablehttp_client task group, while tearing down that
|
||||
# ConnectError, re-surfaces it to our frame as an *internal*
|
||||
# cancel-scope CancelledError carrying the signature message
|
||||
# "Cancelled via cancel scope <id>". A genuine *external*
|
||||
# cancellation (call teardown / shutdown) is a bare CancelledError
|
||||
# (empty args) or one with an application-chosen message. Type, MRO,
|
||||
# context chain, and asyncio task.cancelling() are all identical
|
||||
# between the two, so the anyio scope-signature message is the only
|
||||
# reliable discriminator. Re-raise genuine external cancellation to
|
||||
# preserve structured concurrency; degrade only on the anyio
|
||||
# connect-teardown artifact.
|
||||
msg = "" if not e.args else str(e.args[0] or "")
|
||||
if not msg.startswith("Cancelled via cancel scope"):
|
||||
raise
|
||||
await self._degrade(e)
|
||||
except Exception as e: # noqa: BLE001 — see _degrade docstring
|
||||
# Defensive: if a future Pipecat/httpx version surfaces the connect
|
||||
# failure directly (e.g. httpx.ConnectError) instead of via the
|
||||
# anyio cancel-scope artifact above, still degrade gracefully.
|
||||
await self._degrade(e)
|
||||
|
||||
async def _degrade(self, e: BaseException) -> None:
|
||||
"""Mark this session unavailable and tear down any dangling client so
|
||||
start() leaves self._client either fully usable or None. The contract
|
||||
requires graceful degradation on any *connect* failure (never raising
|
||||
for a dead MCP server) while genuine external cancellation /
|
||||
KeyboardInterrupt / SystemExit are re-raised by the caller."""
|
||||
self.available = False
|
||||
self._schemas = []
|
||||
self._name_map = {}
|
||||
# Self-contained cleanup: _client.start() may have succeeded before a
|
||||
# later step (e.g. get_tools_schema()) failed, leaving an open client.
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._client = None
|
||||
self._session = None
|
||||
logger.warning(
|
||||
f"MCP session unavailable for tool '{self._tool_name}' "
|
||||
f"({self._tool_uuid}) at {self._url}: {e!r}. "
|
||||
f"Call proceeds without these tools."
|
||||
)
|
||||
|
||||
@property
|
||||
def call_timeout_secs(self) -> float:
|
||||
"""Pipecat function-call timeout for this server's tools. Slightly
|
||||
longer than the transport read timeout so a slow MCP call surfaces
|
||||
as a structured tool error (handled in the handler) rather than a
|
||||
hard pipeline timeout."""
|
||||
return float(self._sse_read_timeout_secs) + 5.0
|
||||
|
||||
def function_schemas(
|
||||
self, allowed_raw_names: Optional[Set[str]] = None
|
||||
) -> List[FunctionSchema]:
|
||||
"""Return cached FunctionSchemas, optionally filtered by raw MCP tool name.
|
||||
|
||||
``allowed_raw_names=None`` returns all schemas. An empty set returns none.
|
||||
Raw names are the pre-namespace MCP tool names (e.g. ``echo``, not
|
||||
``mcp__slug__echo``).
|
||||
"""
|
||||
if allowed_raw_names is None:
|
||||
return list(self._schemas)
|
||||
return [
|
||||
s for s in self._schemas if self._name_map.get(s.name) in allowed_raw_names
|
||||
]
|
||||
|
||||
def discovered_tools(self) -> List[Dict[str, str]]:
|
||||
"""Raw MCP tool catalog for UI/cache: ``[{name, description}]``
|
||||
using the *raw* server names (not the namespaced LLM names).
|
||||
Empty if the session is unavailable."""
|
||||
out: List[Dict[str, str]] = []
|
||||
for s in self._schemas:
|
||||
raw = self._name_map.get(s.name)
|
||||
if raw is None:
|
||||
continue
|
||||
out.append({"name": raw, "description": s.description or ""})
|
||||
return out
|
||||
|
||||
async def call(self, namespaced_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""Invoke an MCP tool by its namespaced LLM name. Returns a string
|
||||
(flattened text content). Raises if the session is unavailable so
|
||||
the caller can map it to a structured error for the LLM."""
|
||||
if not self.available or self._session is None:
|
||||
raise RuntimeError(f"MCP session unavailable for {namespaced_name}")
|
||||
original = self._name_map.get(namespaced_name)
|
||||
if original is None:
|
||||
raise RuntimeError(f"Unknown MCP function {namespaced_name}")
|
||||
result = await self._session.call_tool(original, arguments=arguments)
|
||||
text = ""
|
||||
for content in getattr(result, "content", []) or []:
|
||||
if getattr(content, "text", None):
|
||||
text += content.text
|
||||
return text or "Sorry, the MCP tool returned no content."
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing MCP session {self._tool_uuid}: {e}")
|
||||
finally:
|
||||
self._client = None
|
||||
self._session = None
|
||||
|
||||
|
||||
async def discover_mcp_tools(
|
||||
*,
|
||||
url: str,
|
||||
credential: Optional["ExternalCredentialModel"],
|
||||
timeout_secs: int,
|
||||
sse_read_timeout_secs: int,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Open an ephemeral MCP session, list its tools, close it. Returns
|
||||
``[{name, description}]`` (raw names). Never raises — on any connect
|
||||
failure returns ``[]``."""
|
||||
session = McpToolSession(
|
||||
tool_uuid="discover",
|
||||
tool_name="discover",
|
||||
url=url,
|
||||
credential=credential,
|
||||
tools_filter=[],
|
||||
timeout_secs=timeout_secs,
|
||||
sse_read_timeout_secs=sse_read_timeout_secs,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
if not session.available:
|
||||
return []
|
||||
return session.discovered_tools()
|
||||
finally:
|
||||
await session.close()
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Union
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import (
|
||||
|
|
@ -16,6 +16,7 @@ from pipecat.services.settings import LLMSettings
|
|||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import ToolCategory
|
||||
from api.services.pipecat.audio_playback import play_audio
|
||||
from api.services.workflow.disposition_mapper import apply_disposition_mapping
|
||||
from api.services.workflow.workflow_graph import Node, WorkflowGraph
|
||||
|
|
@ -34,6 +35,7 @@ import asyncio
|
|||
from loguru import logger
|
||||
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
compose_functions_for_node,
|
||||
compose_system_prompt_for_node,
|
||||
|
|
@ -116,6 +118,9 @@ class PipecatEngine:
|
|||
# Cached organization ID (resolved lazily from workflow run)
|
||||
self._organization_id: Optional[int] = None
|
||||
|
||||
# Open MCP tool sessions for this call, keyed by tool_uuid
|
||||
self._mcp_sessions: Dict[str, McpToolSession] = {}
|
||||
|
||||
# Embeddings configuration (passed from run_pipeline.py)
|
||||
self._embeddings_api_key: Optional[str] = embeddings_api_key
|
||||
self._embeddings_model: Optional[str] = embeddings_model
|
||||
|
|
@ -178,6 +183,9 @@ class PipecatEngine:
|
|||
# Helper that encapsulates custom tool management
|
||||
self._custom_tool_manager = CustomToolManager(self)
|
||||
|
||||
# Open persistent MCP server sessions for this call (degrades on failure)
|
||||
await self._open_mcp_sessions()
|
||||
|
||||
# Helper that encapsulates context summarization
|
||||
if self._context_compaction_enabled:
|
||||
self._context_summarization_manager = ContextSummarizationManager(self)
|
||||
|
|
@ -503,7 +511,10 @@ class PipecatEngine:
|
|||
|
||||
# Register custom tool handlers for this node
|
||||
if node.tool_uuids and self._custom_tool_manager:
|
||||
await self._custom_tool_manager.register_handlers(node.tool_uuids)
|
||||
await self._custom_tool_manager.register_handlers(
|
||||
node.tool_uuids,
|
||||
mcp_tool_filters=getattr(node, "mcp_tool_filters", None),
|
||||
)
|
||||
|
||||
# Register knowledge base retrieval handler if node has documents
|
||||
if node.document_uuids:
|
||||
|
|
@ -814,6 +825,79 @@ class PipecatEngine:
|
|||
"""Get the gathered context including extracted variables."""
|
||||
return self._gathered_context.copy()
|
||||
|
||||
async def _open_mcp_sessions(self) -> None:
|
||||
"""Connect every MCP-category tool referenced by any workflow node.
|
||||
Failures degrade (session marked unavailable); never raises."""
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpDefinitionError,
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
||||
try:
|
||||
tool_uuids: set[str] = set()
|
||||
for node in self.workflow.nodes.values():
|
||||
for tu in getattr(node, "tool_uuids", None) or []:
|
||||
tool_uuids.add(tu)
|
||||
if not tool_uuids:
|
||||
return
|
||||
|
||||
organization_id = await self._get_organization_id()
|
||||
if not organization_id:
|
||||
logger.warning("Cannot open MCP sessions: organization_id missing")
|
||||
return
|
||||
|
||||
tools = await db_client.get_tools_by_uuids(
|
||||
list(tool_uuids), organization_id
|
||||
)
|
||||
for tool in tools:
|
||||
if tool.category != ToolCategory.MCP.value:
|
||||
continue
|
||||
try:
|
||||
cfg = validate_mcp_definition(tool.definition)
|
||||
except McpDefinitionError as e:
|
||||
logger.warning(
|
||||
f"Skipping MCP tool '{tool.name}' ({tool.tool_uuid}): "
|
||||
f"invalid definition: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
credential = None
|
||||
if cfg["credential_uuid"]:
|
||||
try:
|
||||
credential = await db_client.get_credential_by_uuid(
|
||||
cfg["credential_uuid"], organization_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MCP tool '{tool.name}': credential fetch failed: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url=cfg["url"],
|
||||
credential=credential,
|
||||
tools_filter=cfg["tools_filter"],
|
||||
timeout_secs=cfg["timeout_secs"],
|
||||
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
||||
)
|
||||
await session.start()
|
||||
self._mcp_sessions[tool.tool_uuid] = session
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to open MCP sessions; call proceeds without MCP tools: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _close_mcp_sessions(self) -> None:
|
||||
for tool_uuid, session in list(self._mcp_sessions.items()):
|
||||
try:
|
||||
await session.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing MCP session {tool_uuid}: {e}")
|
||||
self._mcp_sessions = {}
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up engine resources on disconnect."""
|
||||
# Cancel any pending timeout tasks
|
||||
|
|
@ -823,6 +907,12 @@ class PipecatEngine:
|
|||
):
|
||||
self._user_response_timeout_task.cancel()
|
||||
|
||||
# Cancel any in-flight background summarization
|
||||
if self._context_summarization_manager:
|
||||
await self._context_summarization_manager.cleanup()
|
||||
# Cancel any in-flight background summarization.
|
||||
# MCP sessions are closed in a finally block so they are guaranteed to
|
||||
# run even if the summarization cleanup raises an exception.
|
||||
try:
|
||||
if self._context_summarization_manager:
|
||||
await self._context_summarization_manager.cleanup()
|
||||
finally:
|
||||
# Close any open MCP tool sessions
|
||||
await self._close_mcp_sessions()
|
||||
|
|
|
|||
|
|
@ -117,7 +117,8 @@ async def compose_functions_for_node(
|
|||
# Custom tools
|
||||
if node.tool_uuids and custom_tool_manager:
|
||||
custom_tool_schemas = await custom_tool_manager.get_tool_schemas(
|
||||
node.tool_uuids
|
||||
node.tool_uuids,
|
||||
mcp_tool_filters=getattr(node, "mcp_tool_filters", None),
|
||||
)
|
||||
functions.extend(custom_tool_schemas)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ from api.services.workflow.tools.custom_tool import (
|
|||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
|
|
@ -121,11 +122,18 @@ class CustomToolManager:
|
|||
"""Get the organization ID from the engine (shared cache)."""
|
||||
return await self._engine._get_organization_id()
|
||||
|
||||
async def get_tool_schemas(self, tool_uuids: list[str]) -> list[FunctionSchema]:
|
||||
async def get_tool_schemas(
|
||||
self,
|
||||
tool_uuids: list[str],
|
||||
mcp_tool_filters: Optional[dict[str, list[str]]] = None,
|
||||
) -> list[FunctionSchema]:
|
||||
"""Fetch custom tools and convert them to function schemas.
|
||||
|
||||
Args:
|
||||
tool_uuids: List of tool UUIDs to fetch
|
||||
mcp_tool_filters: Optional per-node filter mapping tool_uuid → list of
|
||||
raw MCP tool names to expose. None (default) exposes all tools.
|
||||
Empty dict or entry with [] suppresses all tools for that uuid.
|
||||
|
||||
Returns:
|
||||
List of FunctionSchema objects for LLM
|
||||
|
|
@ -154,6 +162,22 @@ class CustomToolManager:
|
|||
)
|
||||
continue
|
||||
|
||||
if tool.category == ToolCategory.MCP.value:
|
||||
session = self._engine._mcp_sessions.get(tool.tool_uuid)
|
||||
if session is None or not session.available:
|
||||
logger.warning(
|
||||
f"MCP tool '{tool.name}' ({tool.tool_uuid}) "
|
||||
f"unavailable; skipping"
|
||||
)
|
||||
continue
|
||||
allowed = (
|
||||
None
|
||||
if mcp_tool_filters is None
|
||||
else set(mcp_tool_filters.get(tool.tool_uuid, []))
|
||||
)
|
||||
schemas.extend(session.function_schemas(allowed))
|
||||
continue
|
||||
|
||||
raw_schema = tool_to_function_schema(tool)
|
||||
function_name = raw_schema["function"]["name"]
|
||||
|
||||
|
|
@ -178,11 +202,18 @@ class CustomToolManager:
|
|||
logger.error(f"Failed to fetch custom tools: {e}")
|
||||
return []
|
||||
|
||||
async def register_handlers(self, tool_uuids: list[str]) -> None:
|
||||
async def register_handlers(
|
||||
self,
|
||||
tool_uuids: list[str],
|
||||
mcp_tool_filters: Optional[dict[str, list[str]]] = None,
|
||||
) -> None:
|
||||
"""Register custom tool execution handlers with the LLM.
|
||||
|
||||
Args:
|
||||
tool_uuids: List of tool UUIDs to register handlers for
|
||||
mcp_tool_filters: Optional per-node filter mapping tool_uuid → list of
|
||||
raw MCP tool names to expose. None (default) exposes all tools.
|
||||
Empty dict or entry with [] suppresses all tools for that uuid.
|
||||
"""
|
||||
organization_id = await self.get_organization_id()
|
||||
if not organization_id:
|
||||
|
|
@ -203,6 +234,32 @@ class CustomToolManager:
|
|||
)
|
||||
continue
|
||||
|
||||
if tool.category == ToolCategory.MCP.value:
|
||||
session = self._engine._mcp_sessions.get(tool.tool_uuid)
|
||||
if session is None or not session.available:
|
||||
logger.warning(
|
||||
f"MCP tool '{tool.name}' ({tool.tool_uuid}) "
|
||||
f"unavailable; skipping handler registration"
|
||||
)
|
||||
continue
|
||||
allowed = (
|
||||
None
|
||||
if mcp_tool_filters is None
|
||||
else set(mcp_tool_filters.get(tool.tool_uuid, []))
|
||||
)
|
||||
mcp_schemas = session.function_schemas(allowed)
|
||||
for fs in mcp_schemas:
|
||||
self._engine.llm.register_function(
|
||||
fs.name,
|
||||
self._create_mcp_handler(session, fs.name),
|
||||
timeout_secs=session.call_timeout_secs,
|
||||
)
|
||||
logger.debug(
|
||||
f"Registered {len(mcp_schemas)} MCP "
|
||||
f"handlers for tool '{tool.name}' ({tool.tool_uuid})"
|
||||
)
|
||||
continue
|
||||
|
||||
schema = tool_to_function_schema(tool)
|
||||
function_name = schema["function"]["name"]
|
||||
|
||||
|
|
@ -335,6 +392,29 @@ class CustomToolManager:
|
|||
|
||||
return http_tool_handler
|
||||
|
||||
def _create_mcp_handler(self, session: "McpToolSession", function_name: str):
|
||||
"""Create a handler that proxies an LLM function call to a live MCP
|
||||
session. Errors are returned to the LLM as structured text so the
|
||||
agent can recover verbally; the call is never crashed."""
|
||||
|
||||
async def mcp_tool_handler(
|
||||
function_call_params: FunctionCallParams,
|
||||
) -> None:
|
||||
logger.info(f"MCP Tool EXECUTED: {function_name}")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
try:
|
||||
result = await session.call(
|
||||
function_name, function_call_params.arguments or {}
|
||||
)
|
||||
await function_call_params.result_callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool '{function_name}' failed: {e}")
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": str(e)}
|
||||
)
|
||||
|
||||
return mcp_tool_handler
|
||||
|
||||
def _create_end_call_handler(self, tool: Any, function_name: str):
|
||||
"""Create a handler function for an end call tool.
|
||||
|
||||
|
|
|
|||
116
api/services/workflow/tools/mcp_tool.py
Normal file
116
api/services/workflow/tools/mcp_tool.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
"""Pure helpers for MCP-category tools: definition validation and
|
||||
LLM-function-name namespacing. No I/O, no MCP protocol here."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
|
||||
DEFAULT_TIMEOUT_SECS = 30
|
||||
DEFAULT_SSE_READ_TIMEOUT_SECS = 300
|
||||
|
||||
|
||||
class McpDefinitionError(ValueError):
|
||||
"""Raised when an MCP tool definition is structurally invalid."""
|
||||
|
||||
|
||||
class McpToolConfig(BaseModel):
|
||||
"""Configuration for an MCP tool definition."""
|
||||
|
||||
transport: Literal["streamable_http"] = Field(
|
||||
default="streamable_http", description="MCP transport protocol"
|
||||
)
|
||||
url: str = Field(description="MCP server URL (must be http:// or https://)")
|
||||
credential_uuid: Optional[str] = Field(
|
||||
default=None, description="Reference to ExternalCredentialModel for auth"
|
||||
)
|
||||
tools_filter: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Allowlist of MCP tool names to expose (empty = all tools)",
|
||||
)
|
||||
timeout_secs: int = Field(
|
||||
default=DEFAULT_TIMEOUT_SECS, description="Connection timeout in seconds"
|
||||
)
|
||||
sse_read_timeout_secs: int = Field(
|
||||
default=DEFAULT_SSE_READ_TIMEOUT_SECS,
|
||||
description="SSE read timeout in seconds",
|
||||
)
|
||||
discovered_tools: list[dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Server-managed cache of the MCP server's tool catalog "
|
||||
"[{name, description}]. Populated best-effort by the backend."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def validate_url(cls, v: str) -> str:
|
||||
if not isinstance(v, str) or not v.startswith(("http://", "https://")):
|
||||
raise ValueError("config.url must be an http(s) URL")
|
||||
return v
|
||||
|
||||
@field_validator("tools_filter")
|
||||
@classmethod
|
||||
def validate_tools_filter(cls, v: list[str]) -> list[str]:
|
||||
if not all(isinstance(tool_name, str) for tool_name in v):
|
||||
raise ValueError("config.tools_filter must be a list of strings")
|
||||
return v
|
||||
|
||||
|
||||
class McpToolDefinition(BaseModel):
|
||||
"""Persisted MCP tool definition."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version")
|
||||
type: Literal["mcp"] = Field(description="Tool type")
|
||||
config: McpToolConfig = Field(description="MCP server configuration")
|
||||
|
||||
|
||||
def _format_validation_error(error: ValidationError) -> str:
|
||||
parts: list[str] = []
|
||||
for item in error.errors():
|
||||
location = ".".join(str(part) for part in item["loc"])
|
||||
parts.append(f"{location}: {item['msg']}")
|
||||
return "; ".join(parts)
|
||||
|
||||
|
||||
def validate_mcp_definition(definition: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate a ``type: "mcp"`` ToolModel definition and return a
|
||||
normalized config dict with defaults applied.
|
||||
|
||||
Raises:
|
||||
McpDefinitionError: if the definition is missing required fields
|
||||
or uses an unsupported transport.
|
||||
"""
|
||||
if not isinstance(definition, dict) or definition.get("type") != "mcp":
|
||||
raise McpDefinitionError("definition.type must be 'mcp'")
|
||||
|
||||
config = definition.get("config")
|
||||
if not isinstance(config, dict):
|
||||
raise McpDefinitionError("definition.config is required and must be an object")
|
||||
|
||||
try:
|
||||
parsed = McpToolDefinition.model_validate(definition)
|
||||
except ValidationError as e:
|
||||
raise McpDefinitionError(_format_validation_error(e)) from e
|
||||
|
||||
return parsed.config.model_dump(exclude={"discovered_tools"})
|
||||
|
||||
|
||||
def _slugify(value: str) -> str:
|
||||
slug = re.sub(r"[^a-z0-9]+", "_", value.strip().lower()).strip("_")
|
||||
return slug
|
||||
|
||||
|
||||
def namespace_function_name(
|
||||
tool_name: str, mcp_tool_name: str, *, fallback: str = "server"
|
||||
) -> str:
|
||||
"""Build a collision-safe LLM function name: ``mcp__<slug>__<tool>``.
|
||||
|
||||
``slug`` is derived from the Dograh ToolModel name; if it slugifies to
|
||||
empty, ``fallback`` (e.g. first 8 chars of tool_uuid) is used instead.
|
||||
"""
|
||||
slug = _slugify(tool_name) or _slugify(fallback) or "server"
|
||||
return f"mcp__{slug}__{mcp_tool_name}"
|
||||
|
|
@ -89,6 +89,7 @@ class Node:
|
|||
self.delayed_start_duration = getattr(data, "delayed_start_duration", None)
|
||||
self.tool_uuids = getattr(data, "tool_uuids", None)
|
||||
self.document_uuids = getattr(data, "document_uuids", None)
|
||||
self.mcp_tool_filters = getattr(data, "mcp_tool_filters", None)
|
||||
self.pre_call_fetch_enabled = getattr(data, "pre_call_fetch_enabled", False)
|
||||
self.pre_call_fetch_url = getattr(data, "pre_call_fetch_url", None)
|
||||
self.pre_call_fetch_credential_uuid = getattr(
|
||||
|
|
|
|||
0
api/tests/support/__init__.py
Normal file
0
api/tests/support/__init__.py
Normal file
103
api/tests/support/mcp_mock_server.py
Normal file
103
api/tests/support/mcp_mock_server.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""A real FastMCP server exposing 2 tools over streamable-HTTP, run in a
|
||||
background uvicorn thread on an ephemeral port. Used to exercise the real
|
||||
MCP protocol path in tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import socket
|
||||
import threading
|
||||
from typing import AsyncIterator
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastmcp import FastMCP
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
def _build_app(required_headers: dict[str, str] | None = None):
|
||||
mcp = FastMCP("mock-mcp")
|
||||
|
||||
@mcp.tool()
|
||||
def echo(text: str) -> str:
|
||||
"""Echo the provided text back."""
|
||||
return f"echo:{text}"
|
||||
|
||||
@mcp.tool()
|
||||
def add(a: int, b: int) -> int:
|
||||
"""Add two integers."""
|
||||
return a + b
|
||||
|
||||
# FastMCP 3.x: ASGI app for streamable-HTTP transport at "/mcp".
|
||||
app = mcp.http_app()
|
||||
if not required_headers:
|
||||
return app
|
||||
|
||||
normalized = {k.lower(): v for k, v in required_headers.items()}
|
||||
|
||||
async def guarded_app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
headers = {
|
||||
key.decode("latin-1").lower(): value.decode("latin-1")
|
||||
for key, value in scope.get("headers", [])
|
||||
}
|
||||
for header_name, expected_value in normalized.items():
|
||||
if headers.get(header_name) != expected_value:
|
||||
response = JSONResponse(
|
||||
{"detail": f"Missing or invalid header: {header_name}"},
|
||||
status_code=401,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
await app(scope, receive, send)
|
||||
|
||||
return guarded_app
|
||||
|
||||
|
||||
def _free_port() -> int:
|
||||
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def running_mcp_server(
|
||||
*, required_headers: dict[str, str] | None = None
|
||||
) -> AsyncIterator[str]:
|
||||
"""Yield the base streamable-HTTP URL of a live mock MCP server."""
|
||||
port = _free_port()
|
||||
config = uvicorn.Config(
|
||||
_build_app(required_headers), host="127.0.0.1", port=port, log_level="warning"
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
|
||||
base_url = f"http://127.0.0.1:{port}/mcp"
|
||||
server_ready = False
|
||||
for _ in range(50):
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
await client.get(base_url, timeout=0.5)
|
||||
server_ready = True
|
||||
break
|
||||
except Exception:
|
||||
await asyncio.sleep(0.1)
|
||||
if not server_ready:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
raise RuntimeError(f"Mock MCP server at {base_url} failed to start within 5s")
|
||||
try:
|
||||
yield base_url
|
||||
finally:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
if thread.is_alive():
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Mock MCP server thread did not terminate within 5s",
|
||||
ResourceWarning,
|
||||
)
|
||||
63
api/tests/test_mcp_auth.py
Normal file
63
api/tests/test_mcp_auth.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_mcp_request_accepts_bearer_authorization():
|
||||
user = MagicMock()
|
||||
user.id = 1
|
||||
user.selected_organization_id = 90
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.auth.get_http_headers",
|
||||
return_value={"authorization": "Bearer secret-api-key"},
|
||||
) as get_headers,
|
||||
patch(
|
||||
"api.mcp_server.auth._handle_api_key_auth",
|
||||
AsyncMock(return_value=user),
|
||||
) as handle_auth,
|
||||
):
|
||||
authed = await authenticate_mcp_request()
|
||||
|
||||
assert authed is user
|
||||
get_headers.assert_called_once_with(include={"authorization"})
|
||||
handle_auth.assert_awaited_once_with("secret-api-key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_mcp_request_accepts_x_api_key():
|
||||
user = MagicMock()
|
||||
user.id = 2
|
||||
user.selected_organization_id = 91
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.auth.get_http_headers",
|
||||
return_value={"x-api-key": "secret-api-key"},
|
||||
) as get_headers,
|
||||
patch(
|
||||
"api.mcp_server.auth._handle_api_key_auth",
|
||||
AsyncMock(return_value=user),
|
||||
) as handle_auth,
|
||||
):
|
||||
authed = await authenticate_mcp_request()
|
||||
|
||||
assert authed is user
|
||||
get_headers.assert_called_once_with(include={"authorization"})
|
||||
handle_auth.assert_awaited_once_with("secret-api-key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_mcp_request_rejects_missing_api_key():
|
||||
with patch("api.mcp_server.auth.get_http_headers", return_value={}) as get_headers:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await authenticate_mcp_request()
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Missing API key" in str(exc_info.value.detail)
|
||||
get_headers.assert_called_once_with(include={"authorization"})
|
||||
181
api/tests/test_mcp_custom_tool_manager.py
Normal file
181
api/tests/test_mcp_custom_tool_manager.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import ToolCategory
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.tests.support.mcp_mock_server import running_mcp_server
|
||||
|
||||
|
||||
def _mcp_tool():
|
||||
t = MagicMock()
|
||||
t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8]
|
||||
t.name = "Acme MCP"
|
||||
t.category = ToolCategory.MCP.value
|
||||
t.definition = {"type": "mcp", "config": {"url": "https://x/mcp"}}
|
||||
return t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tool_schemas_and_handler_for_mcp(monkeypatch):
|
||||
async with running_mcp_server() as base_url:
|
||||
tool = _mcp_tool()
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
|
||||
engine = MagicMock()
|
||||
engine._mcp_sessions = {tool.tool_uuid: session}
|
||||
registered = {}
|
||||
reg_kwargs = {}
|
||||
|
||||
def _reg(name, fn, **kw):
|
||||
registered[name] = fn
|
||||
reg_kwargs[name] = kw
|
||||
|
||||
engine.llm.register_function = _reg
|
||||
|
||||
mgr = CustomToolManager(engine)
|
||||
mgr.get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
|
||||
)
|
||||
|
||||
try:
|
||||
schemas = await mgr.get_tool_schemas([tool.tool_uuid])
|
||||
names = sorted(s.name for s in schemas)
|
||||
assert names == ["mcp__acme_mcp__add", "mcp__acme_mcp__echo"]
|
||||
|
||||
await mgr.register_handlers([tool.tool_uuid])
|
||||
assert "mcp__acme_mcp__echo" in registered
|
||||
assert reg_kwargs["mcp__acme_mcp__echo"]["timeout_secs"] == pytest.approx(
|
||||
15.0
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
class P:
|
||||
function_name = "mcp__acme_mcp__echo"
|
||||
arguments = {"text": "yo"}
|
||||
|
||||
async def result_callback(self, r, *, properties=None):
|
||||
captured["r"] = r
|
||||
|
||||
await registered["mcp__acme_mcp__echo"](P())
|
||||
assert "echo:yo" in str(captured["r"])
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unavailable_mcp_session_contributes_nothing(monkeypatch):
|
||||
tool = _mcp_tool()
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=1,
|
||||
sse_read_timeout_secs=1,
|
||||
)
|
||||
await session.start() # degrades
|
||||
|
||||
engine = MagicMock()
|
||||
engine._mcp_sessions = {tool.tool_uuid: session}
|
||||
mgr = CustomToolManager(engine)
|
||||
mgr.get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]))
|
||||
|
||||
schemas = await mgr.get_tool_schemas([tool.tool_uuid])
|
||||
assert schemas == []
|
||||
await mgr.register_handlers([tool.tool_uuid]) # must not raise
|
||||
|
||||
|
||||
def test_call_timeout_secs_is_read_timeout_plus_buffer():
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-abc123",
|
||||
tool_name="Acme MCP",
|
||||
url="https://x/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
assert session.call_timeout_secs == 25.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_node_mcp_filter_intersection(monkeypatch):
|
||||
async with running_mcp_server() as base_url:
|
||||
tool = _mcp_tool()
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
|
||||
engine = MagicMock()
|
||||
engine._mcp_sessions = {tool.tool_uuid: session}
|
||||
registered = {}
|
||||
engine.llm.register_function = lambda name, fn, **kw: registered.__setitem__(
|
||||
name, fn
|
||||
)
|
||||
|
||||
mgr = CustomToolManager(engine)
|
||||
mgr.get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
|
||||
)
|
||||
try:
|
||||
# Allow only raw "echo" for this node
|
||||
filters = {tool.tool_uuid: ["echo"]}
|
||||
schemas = await mgr.get_tool_schemas(
|
||||
[tool.tool_uuid], mcp_tool_filters=filters
|
||||
)
|
||||
# Check only "echo" schema returned (namespaced name depends on tool.name)
|
||||
assert len(schemas) == 1
|
||||
assert all("echo" in s.name for s in schemas)
|
||||
|
||||
await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters=filters)
|
||||
assert len(registered) == 1
|
||||
assert all("echo" in k for k in registered)
|
||||
|
||||
# No filter entry for this uuid = none (default-none)
|
||||
registered.clear()
|
||||
result = await mgr.get_tool_schemas([tool.tool_uuid], mcp_tool_filters={})
|
||||
assert result == []
|
||||
await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters={})
|
||||
assert registered == {}
|
||||
|
||||
# mcp_tool_filters=None = backward-compatible (all tools)
|
||||
registered.clear()
|
||||
all_schemas = await mgr.get_tool_schemas([tool.tool_uuid])
|
||||
assert len(all_schemas) == 2 # both echo and add
|
||||
await mgr.register_handlers([tool.tool_uuid])
|
||||
assert len(registered) == 2 # both handlers registered
|
||||
finally:
|
||||
await session.close()
|
||||
107
api/tests/test_mcp_integration.py
Normal file
107
api/tests/test_mcp_integration.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import ToolCategory
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.tests.support.mcp_mock_server import running_mcp_server
|
||||
|
||||
|
||||
def _mcp_tool(url: str):
|
||||
t = MagicMock()
|
||||
t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8]
|
||||
t.name = "Acme MCP"
|
||||
t.category = ToolCategory.MCP.value
|
||||
t.definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http", "url": url},
|
||||
}
|
||||
return t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_opens_and_closes_mcp_sessions(monkeypatch):
|
||||
async with running_mcp_server() as base_url:
|
||||
tool = _mcp_tool(base_url)
|
||||
|
||||
engine = PipecatEngine.__new__(PipecatEngine)
|
||||
node = MagicMock()
|
||||
node.tool_uuids = [tool.tool_uuid]
|
||||
workflow = MagicMock()
|
||||
workflow.nodes = {"n1": node}
|
||||
engine.workflow = workflow
|
||||
engine._mcp_sessions = {}
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_credential_by_uuid", AsyncMock(return_value=None)
|
||||
)
|
||||
engine._get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
await engine._open_mcp_sessions()
|
||||
try:
|
||||
assert tool.tool_uuid in engine._mcp_sessions
|
||||
sess = engine._mcp_sessions[tool.tool_uuid]
|
||||
assert sess.available is True
|
||||
assert len(sess.function_schemas()) == 2
|
||||
finally:
|
||||
await engine._close_mcp_sessions()
|
||||
assert engine._mcp_sessions == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_mcp_sessions_swallows_db_error(monkeypatch):
|
||||
engine = PipecatEngine.__new__(PipecatEngine)
|
||||
node = MagicMock()
|
||||
node.tool_uuids = ["uuid-deadbeef"]
|
||||
workflow = MagicMock()
|
||||
workflow.nodes = {"n1": node}
|
||||
engine.workflow = workflow
|
||||
engine._mcp_sessions = {}
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client,
|
||||
"get_tools_by_uuids",
|
||||
AsyncMock(side_effect=RuntimeError("db down")),
|
||||
)
|
||||
engine._get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
# Must NOT raise
|
||||
await engine._open_mcp_sessions()
|
||||
assert engine._mcp_sessions == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_mcp_sessions_skips_tool_when_credential_fetch_fails(monkeypatch):
|
||||
tool = _mcp_tool("http://127.0.0.1:1/mcp")
|
||||
tool.definition["config"]["credential_uuid"] = "cred-1234"
|
||||
|
||||
engine = PipecatEngine.__new__(PipecatEngine)
|
||||
node = MagicMock()
|
||||
node.tool_uuids = [tool.tool_uuid]
|
||||
workflow = MagicMock()
|
||||
workflow.nodes = {"n1": node}
|
||||
engine.workflow = workflow
|
||||
engine._mcp_sessions = {}
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]))
|
||||
monkeypatch.setattr(
|
||||
db_client,
|
||||
"get_credential_by_uuid",
|
||||
AsyncMock(side_effect=RuntimeError("cred store down")),
|
||||
)
|
||||
engine._get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
# Must NOT raise, and must skip the tool (no futile unauthenticated start)
|
||||
await engine._open_mcp_sessions()
|
||||
assert engine._mcp_sessions == {}
|
||||
112
api/tests/test_mcp_tool_definition.py
Normal file
112
api/tests/test_mcp_tool_definition.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
import importlib
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import ToolCategory
|
||||
from api.routes.tool import McpToolConfig as RouteMcpToolConfig
|
||||
from api.routes.tool import McpToolDefinition as RouteMcpToolDefinition
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpDefinitionError,
|
||||
McpToolConfig,
|
||||
McpToolDefinition,
|
||||
namespace_function_name,
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_category_exists():
|
||||
assert ToolCategory.MCP.value == "mcp"
|
||||
assert ToolCategory("mcp") is ToolCategory.MCP
|
||||
|
||||
|
||||
def test_mcp_migration_present_and_chained(monkeypatch):
|
||||
mod = importlib.import_module(
|
||||
"api.alembic.versions.0a1b2c3d4e5f_add_mcp_in_toolcategory"
|
||||
)
|
||||
assert mod.revision == "0a1b2c3d4e5f"
|
||||
assert mod.down_revision == "4c1f1e3e8ef2"
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_sync_enum_values(**kwargs):
|
||||
calls.append(kwargs)
|
||||
|
||||
monkeypatch.setattr(mod.op, "sync_enum_values", fake_sync_enum_values)
|
||||
|
||||
mod.upgrade()
|
||||
mod.downgrade()
|
||||
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["enum_name"] == "tool_category"
|
||||
assert "mcp" in calls[0]["new_values"]
|
||||
assert "mcp" not in calls[1]["new_values"]
|
||||
|
||||
|
||||
def test_route_reuses_shared_mcp_models():
|
||||
assert RouteMcpToolConfig is McpToolConfig
|
||||
assert RouteMcpToolDefinition is McpToolDefinition
|
||||
|
||||
|
||||
def test_validate_mcp_definition_ok():
|
||||
cfg = validate_mcp_definition(
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://acme.example.com/mcp",
|
||||
"credential_uuid": "cred-123",
|
||||
"tools_filter": ["lookup_patient"],
|
||||
"timeout_secs": 30,
|
||||
"sse_read_timeout_secs": 300,
|
||||
},
|
||||
}
|
||||
)
|
||||
assert cfg["url"] == "https://acme.example.com/mcp"
|
||||
assert cfg["transport"] == "streamable_http"
|
||||
assert cfg["tools_filter"] == ["lookup_patient"]
|
||||
assert cfg["timeout_secs"] == 30
|
||||
assert cfg["sse_read_timeout_secs"] == 300
|
||||
assert cfg["credential_uuid"] == "cred-123"
|
||||
|
||||
|
||||
def test_validate_mcp_definition_defaults():
|
||||
cfg = validate_mcp_definition({"type": "mcp", "config": {"url": "https://x/mcp"}})
|
||||
assert cfg["transport"] == "streamable_http"
|
||||
assert cfg["tools_filter"] == []
|
||||
assert cfg["timeout_secs"] == 30
|
||||
assert cfg["sse_read_timeout_secs"] == 300
|
||||
assert cfg["credential_uuid"] is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"definition",
|
||||
[
|
||||
{"type": "mcp", "config": {}},
|
||||
{"type": "mcp", "config": {"url": ""}},
|
||||
{"type": "mcp", "config": {"url": "ftp://x"}},
|
||||
{"type": "mcp"},
|
||||
{"type": "mcp", "config": {"url": "https://x", "transport": "stdio"}},
|
||||
],
|
||||
)
|
||||
def test_validate_mcp_definition_rejects(definition):
|
||||
with pytest.raises(McpDefinitionError):
|
||||
validate_mcp_definition(definition)
|
||||
|
||||
|
||||
def test_validate_mcp_definition_zero_timeout_preserved():
|
||||
cfg = validate_mcp_definition(
|
||||
{"type": "mcp", "config": {"url": "https://x/mcp", "timeout_secs": 0}}
|
||||
)
|
||||
assert cfg["timeout_secs"] == 0
|
||||
|
||||
|
||||
def test_namespace_function_name():
|
||||
assert (
|
||||
namespace_function_name("Acme MCP", "lookup_patient")
|
||||
== "mcp__acme_mcp__lookup_patient"
|
||||
)
|
||||
assert (
|
||||
namespace_function_name("", "ping", fallback="abcd1234")
|
||||
== "mcp__abcd1234__ping"
|
||||
)
|
||||
437
api/tests/test_mcp_tool_route.py
Normal file
437
api/tests/test_mcp_tool_route.py
Normal file
|
|
@ -0,0 +1,437 @@
|
|||
"""Route-level tests for the MCP tool definition schema.
|
||||
|
||||
These tests exercise the Pydantic request models (CreateToolRequest /
|
||||
UpdateToolRequest) to catch schema gaps at the route/request-model layer —
|
||||
the layer where the pre-fix defect lived (HTTP 422 on every MCP tool
|
||||
creation attempt).
|
||||
|
||||
Test coverage:
|
||||
- CreateToolRequest validates a valid MCP definition (was 422 before Part A).
|
||||
- UpdateToolRequest validates a valid MCP definition.
|
||||
- Invalid MCP bodies are rejected (ftp:// url, missing url).
|
||||
- Round-trip: validated definition dict passes through validate_mcp_definition
|
||||
unchanged, proving the request schema and call-time validator agree.
|
||||
- Full HTTP round-trip via the ASGI test client (POST /api/v1/tools/).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.routes.tool import CreateToolRequest, McpToolDefinition, UpdateToolRequest
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
||||
# ── Canonical valid MCP request body ─────────────────────────────────────────
|
||||
|
||||
VALID_MCP_DEFINITION = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://x/mcp",
|
||||
"credential_uuid": None,
|
||||
"tools_filter": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ── Part A regression: CreateToolRequest / UpdateToolRequest validation ───────
|
||||
|
||||
|
||||
def test_create_tool_request_accepts_mcp_definition():
|
||||
"""CreateToolRequest must accept an MCP definition (was HTTP 422 before fix)."""
|
||||
req = CreateToolRequest(
|
||||
name="My MCP Tool",
|
||||
description="Integration via MCP",
|
||||
category="mcp",
|
||||
definition=VALID_MCP_DEFINITION,
|
||||
)
|
||||
assert isinstance(req.definition, McpToolDefinition)
|
||||
assert req.definition.type == "mcp"
|
||||
assert req.definition.config.url == "https://x/mcp"
|
||||
assert req.definition.config.transport == "streamable_http"
|
||||
assert req.definition.config.credential_uuid is None
|
||||
assert req.definition.config.tools_filter == []
|
||||
assert req.definition.config.timeout_secs == 30
|
||||
assert req.definition.config.sse_read_timeout_secs == 300
|
||||
|
||||
|
||||
def test_update_tool_request_accepts_mcp_definition():
|
||||
"""UpdateToolRequest must also accept an MCP definition."""
|
||||
req = UpdateToolRequest(
|
||||
name="Updated MCP Tool",
|
||||
definition=VALID_MCP_DEFINITION,
|
||||
)
|
||||
assert isinstance(req.definition, McpToolDefinition)
|
||||
assert req.definition.type == "mcp"
|
||||
assert req.definition.config.url == "https://x/mcp"
|
||||
|
||||
|
||||
def test_create_tool_request_accepts_mcp_with_all_fields():
|
||||
"""All optional MCP config fields are accepted and preserved."""
|
||||
req = CreateToolRequest(
|
||||
name="Full MCP Tool",
|
||||
category="mcp",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://acme.example.com/mcp",
|
||||
"credential_uuid": "cred-abc-123",
|
||||
"tools_filter": ["lookup_patient", "schedule_appointment"],
|
||||
"timeout_secs": 60,
|
||||
"sse_read_timeout_secs": 600,
|
||||
},
|
||||
},
|
||||
)
|
||||
cfg = req.definition.config # type: ignore[union-attr]
|
||||
assert cfg.url == "https://acme.example.com/mcp"
|
||||
assert cfg.credential_uuid == "cred-abc-123"
|
||||
assert cfg.tools_filter == ["lookup_patient", "schedule_appointment"]
|
||||
assert cfg.timeout_secs == 60
|
||||
assert cfg.sse_read_timeout_secs == 600
|
||||
|
||||
|
||||
# ── Invalid bodies are rejected ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"definition",
|
||||
[
|
||||
# ftp:// URL — rejected by McpToolConfig.validate_url
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http", "url": "ftp://x/mcp"},
|
||||
},
|
||||
# Empty url — rejected by McpToolConfig.validate_url
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http", "url": ""},
|
||||
},
|
||||
# Missing url — rejected by McpToolConfig (required field)
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http"},
|
||||
},
|
||||
# Unsupported transport — rejected because Literal["streamable_http"] constraint
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"url": "https://x/mcp", "transport": "stdio"},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_create_tool_request_rejects_invalid_mcp_definition(definition):
|
||||
"""Invalid MCP definitions must raise ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
CreateToolRequest(
|
||||
name="Bad MCP Tool",
|
||||
category="mcp",
|
||||
definition=definition,
|
||||
)
|
||||
|
||||
|
||||
# ── Round-trip compatibility: request schema ↔ validate_mcp_definition ───────
|
||||
|
||||
|
||||
def test_mcp_definition_round_trips_through_validate_mcp_definition():
|
||||
"""The dict produced by CreateToolRequest.definition.model_dump() must be
|
||||
accepted by validate_mcp_definition without raising, and the result must
|
||||
contain the expected fields. This proves the request-layer schema and the
|
||||
call-time validator agree on the stored config shape."""
|
||||
req = CreateToolRequest(
|
||||
name="Round-Trip MCP Tool",
|
||||
category="mcp",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://roundtrip.example.com/mcp",
|
||||
"credential_uuid": "cred-rt-456",
|
||||
"tools_filter": ["ping"],
|
||||
"timeout_secs": 45,
|
||||
"sse_read_timeout_secs": 400,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Simulate what the route does: persist definition as a plain dict
|
||||
persisted = req.definition.model_dump() # type: ignore[union-attr]
|
||||
|
||||
# validate_mcp_definition must accept the persisted shape without raising
|
||||
normalized = validate_mcp_definition(persisted)
|
||||
|
||||
assert normalized["url"] == "https://roundtrip.example.com/mcp"
|
||||
assert normalized["transport"] == "streamable_http"
|
||||
assert normalized["credential_uuid"] == "cred-rt-456"
|
||||
assert normalized["tools_filter"] == ["ping"]
|
||||
assert normalized["timeout_secs"] == 45
|
||||
assert normalized["sse_read_timeout_secs"] == 400
|
||||
|
||||
|
||||
def test_mcp_definition_round_trip_defaults():
|
||||
"""Round-trip with minimal body: defaults fill in correctly and
|
||||
validate_mcp_definition agrees on them."""
|
||||
req = CreateToolRequest(
|
||||
name="Minimal MCP Tool",
|
||||
category="mcp",
|
||||
definition=VALID_MCP_DEFINITION,
|
||||
)
|
||||
|
||||
persisted = req.definition.model_dump() # type: ignore[union-attr]
|
||||
normalized = validate_mcp_definition(persisted)
|
||||
|
||||
assert normalized["transport"] == "streamable_http"
|
||||
assert normalized["tools_filter"] == []
|
||||
assert normalized["timeout_secs"] == 30
|
||||
assert normalized["sse_read_timeout_secs"] == 300
|
||||
assert normalized["credential_uuid"] is None
|
||||
# Part B: auth_header / auth_scheme must NOT be present in the normalized
|
||||
# config dict (they were dead config removed in the fix)
|
||||
assert "auth_header" not in normalized
|
||||
assert "auth_scheme" not in normalized
|
||||
|
||||
|
||||
# ── Full HTTP round-trip via ASGI test client ─────────────────────────────────
|
||||
|
||||
|
||||
async def test_post_tool_mcp_returns_200(test_client_factory, db_session):
|
||||
"""POST /api/v1/tools/ with an MCP definition must return HTTP 200 and
|
||||
persist the definition with type='mcp'. Before Part A this always
|
||||
returned 422."""
|
||||
# Create a user and an organization, then link them so the route's
|
||||
# selected_organization_id check passes.
|
||||
user, _ = await db_session.get_or_create_user_by_provider_id("mcp_route_test_user")
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
"mcp_route_test_org", user.id
|
||||
)
|
||||
await db_session.update_user_selected_organization(user.id, org.id)
|
||||
# Reload the user so selected_organization_id is populated on the object.
|
||||
user = await db_session.get_user_by_id(user.id)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/tools/",
|
||||
json={
|
||||
"name": "HTTP Round-Trip MCP Tool",
|
||||
"description": "Testing the full route",
|
||||
"category": "mcp",
|
||||
"definition": {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://roundtrip.example.com/mcp",
|
||||
"credential_uuid": None,
|
||||
"tools_filter": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, (
|
||||
f"Expected 200, got {response.status_code}: {response.text}"
|
||||
)
|
||||
body = response.json()
|
||||
assert body["definition"]["type"] == "mcp"
|
||||
assert body["definition"]["config"]["url"] == "https://roundtrip.example.com/mcp"
|
||||
assert body["category"] == "mcp"
|
||||
|
||||
|
||||
async def test_post_tool_mcp_invalid_url_returns_422(test_client_factory, db_session):
|
||||
"""POST /api/v1/tools/ with an ftp:// URL must return HTTP 422."""
|
||||
user, _ = await db_session.get_or_create_user_by_provider_id(
|
||||
"mcp_route_test_user_422"
|
||||
)
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
"mcp_route_test_org_422", user.id
|
||||
)
|
||||
await db_session.update_user_selected_organization(user.id, org.id)
|
||||
user = await db_session.get_user_by_id(user.id)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/tools/",
|
||||
json={
|
||||
"name": "Bad MCP Tool",
|
||||
"category": "mcp",
|
||||
"definition": {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "ftp://invalid.example.com/mcp",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ── Task 6: discovered_tools field and _populate_discovered_tools helper ──────
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from api.routes.tool import McpToolConfig, _populate_discovered_tools
|
||||
|
||||
|
||||
def test_mcp_config_accepts_discovered_tools():
|
||||
cfg = McpToolConfig(
|
||||
url="https://x/mcp",
|
||||
discovered_tools=[{"name": "echo", "description": "Echo"}],
|
||||
)
|
||||
assert cfg.discovered_tools == [{"name": "echo", "description": "Echo"}]
|
||||
# Defaults to [] when omitted
|
||||
assert McpToolConfig(url="https://x/mcp").discovered_tools == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_overwrites_cache(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
|
||||
)
|
||||
definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"url": "https://x/mcp",
|
||||
"tools_filter": [],
|
||||
"discovered_tools": [{"name": "stale", "description": "old"}],
|
||||
},
|
||||
}
|
||||
out = await _populate_discovered_tools(definition, organization_id=1)
|
||||
assert out["config"]["discovered_tools"] == [
|
||||
{"name": "echo", "description": "Echo"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_non_mcp_is_noop():
|
||||
definition = {"schema_version": 1, "type": "http_api", "config": {}}
|
||||
out = await _populate_discovered_tools(definition, organization_id=1)
|
||||
assert out == definition # untouched
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_server_down_sets_empty(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(side_effect=RuntimeError("connection refused")),
|
||||
)
|
||||
definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"url": "https://x/mcp", "tools_filter": []},
|
||||
}
|
||||
out = await _populate_discovered_tools(definition, organization_id=1)
|
||||
assert out["config"]["discovered_tools"] == []
|
||||
|
||||
|
||||
# ── Task 7: POST /{tool_uuid}/mcp/refresh ─────────────────────────────────────
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.routes.tool import refresh_mcp_tools
|
||||
|
||||
|
||||
def _fake_user(org_id=1):
|
||||
u = MagicMock()
|
||||
u.selected_organization_id = org_id
|
||||
u.id = 1
|
||||
u.provider_id = "p1"
|
||||
return u
|
||||
|
||||
|
||||
def _mcp_tool_model(org_id=1):
|
||||
t = MagicMock()
|
||||
t.tool_uuid = "tu-mcp"
|
||||
t.name = "Mock MCP"
|
||||
t.category = "mcp"
|
||||
t.definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"url": "https://x/mcp", "tools_filter": []},
|
||||
}
|
||||
return t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_success(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client,
|
||||
"update_tool",
|
||||
AsyncMock(return_value=tool),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
|
||||
)
|
||||
resp = await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert resp.discovered_tools == [{"name": "echo", "description": "Echo"}]
|
||||
assert resp.error is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_server_down_returns_200_with_error(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
monkeypatch.setattr(tool_mod.db_client, "update_tool", AsyncMock(return_value=tool))
|
||||
monkeypatch.setattr(tool_mod, "discover_mcp_tools", AsyncMock(return_value=[]))
|
||||
resp = await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert resp.discovered_tools == []
|
||||
assert resp.error # non-empty human-readable message
|
||||
# update_tool should NOT be called when discovery returns empty
|
||||
tool_mod.db_client.update_tool.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_non_mcp_is_400(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
tool.category = "http_api"
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert ei.value.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_not_found_is_404(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=None)
|
||||
)
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
await refresh_mcp_tools("nope", user=_fake_user())
|
||||
assert ei.value.status_code == 404
|
||||
274
api/tests/test_mcp_tool_session.py
Normal file
274
api/tests/test_mcp_tool_session.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.mcp_tool_session import (
|
||||
McpToolSession,
|
||||
build_streamable_http_params,
|
||||
discover_mcp_tools,
|
||||
)
|
||||
from api.tests.support.mcp_mock_server import running_mcp_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_server_starts_and_serves():
|
||||
async with running_mcp_server() as base_url:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(base_url, timeout=5.0)
|
||||
assert resp.status_code in (400, 404, 405, 406)
|
||||
|
||||
|
||||
def test_build_streamable_http_params_with_credential():
|
||||
cred = MagicMock()
|
||||
cred.credential_type = "bearer_token"
|
||||
cred.credential_data = {"token": "abc"}
|
||||
params = build_streamable_http_params(
|
||||
url="https://acme.example.com/mcp",
|
||||
credential=cred,
|
||||
timeout_secs=30,
|
||||
sse_read_timeout_secs=300,
|
||||
)
|
||||
assert params.url == "https://acme.example.com/mcp"
|
||||
assert params.headers == {"Authorization": "Bearer abc"}
|
||||
assert params.timeout == timedelta(seconds=30)
|
||||
assert params.sse_read_timeout == timedelta(seconds=300)
|
||||
|
||||
|
||||
def test_build_streamable_http_params_no_credential():
|
||||
params = build_streamable_http_params(
|
||||
url="https://acme.example.com/mcp",
|
||||
credential=None,
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
assert params.headers is None or params.headers == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_start_passes_auth_header_to_real_server():
|
||||
cred = MagicMock()
|
||||
cred.credential_type = "bearer_token"
|
||||
cred.credential_data = {"token": "abc"}
|
||||
|
||||
async with running_mcp_server(
|
||||
required_headers={"Authorization": "Bearer abc"}
|
||||
) as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-auth-ok",
|
||||
tool_name="Secure MCP",
|
||||
url=base_url,
|
||||
credential=cred,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
assert session.available is True
|
||||
names = sorted(s.name for s in session.function_schemas())
|
||||
assert names == ["mcp__secure_mcp__add", "mcp__secure_mcp__echo"]
|
||||
result = await session.call("mcp__secure_mcp__echo", {"text": "hi"})
|
||||
assert "echo:hi" in result
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_auth_failure_degrades_not_raises():
|
||||
async with running_mcp_server(
|
||||
required_headers={"Authorization": "Bearer abc"}
|
||||
) as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-auth-fail",
|
||||
tool_name="Secure MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=2,
|
||||
sse_read_timeout_secs=2,
|
||||
)
|
||||
await session.start() # must degrade instead of raising on 401
|
||||
try:
|
||||
assert session.available is False
|
||||
assert session.function_schemas() == []
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_start_lists_and_calls_real_server():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
assert session.available is True
|
||||
schemas = session.function_schemas()
|
||||
names = sorted(s.name for s in schemas)
|
||||
assert names == ["mcp__acme_mcp__add", "mcp__acme_mcp__echo"]
|
||||
result = await session.call("mcp__acme_mcp__echo", {"text": "hi"})
|
||||
assert "echo:hi" in result
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_tools_filter_applied():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=["echo"],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
names = sorted(s.name for s in session.function_schemas())
|
||||
assert names == ["mcp__acme_mcp__echo"]
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_unreachable_degrades_not_raises():
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=2,
|
||||
sse_read_timeout_secs=2,
|
||||
)
|
||||
await session.start() # must NOT raise
|
||||
assert session.available is False
|
||||
assert session.function_schemas() == []
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_on_unavailable_session_raises():
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=2,
|
||||
sse_read_timeout_secs=2,
|
||||
)
|
||||
await session.start()
|
||||
with pytest.raises(RuntimeError):
|
||||
await session.call("mcp__acme_mcp__echo", {"text": "x"})
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_unknown_function_raises():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
with pytest.raises(RuntimeError):
|
||||
await session.call("mcp__acme_mcp__does_not_exist", {})
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_schemas_filter_by_raw_name():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="t-filter",
|
||||
tool_name="Mock MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
# No arg = all (backward compatible)
|
||||
all_names = sorted(s.name for s in session.function_schemas())
|
||||
assert all_names == ["mcp__mock_mcp__add", "mcp__mock_mcp__echo"]
|
||||
|
||||
# Allow only raw "echo"
|
||||
only_echo = session.function_schemas(allowed_raw_names={"echo"})
|
||||
assert [s.name for s in only_echo] == ["mcp__mock_mcp__echo"]
|
||||
|
||||
# Empty set = none (default-none semantics)
|
||||
assert session.function_schemas(allowed_raw_names=set()) == []
|
||||
|
||||
# Unknown raw name = skipped (pure intersection)
|
||||
assert session.function_schemas(allowed_raw_names={"nope"}) == []
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_mcp_tools_success():
|
||||
async with running_mcp_server() as base_url:
|
||||
tools = await discover_mcp_tools(
|
||||
url=base_url,
|
||||
credential=None,
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
names = sorted(t["name"] for t in tools)
|
||||
assert names == ["add", "echo"]
|
||||
by_name = {t["name"]: t for t in tools}
|
||||
assert by_name["echo"]["description"] # non-empty description
|
||||
assert set(by_name["echo"]) == {"name", "description"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_mcp_tools_server_down_returns_empty():
|
||||
# Unroutable port, short timeouts: must degrade to [] (never raise).
|
||||
tools = await discover_mcp_tools(
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
timeout_secs=1,
|
||||
sse_read_timeout_secs=1,
|
||||
)
|
||||
assert tools == []
|
||||
|
||||
|
||||
def test_agent_node_data_carries_mcp_tool_filters():
|
||||
from api.services.workflow.dto import AgentNodeData, NodeType
|
||||
from api.services.workflow.workflow_graph import Node
|
||||
|
||||
data = AgentNodeData(
|
||||
name="N1",
|
||||
tool_uuids=["tu-1"],
|
||||
mcp_tool_filters={"tu-1": ["echo"]},
|
||||
)
|
||||
assert data.mcp_tool_filters == {"tu-1": ["echo"]}
|
||||
|
||||
node = Node("n1", NodeType.agentNode, data)
|
||||
assert node.mcp_tool_filters == {"tu-1": ["echo"]}
|
||||
|
||||
# Absent field defaults to None (backward compatible)
|
||||
data2 = AgentNodeData(name="N2")
|
||||
assert data2.mcp_tool_filters is None
|
||||
assert Node("n2", NodeType.agentNode, data2).mcp_tool_filters is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue