mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat(mcp): generic MCP tool source with per-node function filtering (#301)
* feat(mcp): generic MCP tool source with per-node function filtering
Adds a Model Context Protocol tool category: connect a customer MCP
server and expose its tools to the agent, with optional per-node
allow-listing of individual MCP functions.
- ToolCategory.MCP enum + alembic migration
- MCP definition validator and collision-safe function-name namespacing
- McpToolSession wrapper: graceful-degrade, per-call open/close lifecycle
- CustomToolManager MCP branch (schemas + proxy handlers)
- Per-node mcp_tool_filters threaded through DTO/graph/engine
- Best-effort discovered_tools catalog cache + POST /tools/{uuid}/mcp/refresh
- UI: MCP create/edit config, tabbed ToolSelector with per-node toggles
* feat: refactor for code standardisation and documentation
---------
Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
parent
0097974444
commit
75839f9de5
40 changed files with 3028 additions and 137 deletions
|
|
@ -20,7 +20,7 @@ RUN pip install --user --no-cache-dir -r requirements.txt && \
|
|||
|
||||
# Copy and install pipecat from local submodule
|
||||
COPY pipecat /tmp/pipecat
|
||||
RUN pip install --user --no-cache-dir '/tmp/pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb]' && \
|
||||
RUN pip install --user --no-cache-dir '/tmp/pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb,mcp]' && \
|
||||
# Swap opencv-python (pulled by pipecat[webrtc]) for opencv-python-headless
|
||||
# to drop X11/Qt dependencies that otherwise require libxcb etc. in runner.
|
||||
pip uninstall -y opencv-python && \
|
||||
|
|
|
|||
64
api/alembic/versions/0a1b2c3d4e5f_add_mcp_in_toolcategory.py
Normal file
64
api/alembic/versions/0a1b2c3d4e5f_add_mcp_in_toolcategory.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"""add mcp in ToolCategory
|
||||
|
||||
Revision ID: 0a1b2c3d4e5f
|
||||
Revises: 4c1f1e3e8ef2
|
||||
Create Date: 2026-05-16 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from alembic_postgresql_enum import TableReference
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0a1b2c3d4e5f"
|
||||
down_revision: Union[str, None] = "4c1f1e3e8ef2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.sync_enum_values(
|
||||
enum_schema="public",
|
||||
enum_name="tool_category",
|
||||
new_values=[
|
||||
"http_api",
|
||||
"end_call",
|
||||
"transfer_call",
|
||||
"calculator",
|
||||
"native",
|
||||
"integration",
|
||||
"mcp",
|
||||
],
|
||||
affected_columns=[
|
||||
TableReference(
|
||||
table_schema="public", table_name="tools", column_name="category"
|
||||
)
|
||||
],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.sync_enum_values(
|
||||
enum_schema="public",
|
||||
enum_name="tool_category",
|
||||
new_values=[
|
||||
"http_api",
|
||||
"end_call",
|
||||
"transfer_call",
|
||||
"calculator",
|
||||
"native",
|
||||
"integration",
|
||||
],
|
||||
affected_columns=[
|
||||
TableReference(
|
||||
table_schema="public", table_name="tools", column_name="category"
|
||||
)
|
||||
],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
|
|
@ -133,6 +133,7 @@ class ToolCategory(Enum):
|
|||
CALCULATOR = "calculator" # Built-in calculator tool
|
||||
NATIVE = "native" # Built-in integrations (future: dtmf_input)
|
||||
INTEGRATION = "integration" # Third-party integrations (future: Google Calendar, Salesforce, etc.)
|
||||
MCP = "mcp" # Customer-provided MCP server exposing a tool catalog
|
||||
|
||||
|
||||
class ToolStatus(Enum):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ async def authenticate_mcp_request() -> UserModel:
|
|||
the `langfuse.user.id` / `langfuse.session.id` attributes make the
|
||||
span filterable in the Langfuse UI.
|
||||
"""
|
||||
headers = get_http_headers()
|
||||
# FastMCP strips Authorization by default unless explicitly included.
|
||||
# Preserve it here so Bearer API keys work for MCP tool invocations.
|
||||
headers = get_http_headers(include={"authorization"})
|
||||
api_key = headers.get("x-api-key")
|
||||
if not api_key:
|
||||
auth = headers.get("authorization", "")
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
"""API routes for managing tools."""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from api.db import db_client
|
||||
|
|
@ -13,9 +15,23 @@ from api.enums import PostHogEvent, ToolCategory, ToolStatus
|
|||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.workflow.mcp_tool_session import discover_mcp_tools
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpDefinitionError,
|
||||
validate_mcp_definition,
|
||||
)
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpToolConfig as SharedMcpToolConfig,
|
||||
)
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpToolDefinition as SharedMcpToolDefinition,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/tools")
|
||||
|
||||
McpToolConfig = SharedMcpToolConfig
|
||||
McpToolDefinition = SharedMcpToolDefinition
|
||||
|
||||
|
||||
# Request/Response schemas
|
||||
class ToolParameter(BaseModel):
|
||||
|
|
@ -183,6 +199,7 @@ ToolDefinition = Annotated[
|
|||
EndCallToolDefinition,
|
||||
TransferCallToolDefinition,
|
||||
CalculatorToolDefinition,
|
||||
McpToolDefinition,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
@ -248,6 +265,14 @@ class ToolResponse(BaseModel):
|
|||
from_attributes = True
|
||||
|
||||
|
||||
class McpRefreshResponse(BaseModel):
|
||||
"""Result of re-discovering an MCP server's tool catalog."""
|
||||
|
||||
tool_uuid: str
|
||||
discovered_tools: list = Field(default_factory=list)
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def build_tool_response(tool, include_created_by: bool = False) -> ToolResponse:
|
||||
"""Build a response from a tool model."""
|
||||
created_by = None
|
||||
|
|
@ -336,6 +361,52 @@ async def list_tools(
|
|||
return [build_tool_response(tool) for tool in tools]
|
||||
|
||||
|
||||
async def _fetch_credential(credential_uuid: Optional[str], organization_id: int):
|
||||
"""Best-effort credential lookup for MCP auth. A missing/failed credential
|
||||
degrades to ``None`` (unauthenticated) rather than failing the request."""
|
||||
if not credential_uuid:
|
||||
return None
|
||||
try:
|
||||
return await db_client.get_credential_by_uuid(credential_uuid, organization_id)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(f"MCP: credential fetch failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _populate_discovered_tools(definition: dict, *, organization_id: int) -> dict:
|
||||
"""Best-effort: for an MCP definition, connect to the server, list its
|
||||
tools, and overwrite ``config.discovered_tools``. Never raises and never
|
||||
blocks tool save — a dead server yields ``discovered_tools: []``. Non-MCP
|
||||
definitions pass through untouched."""
|
||||
if not isinstance(definition, dict) or definition.get("type") != "mcp":
|
||||
return definition
|
||||
try:
|
||||
cfg = validate_mcp_definition(definition)
|
||||
except McpDefinitionError:
|
||||
return definition
|
||||
|
||||
credential = await _fetch_credential(cfg.get("credential_uuid"), organization_id)
|
||||
|
||||
# Run discovery in an isolated asyncio task so an anyio cancel-scope
|
||||
# CancelledError doesn't bleed into the parent task and corrupt the
|
||||
# subsequent DB write. _run() never raises (degrades to []).
|
||||
async def _run() -> list:
|
||||
try:
|
||||
return await discover_mcp_tools(
|
||||
url=cfg["url"],
|
||||
credential=credential,
|
||||
timeout_secs=cfg["timeout_secs"],
|
||||
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
||||
)
|
||||
except BaseException as e: # noqa: BLE001
|
||||
logger.warning(f"MCP discovery failed; caching empty list: {e}")
|
||||
return []
|
||||
|
||||
discovered = await asyncio.ensure_future(_run())
|
||||
definition["config"]["discovered_tools"] = discovered
|
||||
return definition
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_tool(
|
||||
request: CreateToolRequest,
|
||||
|
|
@ -357,11 +428,16 @@ async def create_tool(
|
|||
|
||||
validate_category(request.category)
|
||||
|
||||
definition = await _populate_discovered_tools(
|
||||
request.definition.model_dump(),
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
tool = await db_client.create_tool(
|
||||
organization_id=user.selected_organization_id,
|
||||
user_id=user.id,
|
||||
name=request.name,
|
||||
definition=request.definition.model_dump(),
|
||||
definition=definition,
|
||||
category=request.category,
|
||||
description=request.description,
|
||||
icon=request.icon,
|
||||
|
|
@ -410,6 +486,67 @@ async def get_tool(
|
|||
return build_tool_response(tool, include_created_by=True)
|
||||
|
||||
|
||||
@router.post("/{tool_uuid}/mcp/refresh")
|
||||
async def refresh_mcp_tools(
|
||||
tool_uuid: str,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> McpRefreshResponse:
|
||||
"""Re-discover an MCP tool's server catalog and overwrite the cached
|
||||
``definition.config.discovered_tools``. Server down → 200 with error
|
||||
(cache not overwritten on transient failure)."""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
tool = await db_client.get_tool_by_uuid(
|
||||
tool_uuid, user.selected_organization_id, include_archived=True
|
||||
)
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
if tool.category != ToolCategory.MCP.value:
|
||||
raise HTTPException(status_code=400, detail="Tool is not an MCP tool")
|
||||
|
||||
try:
|
||||
cfg = validate_mcp_definition(tool.definition)
|
||||
except McpDefinitionError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid MCP definition: {e}")
|
||||
|
||||
credential = await _fetch_credential(
|
||||
cfg.get("credential_uuid"), user.selected_organization_id
|
||||
)
|
||||
|
||||
try:
|
||||
discovered = await discover_mcp_tools(
|
||||
url=cfg["url"],
|
||||
credential=credential,
|
||||
timeout_secs=cfg["timeout_secs"],
|
||||
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(f"MCP refresh discovery failed: {e}")
|
||||
discovered = []
|
||||
|
||||
if not discovered:
|
||||
error = (
|
||||
f"Could not reach the MCP server at {cfg['url']} "
|
||||
f"(or it exposes no tools). Previously cached list retained."
|
||||
)
|
||||
# Do NOT clobber a previously-good cache with [] on a transient outage.
|
||||
return McpRefreshResponse(tool_uuid=tool_uuid, discovered_tools=[], error=error)
|
||||
|
||||
new_def = dict(tool.definition or {})
|
||||
new_def["config"] = {**new_def.get("config", {}), "discovered_tools": discovered}
|
||||
await db_client.update_tool(
|
||||
tool_uuid=tool_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
definition=new_def,
|
||||
)
|
||||
return McpRefreshResponse(
|
||||
tool_uuid=tool_uuid, discovered_tools=discovered, error=None
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{tool_uuid}")
|
||||
async def update_tool(
|
||||
tool_uuid: str,
|
||||
|
|
@ -434,12 +571,21 @@ async def update_tool(
|
|||
if request.status:
|
||||
validate_status(request.status)
|
||||
|
||||
definition = (
|
||||
await _populate_discovered_tools(
|
||||
request.definition.model_dump(),
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
if request.definition
|
||||
else None
|
||||
)
|
||||
|
||||
tool = await db_client.update_tool(
|
||||
tool_uuid=tool_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
definition=request.definition.model_dump() if request.definition else None,
|
||||
definition=definition,
|
||||
icon=request.icon,
|
||||
icon_color=request.icon_color,
|
||||
status=request.status,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
from typing import Annotated, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, model_validator
|
||||
|
||||
|
|
@ -69,6 +69,7 @@ class _ExtractionNodeDataMixin(BaseModel):
|
|||
class _ToolDocumentRefsMixin(BaseModel):
|
||||
tool_uuids: Optional[List[str]] = None
|
||||
document_uuids: Optional[List[str]] = None
|
||||
mcp_tool_filters: Optional[Dict[str, List[str]]] = None
|
||||
|
||||
|
||||
class StartCallNodeData(
|
||||
|
|
|
|||
254
api/services/workflow/mcp_tool_session.py
Normal file
254
api/services/workflow/mcp_tool_session.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
"""Single unit that knows the MCP protocol + credentials.
|
||||
|
||||
Wraps the vendored Pipecat ``MCPClient`` for connection/session, builds
|
||||
streamable-HTTP params from a Dograh credential, exposes namespaced
|
||||
``FunctionSchema``s, and proxies tool calls. Connection failures degrade
|
||||
(``available = False``) instead of raising — the call must survive a
|
||||
dead MCP server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
||||
|
||||
from loguru import logger
|
||||
from mcp.client.session_group import StreamableHttpParameters
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.services.mcp_service import MCPClient
|
||||
|
||||
from api.services.workflow.tools.mcp_tool import namespace_function_name
|
||||
from api.utils.credential_auth import build_auth_header
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.db.models import ExternalCredentialModel
|
||||
|
||||
|
||||
def build_streamable_http_params(
|
||||
*,
|
||||
url: str,
|
||||
credential: Optional["ExternalCredentialModel"],
|
||||
timeout_secs: int,
|
||||
sse_read_timeout_secs: int,
|
||||
) -> StreamableHttpParameters:
|
||||
"""Build Pipecat/MCP streamable-HTTP params, injecting the auth header
|
||||
from an ExternalCredentialModel (reuses the http_api credential path)."""
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
if credential is not None:
|
||||
auth = build_auth_header(credential)
|
||||
headers = auth or None
|
||||
return StreamableHttpParameters(
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=timedelta(seconds=timeout_secs),
|
||||
sse_read_timeout=timedelta(seconds=sse_read_timeout_secs),
|
||||
)
|
||||
|
||||
|
||||
class McpToolSession:
|
||||
"""One live MCP server connection for the duration of a call."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tool_uuid: str,
|
||||
tool_name: str,
|
||||
url: str,
|
||||
credential: Optional["ExternalCredentialModel"],
|
||||
tools_filter: List[str],
|
||||
timeout_secs: int,
|
||||
sse_read_timeout_secs: int,
|
||||
) -> None:
|
||||
self._tool_uuid = tool_uuid
|
||||
self._tool_name = tool_name
|
||||
self._url = url
|
||||
self._credential = credential
|
||||
# An empty list is intentionally treated as "no filter (expose all
|
||||
# tools)" — Pipecat's MCPClient applies a filter only when this is a
|
||||
# non-empty list, so [] and None are equivalent ("all tools").
|
||||
self._tools_filter = tools_filter or None
|
||||
self._timeout_secs = timeout_secs
|
||||
self._sse_read_timeout_secs = sse_read_timeout_secs
|
||||
|
||||
self._client: Optional[MCPClient] = None
|
||||
self._session: Any = None # mcp.ClientSession (read once after start)
|
||||
self._schemas: List[FunctionSchema] = []
|
||||
# namespaced LLM name -> original MCP tool name
|
||||
self._name_map: Dict[str, str] = {}
|
||||
self.available: bool = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Connect, initialize, and cache the tool list. Never raises —
|
||||
on any failure the session is marked unavailable."""
|
||||
try:
|
||||
params = build_streamable_http_params(
|
||||
url=self._url,
|
||||
credential=self._credential,
|
||||
timeout_secs=self._timeout_secs,
|
||||
sse_read_timeout_secs=self._sse_read_timeout_secs,
|
||||
)
|
||||
self._client = MCPClient(params, tools_filter=self._tools_filter)
|
||||
await self._client.start()
|
||||
# Single, isolated touch of Pipecat internals (vendored submodule).
|
||||
self._session = self._client._active_session
|
||||
tools_schema = await self._client.get_tools_schema()
|
||||
|
||||
fallback = self._tool_uuid[:8] if self._tool_uuid else "server"
|
||||
for fs in tools_schema.standard_tools:
|
||||
ns_name = namespace_function_name(
|
||||
self._tool_name, fs.name, fallback=fallback
|
||||
)
|
||||
self._name_map[ns_name] = fs.name
|
||||
self._schemas.append(
|
||||
FunctionSchema(
|
||||
name=ns_name,
|
||||
description=fs.description,
|
||||
properties=fs.properties,
|
||||
required=fs.required,
|
||||
)
|
||||
)
|
||||
self.available = True
|
||||
logger.info(
|
||||
f"MCP session ready for tool '{self._tool_name}' "
|
||||
f"({self._tool_uuid}): {sorted(self._name_map)}"
|
||||
)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
raise
|
||||
except asyncio.CancelledError as e:
|
||||
# Empirically, a dead/unreachable MCP server does NOT surface as a
|
||||
# plain Exception here. The real failure is httpx.ConnectError, but
|
||||
# anyio's streamablehttp_client task group, while tearing down that
|
||||
# ConnectError, re-surfaces it to our frame as an *internal*
|
||||
# cancel-scope CancelledError carrying the signature message
|
||||
# "Cancelled via cancel scope <id>". A genuine *external*
|
||||
# cancellation (call teardown / shutdown) is a bare CancelledError
|
||||
# (empty args) or one with an application-chosen message. Type, MRO,
|
||||
# context chain, and asyncio task.cancelling() are all identical
|
||||
# between the two, so the anyio scope-signature message is the only
|
||||
# reliable discriminator. Re-raise genuine external cancellation to
|
||||
# preserve structured concurrency; degrade only on the anyio
|
||||
# connect-teardown artifact.
|
||||
msg = "" if not e.args else str(e.args[0] or "")
|
||||
if not msg.startswith("Cancelled via cancel scope"):
|
||||
raise
|
||||
await self._degrade(e)
|
||||
except Exception as e: # noqa: BLE001 — see _degrade docstring
|
||||
# Defensive: if a future Pipecat/httpx version surfaces the connect
|
||||
# failure directly (e.g. httpx.ConnectError) instead of via the
|
||||
# anyio cancel-scope artifact above, still degrade gracefully.
|
||||
await self._degrade(e)
|
||||
|
||||
async def _degrade(self, e: BaseException) -> None:
|
||||
"""Mark this session unavailable and tear down any dangling client so
|
||||
start() leaves self._client either fully usable or None. The contract
|
||||
requires graceful degradation on any *connect* failure (never raising
|
||||
for a dead MCP server) while genuine external cancellation /
|
||||
KeyboardInterrupt / SystemExit are re-raised by the caller."""
|
||||
self.available = False
|
||||
self._schemas = []
|
||||
self._name_map = {}
|
||||
# Self-contained cleanup: _client.start() may have succeeded before a
|
||||
# later step (e.g. get_tools_schema()) failed, leaving an open client.
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._client = None
|
||||
self._session = None
|
||||
logger.warning(
|
||||
f"MCP session unavailable for tool '{self._tool_name}' "
|
||||
f"({self._tool_uuid}) at {self._url}: {e!r}. "
|
||||
f"Call proceeds without these tools."
|
||||
)
|
||||
|
||||
@property
|
||||
def call_timeout_secs(self) -> float:
|
||||
"""Pipecat function-call timeout for this server's tools. Slightly
|
||||
longer than the transport read timeout so a slow MCP call surfaces
|
||||
as a structured tool error (handled in the handler) rather than a
|
||||
hard pipeline timeout."""
|
||||
return float(self._sse_read_timeout_secs) + 5.0
|
||||
|
||||
def function_schemas(
|
||||
self, allowed_raw_names: Optional[Set[str]] = None
|
||||
) -> List[FunctionSchema]:
|
||||
"""Return cached FunctionSchemas, optionally filtered by raw MCP tool name.
|
||||
|
||||
``allowed_raw_names=None`` returns all schemas. An empty set returns none.
|
||||
Raw names are the pre-namespace MCP tool names (e.g. ``echo``, not
|
||||
``mcp__slug__echo``).
|
||||
"""
|
||||
if allowed_raw_names is None:
|
||||
return list(self._schemas)
|
||||
return [
|
||||
s for s in self._schemas if self._name_map.get(s.name) in allowed_raw_names
|
||||
]
|
||||
|
||||
def discovered_tools(self) -> List[Dict[str, str]]:
|
||||
"""Raw MCP tool catalog for UI/cache: ``[{name, description}]``
|
||||
using the *raw* server names (not the namespaced LLM names).
|
||||
Empty if the session is unavailable."""
|
||||
out: List[Dict[str, str]] = []
|
||||
for s in self._schemas:
|
||||
raw = self._name_map.get(s.name)
|
||||
if raw is None:
|
||||
continue
|
||||
out.append({"name": raw, "description": s.description or ""})
|
||||
return out
|
||||
|
||||
async def call(self, namespaced_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""Invoke an MCP tool by its namespaced LLM name. Returns a string
|
||||
(flattened text content). Raises if the session is unavailable so
|
||||
the caller can map it to a structured error for the LLM."""
|
||||
if not self.available or self._session is None:
|
||||
raise RuntimeError(f"MCP session unavailable for {namespaced_name}")
|
||||
original = self._name_map.get(namespaced_name)
|
||||
if original is None:
|
||||
raise RuntimeError(f"Unknown MCP function {namespaced_name}")
|
||||
result = await self._session.call_tool(original, arguments=arguments)
|
||||
text = ""
|
||||
for content in getattr(result, "content", []) or []:
|
||||
if getattr(content, "text", None):
|
||||
text += content.text
|
||||
return text or "Sorry, the MCP tool returned no content."
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing MCP session {self._tool_uuid}: {e}")
|
||||
finally:
|
||||
self._client = None
|
||||
self._session = None
|
||||
|
||||
|
||||
async def discover_mcp_tools(
|
||||
*,
|
||||
url: str,
|
||||
credential: Optional["ExternalCredentialModel"],
|
||||
timeout_secs: int,
|
||||
sse_read_timeout_secs: int,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Open an ephemeral MCP session, list its tools, close it. Returns
|
||||
``[{name, description}]`` (raw names). Never raises — on any connect
|
||||
failure returns ``[]``."""
|
||||
session = McpToolSession(
|
||||
tool_uuid="discover",
|
||||
tool_name="discover",
|
||||
url=url,
|
||||
credential=credential,
|
||||
tools_filter=[],
|
||||
timeout_secs=timeout_secs,
|
||||
sse_read_timeout_secs=sse_read_timeout_secs,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
if not session.available:
|
||||
return []
|
||||
return session.discovered_tools()
|
||||
finally:
|
||||
await session.close()
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Union
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import (
|
||||
|
|
@ -16,6 +16,7 @@ from pipecat.services.settings import LLMSettings
|
|||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import ToolCategory
|
||||
from api.services.pipecat.audio_playback import play_audio
|
||||
from api.services.workflow.disposition_mapper import apply_disposition_mapping
|
||||
from api.services.workflow.workflow_graph import Node, WorkflowGraph
|
||||
|
|
@ -34,6 +35,7 @@ import asyncio
|
|||
from loguru import logger
|
||||
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
compose_functions_for_node,
|
||||
compose_system_prompt_for_node,
|
||||
|
|
@ -116,6 +118,9 @@ class PipecatEngine:
|
|||
# Cached organization ID (resolved lazily from workflow run)
|
||||
self._organization_id: Optional[int] = None
|
||||
|
||||
# Open MCP tool sessions for this call, keyed by tool_uuid
|
||||
self._mcp_sessions: Dict[str, McpToolSession] = {}
|
||||
|
||||
# Embeddings configuration (passed from run_pipeline.py)
|
||||
self._embeddings_api_key: Optional[str] = embeddings_api_key
|
||||
self._embeddings_model: Optional[str] = embeddings_model
|
||||
|
|
@ -178,6 +183,9 @@ class PipecatEngine:
|
|||
# Helper that encapsulates custom tool management
|
||||
self._custom_tool_manager = CustomToolManager(self)
|
||||
|
||||
# Open persistent MCP server sessions for this call (degrades on failure)
|
||||
await self._open_mcp_sessions()
|
||||
|
||||
# Helper that encapsulates context summarization
|
||||
if self._context_compaction_enabled:
|
||||
self._context_summarization_manager = ContextSummarizationManager(self)
|
||||
|
|
@ -503,7 +511,10 @@ class PipecatEngine:
|
|||
|
||||
# Register custom tool handlers for this node
|
||||
if node.tool_uuids and self._custom_tool_manager:
|
||||
await self._custom_tool_manager.register_handlers(node.tool_uuids)
|
||||
await self._custom_tool_manager.register_handlers(
|
||||
node.tool_uuids,
|
||||
mcp_tool_filters=getattr(node, "mcp_tool_filters", None),
|
||||
)
|
||||
|
||||
# Register knowledge base retrieval handler if node has documents
|
||||
if node.document_uuids:
|
||||
|
|
@ -814,6 +825,79 @@ class PipecatEngine:
|
|||
"""Get the gathered context including extracted variables."""
|
||||
return self._gathered_context.copy()
|
||||
|
||||
async def _open_mcp_sessions(self) -> None:
|
||||
"""Connect every MCP-category tool referenced by any workflow node.
|
||||
Failures degrade (session marked unavailable); never raises."""
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpDefinitionError,
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
||||
try:
|
||||
tool_uuids: set[str] = set()
|
||||
for node in self.workflow.nodes.values():
|
||||
for tu in getattr(node, "tool_uuids", None) or []:
|
||||
tool_uuids.add(tu)
|
||||
if not tool_uuids:
|
||||
return
|
||||
|
||||
organization_id = await self._get_organization_id()
|
||||
if not organization_id:
|
||||
logger.warning("Cannot open MCP sessions: organization_id missing")
|
||||
return
|
||||
|
||||
tools = await db_client.get_tools_by_uuids(
|
||||
list(tool_uuids), organization_id
|
||||
)
|
||||
for tool in tools:
|
||||
if tool.category != ToolCategory.MCP.value:
|
||||
continue
|
||||
try:
|
||||
cfg = validate_mcp_definition(tool.definition)
|
||||
except McpDefinitionError as e:
|
||||
logger.warning(
|
||||
f"Skipping MCP tool '{tool.name}' ({tool.tool_uuid}): "
|
||||
f"invalid definition: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
credential = None
|
||||
if cfg["credential_uuid"]:
|
||||
try:
|
||||
credential = await db_client.get_credential_by_uuid(
|
||||
cfg["credential_uuid"], organization_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MCP tool '{tool.name}': credential fetch failed: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url=cfg["url"],
|
||||
credential=credential,
|
||||
tools_filter=cfg["tools_filter"],
|
||||
timeout_secs=cfg["timeout_secs"],
|
||||
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
||||
)
|
||||
await session.start()
|
||||
self._mcp_sessions[tool.tool_uuid] = session
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to open MCP sessions; call proceeds without MCP tools: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _close_mcp_sessions(self) -> None:
|
||||
for tool_uuid, session in list(self._mcp_sessions.items()):
|
||||
try:
|
||||
await session.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing MCP session {tool_uuid}: {e}")
|
||||
self._mcp_sessions = {}
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up engine resources on disconnect."""
|
||||
# Cancel any pending timeout tasks
|
||||
|
|
@ -823,6 +907,12 @@ class PipecatEngine:
|
|||
):
|
||||
self._user_response_timeout_task.cancel()
|
||||
|
||||
# Cancel any in-flight background summarization
|
||||
if self._context_summarization_manager:
|
||||
await self._context_summarization_manager.cleanup()
|
||||
# Cancel any in-flight background summarization.
|
||||
# MCP sessions are closed in a finally block so they are guaranteed to
|
||||
# run even if the summarization cleanup raises an exception.
|
||||
try:
|
||||
if self._context_summarization_manager:
|
||||
await self._context_summarization_manager.cleanup()
|
||||
finally:
|
||||
# Close any open MCP tool sessions
|
||||
await self._close_mcp_sessions()
|
||||
|
|
|
|||
|
|
@ -117,7 +117,8 @@ async def compose_functions_for_node(
|
|||
# Custom tools
|
||||
if node.tool_uuids and custom_tool_manager:
|
||||
custom_tool_schemas = await custom_tool_manager.get_tool_schemas(
|
||||
node.tool_uuids
|
||||
node.tool_uuids,
|
||||
mcp_tool_filters=getattr(node, "mcp_tool_filters", None),
|
||||
)
|
||||
functions.extend(custom_tool_schemas)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ from api.services.workflow.tools.custom_tool import (
|
|||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
|
|
@ -121,11 +122,18 @@ class CustomToolManager:
|
|||
"""Get the organization ID from the engine (shared cache)."""
|
||||
return await self._engine._get_organization_id()
|
||||
|
||||
async def get_tool_schemas(self, tool_uuids: list[str]) -> list[FunctionSchema]:
|
||||
async def get_tool_schemas(
|
||||
self,
|
||||
tool_uuids: list[str],
|
||||
mcp_tool_filters: Optional[dict[str, list[str]]] = None,
|
||||
) -> list[FunctionSchema]:
|
||||
"""Fetch custom tools and convert them to function schemas.
|
||||
|
||||
Args:
|
||||
tool_uuids: List of tool UUIDs to fetch
|
||||
mcp_tool_filters: Optional per-node filter mapping tool_uuid → list of
|
||||
raw MCP tool names to expose. None (default) exposes all tools.
|
||||
Empty dict or entry with [] suppresses all tools for that uuid.
|
||||
|
||||
Returns:
|
||||
List of FunctionSchema objects for LLM
|
||||
|
|
@ -154,6 +162,22 @@ class CustomToolManager:
|
|||
)
|
||||
continue
|
||||
|
||||
if tool.category == ToolCategory.MCP.value:
|
||||
session = self._engine._mcp_sessions.get(tool.tool_uuid)
|
||||
if session is None or not session.available:
|
||||
logger.warning(
|
||||
f"MCP tool '{tool.name}' ({tool.tool_uuid}) "
|
||||
f"unavailable; skipping"
|
||||
)
|
||||
continue
|
||||
allowed = (
|
||||
None
|
||||
if mcp_tool_filters is None
|
||||
else set(mcp_tool_filters.get(tool.tool_uuid, []))
|
||||
)
|
||||
schemas.extend(session.function_schemas(allowed))
|
||||
continue
|
||||
|
||||
raw_schema = tool_to_function_schema(tool)
|
||||
function_name = raw_schema["function"]["name"]
|
||||
|
||||
|
|
@ -178,11 +202,18 @@ class CustomToolManager:
|
|||
logger.error(f"Failed to fetch custom tools: {e}")
|
||||
return []
|
||||
|
||||
async def register_handlers(self, tool_uuids: list[str]) -> None:
|
||||
async def register_handlers(
|
||||
self,
|
||||
tool_uuids: list[str],
|
||||
mcp_tool_filters: Optional[dict[str, list[str]]] = None,
|
||||
) -> None:
|
||||
"""Register custom tool execution handlers with the LLM.
|
||||
|
||||
Args:
|
||||
tool_uuids: List of tool UUIDs to register handlers for
|
||||
mcp_tool_filters: Optional per-node filter mapping tool_uuid → list of
|
||||
raw MCP tool names to expose. None (default) exposes all tools.
|
||||
Empty dict or entry with [] suppresses all tools for that uuid.
|
||||
"""
|
||||
organization_id = await self.get_organization_id()
|
||||
if not organization_id:
|
||||
|
|
@ -203,6 +234,32 @@ class CustomToolManager:
|
|||
)
|
||||
continue
|
||||
|
||||
if tool.category == ToolCategory.MCP.value:
|
||||
session = self._engine._mcp_sessions.get(tool.tool_uuid)
|
||||
if session is None or not session.available:
|
||||
logger.warning(
|
||||
f"MCP tool '{tool.name}' ({tool.tool_uuid}) "
|
||||
f"unavailable; skipping handler registration"
|
||||
)
|
||||
continue
|
||||
allowed = (
|
||||
None
|
||||
if mcp_tool_filters is None
|
||||
else set(mcp_tool_filters.get(tool.tool_uuid, []))
|
||||
)
|
||||
mcp_schemas = session.function_schemas(allowed)
|
||||
for fs in mcp_schemas:
|
||||
self._engine.llm.register_function(
|
||||
fs.name,
|
||||
self._create_mcp_handler(session, fs.name),
|
||||
timeout_secs=session.call_timeout_secs,
|
||||
)
|
||||
logger.debug(
|
||||
f"Registered {len(mcp_schemas)} MCP "
|
||||
f"handlers for tool '{tool.name}' ({tool.tool_uuid})"
|
||||
)
|
||||
continue
|
||||
|
||||
schema = tool_to_function_schema(tool)
|
||||
function_name = schema["function"]["name"]
|
||||
|
||||
|
|
@ -335,6 +392,29 @@ class CustomToolManager:
|
|||
|
||||
return http_tool_handler
|
||||
|
||||
def _create_mcp_handler(self, session: "McpToolSession", function_name: str):
|
||||
"""Create a handler that proxies an LLM function call to a live MCP
|
||||
session. Errors are returned to the LLM as structured text so the
|
||||
agent can recover verbally; the call is never crashed."""
|
||||
|
||||
async def mcp_tool_handler(
|
||||
function_call_params: FunctionCallParams,
|
||||
) -> None:
|
||||
logger.info(f"MCP Tool EXECUTED: {function_name}")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
try:
|
||||
result = await session.call(
|
||||
function_name, function_call_params.arguments or {}
|
||||
)
|
||||
await function_call_params.result_callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool '{function_name}' failed: {e}")
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": str(e)}
|
||||
)
|
||||
|
||||
return mcp_tool_handler
|
||||
|
||||
def _create_end_call_handler(self, tool: Any, function_name: str):
|
||||
"""Create a handler function for an end call tool.
|
||||
|
||||
|
|
|
|||
116
api/services/workflow/tools/mcp_tool.py
Normal file
116
api/services/workflow/tools/mcp_tool.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
"""Pure helpers for MCP-category tools: definition validation and
|
||||
LLM-function-name namespacing. No I/O, no MCP protocol here."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
|
||||
DEFAULT_TIMEOUT_SECS = 30
|
||||
DEFAULT_SSE_READ_TIMEOUT_SECS = 300
|
||||
|
||||
|
||||
class McpDefinitionError(ValueError):
|
||||
"""Raised when an MCP tool definition is structurally invalid."""
|
||||
|
||||
|
||||
class McpToolConfig(BaseModel):
|
||||
"""Configuration for an MCP tool definition."""
|
||||
|
||||
transport: Literal["streamable_http"] = Field(
|
||||
default="streamable_http", description="MCP transport protocol"
|
||||
)
|
||||
url: str = Field(description="MCP server URL (must be http:// or https://)")
|
||||
credential_uuid: Optional[str] = Field(
|
||||
default=None, description="Reference to ExternalCredentialModel for auth"
|
||||
)
|
||||
tools_filter: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Allowlist of MCP tool names to expose (empty = all tools)",
|
||||
)
|
||||
timeout_secs: int = Field(
|
||||
default=DEFAULT_TIMEOUT_SECS, description="Connection timeout in seconds"
|
||||
)
|
||||
sse_read_timeout_secs: int = Field(
|
||||
default=DEFAULT_SSE_READ_TIMEOUT_SECS,
|
||||
description="SSE read timeout in seconds",
|
||||
)
|
||||
discovered_tools: list[dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Server-managed cache of the MCP server's tool catalog "
|
||||
"[{name, description}]. Populated best-effort by the backend."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def validate_url(cls, v: str) -> str:
|
||||
if not isinstance(v, str) or not v.startswith(("http://", "https://")):
|
||||
raise ValueError("config.url must be an http(s) URL")
|
||||
return v
|
||||
|
||||
@field_validator("tools_filter")
|
||||
@classmethod
|
||||
def validate_tools_filter(cls, v: list[str]) -> list[str]:
|
||||
if not all(isinstance(tool_name, str) for tool_name in v):
|
||||
raise ValueError("config.tools_filter must be a list of strings")
|
||||
return v
|
||||
|
||||
|
||||
class McpToolDefinition(BaseModel):
|
||||
"""Persisted MCP tool definition."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version")
|
||||
type: Literal["mcp"] = Field(description="Tool type")
|
||||
config: McpToolConfig = Field(description="MCP server configuration")
|
||||
|
||||
|
||||
def _format_validation_error(error: ValidationError) -> str:
|
||||
parts: list[str] = []
|
||||
for item in error.errors():
|
||||
location = ".".join(str(part) for part in item["loc"])
|
||||
parts.append(f"{location}: {item['msg']}")
|
||||
return "; ".join(parts)
|
||||
|
||||
|
||||
def validate_mcp_definition(definition: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate a ``type: "mcp"`` ToolModel definition and return a
|
||||
normalized config dict with defaults applied.
|
||||
|
||||
Raises:
|
||||
McpDefinitionError: if the definition is missing required fields
|
||||
or uses an unsupported transport.
|
||||
"""
|
||||
if not isinstance(definition, dict) or definition.get("type") != "mcp":
|
||||
raise McpDefinitionError("definition.type must be 'mcp'")
|
||||
|
||||
config = definition.get("config")
|
||||
if not isinstance(config, dict):
|
||||
raise McpDefinitionError("definition.config is required and must be an object")
|
||||
|
||||
try:
|
||||
parsed = McpToolDefinition.model_validate(definition)
|
||||
except ValidationError as e:
|
||||
raise McpDefinitionError(_format_validation_error(e)) from e
|
||||
|
||||
return parsed.config.model_dump(exclude={"discovered_tools"})
|
||||
|
||||
|
||||
def _slugify(value: str) -> str:
|
||||
slug = re.sub(r"[^a-z0-9]+", "_", value.strip().lower()).strip("_")
|
||||
return slug
|
||||
|
||||
|
||||
def namespace_function_name(
|
||||
tool_name: str, mcp_tool_name: str, *, fallback: str = "server"
|
||||
) -> str:
|
||||
"""Build a collision-safe LLM function name: ``mcp__<slug>__<tool>``.
|
||||
|
||||
``slug`` is derived from the Dograh ToolModel name; if it slugifies to
|
||||
empty, ``fallback`` (e.g. first 8 chars of tool_uuid) is used instead.
|
||||
"""
|
||||
slug = _slugify(tool_name) or _slugify(fallback) or "server"
|
||||
return f"mcp__{slug}__{mcp_tool_name}"
|
||||
|
|
@ -89,6 +89,7 @@ class Node:
|
|||
self.delayed_start_duration = getattr(data, "delayed_start_duration", None)
|
||||
self.tool_uuids = getattr(data, "tool_uuids", None)
|
||||
self.document_uuids = getattr(data, "document_uuids", None)
|
||||
self.mcp_tool_filters = getattr(data, "mcp_tool_filters", None)
|
||||
self.pre_call_fetch_enabled = getattr(data, "pre_call_fetch_enabled", False)
|
||||
self.pre_call_fetch_url = getattr(data, "pre_call_fetch_url", None)
|
||||
self.pre_call_fetch_credential_uuid = getattr(
|
||||
|
|
|
|||
0
api/tests/support/__init__.py
Normal file
0
api/tests/support/__init__.py
Normal file
103
api/tests/support/mcp_mock_server.py
Normal file
103
api/tests/support/mcp_mock_server.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""A real FastMCP server exposing 2 tools over streamable-HTTP, run in a
|
||||
background uvicorn thread on an ephemeral port. Used to exercise the real
|
||||
MCP protocol path in tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import socket
|
||||
import threading
|
||||
from typing import AsyncIterator
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastmcp import FastMCP
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
def _build_app(required_headers: dict[str, str] | None = None):
|
||||
mcp = FastMCP("mock-mcp")
|
||||
|
||||
@mcp.tool()
|
||||
def echo(text: str) -> str:
|
||||
"""Echo the provided text back."""
|
||||
return f"echo:{text}"
|
||||
|
||||
@mcp.tool()
|
||||
def add(a: int, b: int) -> int:
|
||||
"""Add two integers."""
|
||||
return a + b
|
||||
|
||||
# FastMCP 3.x: ASGI app for streamable-HTTP transport at "/mcp".
|
||||
app = mcp.http_app()
|
||||
if not required_headers:
|
||||
return app
|
||||
|
||||
normalized = {k.lower(): v for k, v in required_headers.items()}
|
||||
|
||||
async def guarded_app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
headers = {
|
||||
key.decode("latin-1").lower(): value.decode("latin-1")
|
||||
for key, value in scope.get("headers", [])
|
||||
}
|
||||
for header_name, expected_value in normalized.items():
|
||||
if headers.get(header_name) != expected_value:
|
||||
response = JSONResponse(
|
||||
{"detail": f"Missing or invalid header: {header_name}"},
|
||||
status_code=401,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
await app(scope, receive, send)
|
||||
|
||||
return guarded_app
|
||||
|
||||
|
||||
def _free_port() -> int:
|
||||
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def running_mcp_server(
|
||||
*, required_headers: dict[str, str] | None = None
|
||||
) -> AsyncIterator[str]:
|
||||
"""Yield the base streamable-HTTP URL of a live mock MCP server."""
|
||||
port = _free_port()
|
||||
config = uvicorn.Config(
|
||||
_build_app(required_headers), host="127.0.0.1", port=port, log_level="warning"
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
|
||||
base_url = f"http://127.0.0.1:{port}/mcp"
|
||||
server_ready = False
|
||||
for _ in range(50):
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
await client.get(base_url, timeout=0.5)
|
||||
server_ready = True
|
||||
break
|
||||
except Exception:
|
||||
await asyncio.sleep(0.1)
|
||||
if not server_ready:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
raise RuntimeError(f"Mock MCP server at {base_url} failed to start within 5s")
|
||||
try:
|
||||
yield base_url
|
||||
finally:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
if thread.is_alive():
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Mock MCP server thread did not terminate within 5s",
|
||||
ResourceWarning,
|
||||
)
|
||||
63
api/tests/test_mcp_auth.py
Normal file
63
api/tests/test_mcp_auth.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_mcp_request_accepts_bearer_authorization():
|
||||
user = MagicMock()
|
||||
user.id = 1
|
||||
user.selected_organization_id = 90
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.auth.get_http_headers",
|
||||
return_value={"authorization": "Bearer secret-api-key"},
|
||||
) as get_headers,
|
||||
patch(
|
||||
"api.mcp_server.auth._handle_api_key_auth",
|
||||
AsyncMock(return_value=user),
|
||||
) as handle_auth,
|
||||
):
|
||||
authed = await authenticate_mcp_request()
|
||||
|
||||
assert authed is user
|
||||
get_headers.assert_called_once_with(include={"authorization"})
|
||||
handle_auth.assert_awaited_once_with("secret-api-key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_mcp_request_accepts_x_api_key():
|
||||
user = MagicMock()
|
||||
user.id = 2
|
||||
user.selected_organization_id = 91
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.auth.get_http_headers",
|
||||
return_value={"x-api-key": "secret-api-key"},
|
||||
) as get_headers,
|
||||
patch(
|
||||
"api.mcp_server.auth._handle_api_key_auth",
|
||||
AsyncMock(return_value=user),
|
||||
) as handle_auth,
|
||||
):
|
||||
authed = await authenticate_mcp_request()
|
||||
|
||||
assert authed is user
|
||||
get_headers.assert_called_once_with(include={"authorization"})
|
||||
handle_auth.assert_awaited_once_with("secret-api-key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_mcp_request_rejects_missing_api_key():
|
||||
with patch("api.mcp_server.auth.get_http_headers", return_value={}) as get_headers:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await authenticate_mcp_request()
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Missing API key" in str(exc_info.value.detail)
|
||||
get_headers.assert_called_once_with(include={"authorization"})
|
||||
181
api/tests/test_mcp_custom_tool_manager.py
Normal file
181
api/tests/test_mcp_custom_tool_manager.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import ToolCategory
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.tests.support.mcp_mock_server import running_mcp_server
|
||||
|
||||
|
||||
def _mcp_tool():
|
||||
t = MagicMock()
|
||||
t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8]
|
||||
t.name = "Acme MCP"
|
||||
t.category = ToolCategory.MCP.value
|
||||
t.definition = {"type": "mcp", "config": {"url": "https://x/mcp"}}
|
||||
return t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tool_schemas_and_handler_for_mcp(monkeypatch):
|
||||
async with running_mcp_server() as base_url:
|
||||
tool = _mcp_tool()
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
|
||||
engine = MagicMock()
|
||||
engine._mcp_sessions = {tool.tool_uuid: session}
|
||||
registered = {}
|
||||
reg_kwargs = {}
|
||||
|
||||
def _reg(name, fn, **kw):
|
||||
registered[name] = fn
|
||||
reg_kwargs[name] = kw
|
||||
|
||||
engine.llm.register_function = _reg
|
||||
|
||||
mgr = CustomToolManager(engine)
|
||||
mgr.get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
|
||||
)
|
||||
|
||||
try:
|
||||
schemas = await mgr.get_tool_schemas([tool.tool_uuid])
|
||||
names = sorted(s.name for s in schemas)
|
||||
assert names == ["mcp__acme_mcp__add", "mcp__acme_mcp__echo"]
|
||||
|
||||
await mgr.register_handlers([tool.tool_uuid])
|
||||
assert "mcp__acme_mcp__echo" in registered
|
||||
assert reg_kwargs["mcp__acme_mcp__echo"]["timeout_secs"] == pytest.approx(
|
||||
15.0
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
class P:
|
||||
function_name = "mcp__acme_mcp__echo"
|
||||
arguments = {"text": "yo"}
|
||||
|
||||
async def result_callback(self, r, *, properties=None):
|
||||
captured["r"] = r
|
||||
|
||||
await registered["mcp__acme_mcp__echo"](P())
|
||||
assert "echo:yo" in str(captured["r"])
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unavailable_mcp_session_contributes_nothing(monkeypatch):
|
||||
tool = _mcp_tool()
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=1,
|
||||
sse_read_timeout_secs=1,
|
||||
)
|
||||
await session.start() # degrades
|
||||
|
||||
engine = MagicMock()
|
||||
engine._mcp_sessions = {tool.tool_uuid: session}
|
||||
mgr = CustomToolManager(engine)
|
||||
mgr.get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]))
|
||||
|
||||
schemas = await mgr.get_tool_schemas([tool.tool_uuid])
|
||||
assert schemas == []
|
||||
await mgr.register_handlers([tool.tool_uuid]) # must not raise
|
||||
|
||||
|
||||
def test_call_timeout_secs_is_read_timeout_plus_buffer():
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-abc123",
|
||||
tool_name="Acme MCP",
|
||||
url="https://x/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
assert session.call_timeout_secs == 25.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_node_mcp_filter_intersection(monkeypatch):
|
||||
async with running_mcp_server() as base_url:
|
||||
tool = _mcp_tool()
|
||||
session = McpToolSession(
|
||||
tool_uuid=tool.tool_uuid,
|
||||
tool_name=tool.name,
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
|
||||
engine = MagicMock()
|
||||
engine._mcp_sessions = {tool.tool_uuid: session}
|
||||
registered = {}
|
||||
engine.llm.register_function = lambda name, fn, **kw: registered.__setitem__(
|
||||
name, fn
|
||||
)
|
||||
|
||||
mgr = CustomToolManager(engine)
|
||||
mgr.get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
|
||||
)
|
||||
try:
|
||||
# Allow only raw "echo" for this node
|
||||
filters = {tool.tool_uuid: ["echo"]}
|
||||
schemas = await mgr.get_tool_schemas(
|
||||
[tool.tool_uuid], mcp_tool_filters=filters
|
||||
)
|
||||
# Check only "echo" schema returned (namespaced name depends on tool.name)
|
||||
assert len(schemas) == 1
|
||||
assert all("echo" in s.name for s in schemas)
|
||||
|
||||
await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters=filters)
|
||||
assert len(registered) == 1
|
||||
assert all("echo" in k for k in registered)
|
||||
|
||||
# No filter entry for this uuid = none (default-none)
|
||||
registered.clear()
|
||||
result = await mgr.get_tool_schemas([tool.tool_uuid], mcp_tool_filters={})
|
||||
assert result == []
|
||||
await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters={})
|
||||
assert registered == {}
|
||||
|
||||
# mcp_tool_filters=None = backward-compatible (all tools)
|
||||
registered.clear()
|
||||
all_schemas = await mgr.get_tool_schemas([tool.tool_uuid])
|
||||
assert len(all_schemas) == 2 # both echo and add
|
||||
await mgr.register_handlers([tool.tool_uuid])
|
||||
assert len(registered) == 2 # both handlers registered
|
||||
finally:
|
||||
await session.close()
|
||||
107
api/tests/test_mcp_integration.py
Normal file
107
api/tests/test_mcp_integration.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import ToolCategory
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.tests.support.mcp_mock_server import running_mcp_server
|
||||
|
||||
|
||||
def _mcp_tool(url: str):
|
||||
t = MagicMock()
|
||||
t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8]
|
||||
t.name = "Acme MCP"
|
||||
t.category = ToolCategory.MCP.value
|
||||
t.definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http", "url": url},
|
||||
}
|
||||
return t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_opens_and_closes_mcp_sessions(monkeypatch):
|
||||
async with running_mcp_server() as base_url:
|
||||
tool = _mcp_tool(base_url)
|
||||
|
||||
engine = PipecatEngine.__new__(PipecatEngine)
|
||||
node = MagicMock()
|
||||
node.tool_uuids = [tool.tool_uuid]
|
||||
workflow = MagicMock()
|
||||
workflow.nodes = {"n1": node}
|
||||
engine.workflow = workflow
|
||||
engine._mcp_sessions = {}
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
db_client, "get_credential_by_uuid", AsyncMock(return_value=None)
|
||||
)
|
||||
engine._get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
await engine._open_mcp_sessions()
|
||||
try:
|
||||
assert tool.tool_uuid in engine._mcp_sessions
|
||||
sess = engine._mcp_sessions[tool.tool_uuid]
|
||||
assert sess.available is True
|
||||
assert len(sess.function_schemas()) == 2
|
||||
finally:
|
||||
await engine._close_mcp_sessions()
|
||||
assert engine._mcp_sessions == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_mcp_sessions_swallows_db_error(monkeypatch):
|
||||
engine = PipecatEngine.__new__(PipecatEngine)
|
||||
node = MagicMock()
|
||||
node.tool_uuids = ["uuid-deadbeef"]
|
||||
workflow = MagicMock()
|
||||
workflow.nodes = {"n1": node}
|
||||
engine.workflow = workflow
|
||||
engine._mcp_sessions = {}
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(
|
||||
db_client,
|
||||
"get_tools_by_uuids",
|
||||
AsyncMock(side_effect=RuntimeError("db down")),
|
||||
)
|
||||
engine._get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
# Must NOT raise
|
||||
await engine._open_mcp_sessions()
|
||||
assert engine._mcp_sessions == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_mcp_sessions_skips_tool_when_credential_fetch_fails(monkeypatch):
|
||||
tool = _mcp_tool("http://127.0.0.1:1/mcp")
|
||||
tool.definition["config"]["credential_uuid"] = "cred-1234"
|
||||
|
||||
engine = PipecatEngine.__new__(PipecatEngine)
|
||||
node = MagicMock()
|
||||
node.tool_uuids = [tool.tool_uuid]
|
||||
workflow = MagicMock()
|
||||
workflow.nodes = {"n1": node}
|
||||
engine.workflow = workflow
|
||||
engine._mcp_sessions = {}
|
||||
|
||||
from api.db import db_client
|
||||
|
||||
monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]))
|
||||
monkeypatch.setattr(
|
||||
db_client,
|
||||
"get_credential_by_uuid",
|
||||
AsyncMock(side_effect=RuntimeError("cred store down")),
|
||||
)
|
||||
engine._get_organization_id = AsyncMock(return_value=42)
|
||||
|
||||
# Must NOT raise, and must skip the tool (no futile unauthenticated start)
|
||||
await engine._open_mcp_sessions()
|
||||
assert engine._mcp_sessions == {}
|
||||
112
api/tests/test_mcp_tool_definition.py
Normal file
112
api/tests/test_mcp_tool_definition.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
import importlib
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import ToolCategory
|
||||
from api.routes.tool import McpToolConfig as RouteMcpToolConfig
|
||||
from api.routes.tool import McpToolDefinition as RouteMcpToolDefinition
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpDefinitionError,
|
||||
McpToolConfig,
|
||||
McpToolDefinition,
|
||||
namespace_function_name,
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_category_exists():
|
||||
assert ToolCategory.MCP.value == "mcp"
|
||||
assert ToolCategory("mcp") is ToolCategory.MCP
|
||||
|
||||
|
||||
def test_mcp_migration_present_and_chained(monkeypatch):
|
||||
mod = importlib.import_module(
|
||||
"api.alembic.versions.0a1b2c3d4e5f_add_mcp_in_toolcategory"
|
||||
)
|
||||
assert mod.revision == "0a1b2c3d4e5f"
|
||||
assert mod.down_revision == "4c1f1e3e8ef2"
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_sync_enum_values(**kwargs):
|
||||
calls.append(kwargs)
|
||||
|
||||
monkeypatch.setattr(mod.op, "sync_enum_values", fake_sync_enum_values)
|
||||
|
||||
mod.upgrade()
|
||||
mod.downgrade()
|
||||
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["enum_name"] == "tool_category"
|
||||
assert "mcp" in calls[0]["new_values"]
|
||||
assert "mcp" not in calls[1]["new_values"]
|
||||
|
||||
|
||||
def test_route_reuses_shared_mcp_models():
|
||||
assert RouteMcpToolConfig is McpToolConfig
|
||||
assert RouteMcpToolDefinition is McpToolDefinition
|
||||
|
||||
|
||||
def test_validate_mcp_definition_ok():
|
||||
cfg = validate_mcp_definition(
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://acme.example.com/mcp",
|
||||
"credential_uuid": "cred-123",
|
||||
"tools_filter": ["lookup_patient"],
|
||||
"timeout_secs": 30,
|
||||
"sse_read_timeout_secs": 300,
|
||||
},
|
||||
}
|
||||
)
|
||||
assert cfg["url"] == "https://acme.example.com/mcp"
|
||||
assert cfg["transport"] == "streamable_http"
|
||||
assert cfg["tools_filter"] == ["lookup_patient"]
|
||||
assert cfg["timeout_secs"] == 30
|
||||
assert cfg["sse_read_timeout_secs"] == 300
|
||||
assert cfg["credential_uuid"] == "cred-123"
|
||||
|
||||
|
||||
def test_validate_mcp_definition_defaults():
|
||||
cfg = validate_mcp_definition({"type": "mcp", "config": {"url": "https://x/mcp"}})
|
||||
assert cfg["transport"] == "streamable_http"
|
||||
assert cfg["tools_filter"] == []
|
||||
assert cfg["timeout_secs"] == 30
|
||||
assert cfg["sse_read_timeout_secs"] == 300
|
||||
assert cfg["credential_uuid"] is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"definition",
|
||||
[
|
||||
{"type": "mcp", "config": {}},
|
||||
{"type": "mcp", "config": {"url": ""}},
|
||||
{"type": "mcp", "config": {"url": "ftp://x"}},
|
||||
{"type": "mcp"},
|
||||
{"type": "mcp", "config": {"url": "https://x", "transport": "stdio"}},
|
||||
],
|
||||
)
|
||||
def test_validate_mcp_definition_rejects(definition):
|
||||
with pytest.raises(McpDefinitionError):
|
||||
validate_mcp_definition(definition)
|
||||
|
||||
|
||||
def test_validate_mcp_definition_zero_timeout_preserved():
|
||||
cfg = validate_mcp_definition(
|
||||
{"type": "mcp", "config": {"url": "https://x/mcp", "timeout_secs": 0}}
|
||||
)
|
||||
assert cfg["timeout_secs"] == 0
|
||||
|
||||
|
||||
def test_namespace_function_name():
|
||||
assert (
|
||||
namespace_function_name("Acme MCP", "lookup_patient")
|
||||
== "mcp__acme_mcp__lookup_patient"
|
||||
)
|
||||
assert (
|
||||
namespace_function_name("", "ping", fallback="abcd1234")
|
||||
== "mcp__abcd1234__ping"
|
||||
)
|
||||
437
api/tests/test_mcp_tool_route.py
Normal file
437
api/tests/test_mcp_tool_route.py
Normal file
|
|
@ -0,0 +1,437 @@
|
|||
"""Route-level tests for the MCP tool definition schema.
|
||||
|
||||
These tests exercise the Pydantic request models (CreateToolRequest /
|
||||
UpdateToolRequest) to catch schema gaps at the route/request-model layer —
|
||||
the layer where the pre-fix defect lived (HTTP 422 on every MCP tool
|
||||
creation attempt).
|
||||
|
||||
Test coverage:
|
||||
- CreateToolRequest validates a valid MCP definition (was 422 before Part A).
|
||||
- UpdateToolRequest validates a valid MCP definition.
|
||||
- Invalid MCP bodies are rejected (ftp:// url, missing url).
|
||||
- Round-trip: validated definition dict passes through validate_mcp_definition
|
||||
unchanged, proving the request schema and call-time validator agree.
|
||||
- Full HTTP round-trip via the ASGI test client (POST /api/v1/tools/).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.routes.tool import CreateToolRequest, McpToolDefinition, UpdateToolRequest
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
||||
# ── Canonical valid MCP request body ─────────────────────────────────────────
|
||||
|
||||
VALID_MCP_DEFINITION = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://x/mcp",
|
||||
"credential_uuid": None,
|
||||
"tools_filter": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ── Part A regression: CreateToolRequest / UpdateToolRequest validation ───────
|
||||
|
||||
|
||||
def test_create_tool_request_accepts_mcp_definition():
|
||||
"""CreateToolRequest must accept an MCP definition (was HTTP 422 before fix)."""
|
||||
req = CreateToolRequest(
|
||||
name="My MCP Tool",
|
||||
description="Integration via MCP",
|
||||
category="mcp",
|
||||
definition=VALID_MCP_DEFINITION,
|
||||
)
|
||||
assert isinstance(req.definition, McpToolDefinition)
|
||||
assert req.definition.type == "mcp"
|
||||
assert req.definition.config.url == "https://x/mcp"
|
||||
assert req.definition.config.transport == "streamable_http"
|
||||
assert req.definition.config.credential_uuid is None
|
||||
assert req.definition.config.tools_filter == []
|
||||
assert req.definition.config.timeout_secs == 30
|
||||
assert req.definition.config.sse_read_timeout_secs == 300
|
||||
|
||||
|
||||
def test_update_tool_request_accepts_mcp_definition():
|
||||
"""UpdateToolRequest must also accept an MCP definition."""
|
||||
req = UpdateToolRequest(
|
||||
name="Updated MCP Tool",
|
||||
definition=VALID_MCP_DEFINITION,
|
||||
)
|
||||
assert isinstance(req.definition, McpToolDefinition)
|
||||
assert req.definition.type == "mcp"
|
||||
assert req.definition.config.url == "https://x/mcp"
|
||||
|
||||
|
||||
def test_create_tool_request_accepts_mcp_with_all_fields():
|
||||
"""All optional MCP config fields are accepted and preserved."""
|
||||
req = CreateToolRequest(
|
||||
name="Full MCP Tool",
|
||||
category="mcp",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://acme.example.com/mcp",
|
||||
"credential_uuid": "cred-abc-123",
|
||||
"tools_filter": ["lookup_patient", "schedule_appointment"],
|
||||
"timeout_secs": 60,
|
||||
"sse_read_timeout_secs": 600,
|
||||
},
|
||||
},
|
||||
)
|
||||
cfg = req.definition.config # type: ignore[union-attr]
|
||||
assert cfg.url == "https://acme.example.com/mcp"
|
||||
assert cfg.credential_uuid == "cred-abc-123"
|
||||
assert cfg.tools_filter == ["lookup_patient", "schedule_appointment"]
|
||||
assert cfg.timeout_secs == 60
|
||||
assert cfg.sse_read_timeout_secs == 600
|
||||
|
||||
|
||||
# ── Invalid bodies are rejected ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"definition",
|
||||
[
|
||||
# ftp:// URL — rejected by McpToolConfig.validate_url
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http", "url": "ftp://x/mcp"},
|
||||
},
|
||||
# Empty url — rejected by McpToolConfig.validate_url
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http", "url": ""},
|
||||
},
|
||||
# Missing url — rejected by McpToolConfig (required field)
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"transport": "streamable_http"},
|
||||
},
|
||||
# Unsupported transport — rejected because Literal["streamable_http"] constraint
|
||||
{
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"url": "https://x/mcp", "transport": "stdio"},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_create_tool_request_rejects_invalid_mcp_definition(definition):
|
||||
"""Invalid MCP definitions must raise ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
CreateToolRequest(
|
||||
name="Bad MCP Tool",
|
||||
category="mcp",
|
||||
definition=definition,
|
||||
)
|
||||
|
||||
|
||||
# ── Round-trip compatibility: request schema ↔ validate_mcp_definition ───────
|
||||
|
||||
|
||||
def test_mcp_definition_round_trips_through_validate_mcp_definition():
|
||||
"""The dict produced by CreateToolRequest.definition.model_dump() must be
|
||||
accepted by validate_mcp_definition without raising, and the result must
|
||||
contain the expected fields. This proves the request-layer schema and the
|
||||
call-time validator agree on the stored config shape."""
|
||||
req = CreateToolRequest(
|
||||
name="Round-Trip MCP Tool",
|
||||
category="mcp",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://roundtrip.example.com/mcp",
|
||||
"credential_uuid": "cred-rt-456",
|
||||
"tools_filter": ["ping"],
|
||||
"timeout_secs": 45,
|
||||
"sse_read_timeout_secs": 400,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Simulate what the route does: persist definition as a plain dict
|
||||
persisted = req.definition.model_dump() # type: ignore[union-attr]
|
||||
|
||||
# validate_mcp_definition must accept the persisted shape without raising
|
||||
normalized = validate_mcp_definition(persisted)
|
||||
|
||||
assert normalized["url"] == "https://roundtrip.example.com/mcp"
|
||||
assert normalized["transport"] == "streamable_http"
|
||||
assert normalized["credential_uuid"] == "cred-rt-456"
|
||||
assert normalized["tools_filter"] == ["ping"]
|
||||
assert normalized["timeout_secs"] == 45
|
||||
assert normalized["sse_read_timeout_secs"] == 400
|
||||
|
||||
|
||||
def test_mcp_definition_round_trip_defaults():
|
||||
"""Round-trip with minimal body: defaults fill in correctly and
|
||||
validate_mcp_definition agrees on them."""
|
||||
req = CreateToolRequest(
|
||||
name="Minimal MCP Tool",
|
||||
category="mcp",
|
||||
definition=VALID_MCP_DEFINITION,
|
||||
)
|
||||
|
||||
persisted = req.definition.model_dump() # type: ignore[union-attr]
|
||||
normalized = validate_mcp_definition(persisted)
|
||||
|
||||
assert normalized["transport"] == "streamable_http"
|
||||
assert normalized["tools_filter"] == []
|
||||
assert normalized["timeout_secs"] == 30
|
||||
assert normalized["sse_read_timeout_secs"] == 300
|
||||
assert normalized["credential_uuid"] is None
|
||||
# Part B: auth_header / auth_scheme must NOT be present in the normalized
|
||||
# config dict (they were dead config removed in the fix)
|
||||
assert "auth_header" not in normalized
|
||||
assert "auth_scheme" not in normalized
|
||||
|
||||
|
||||
# ── Full HTTP round-trip via ASGI test client ─────────────────────────────────
|
||||
|
||||
|
||||
async def test_post_tool_mcp_returns_200(test_client_factory, db_session):
|
||||
"""POST /api/v1/tools/ with an MCP definition must return HTTP 200 and
|
||||
persist the definition with type='mcp'. Before Part A this always
|
||||
returned 422."""
|
||||
# Create a user and an organization, then link them so the route's
|
||||
# selected_organization_id check passes.
|
||||
user, _ = await db_session.get_or_create_user_by_provider_id("mcp_route_test_user")
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
"mcp_route_test_org", user.id
|
||||
)
|
||||
await db_session.update_user_selected_organization(user.id, org.id)
|
||||
# Reload the user so selected_organization_id is populated on the object.
|
||||
user = await db_session.get_user_by_id(user.id)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/tools/",
|
||||
json={
|
||||
"name": "HTTP Round-Trip MCP Tool",
|
||||
"description": "Testing the full route",
|
||||
"category": "mcp",
|
||||
"definition": {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://roundtrip.example.com/mcp",
|
||||
"credential_uuid": None,
|
||||
"tools_filter": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, (
|
||||
f"Expected 200, got {response.status_code}: {response.text}"
|
||||
)
|
||||
body = response.json()
|
||||
assert body["definition"]["type"] == "mcp"
|
||||
assert body["definition"]["config"]["url"] == "https://roundtrip.example.com/mcp"
|
||||
assert body["category"] == "mcp"
|
||||
|
||||
|
||||
async def test_post_tool_mcp_invalid_url_returns_422(test_client_factory, db_session):
|
||||
"""POST /api/v1/tools/ with an ftp:// URL must return HTTP 422."""
|
||||
user, _ = await db_session.get_or_create_user_by_provider_id(
|
||||
"mcp_route_test_user_422"
|
||||
)
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
"mcp_route_test_org_422", user.id
|
||||
)
|
||||
await db_session.update_user_selected_organization(user.id, org.id)
|
||||
user = await db_session.get_user_by_id(user.id)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/tools/",
|
||||
json={
|
||||
"name": "Bad MCP Tool",
|
||||
"category": "mcp",
|
||||
"definition": {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"transport": "streamable_http",
|
||||
"url": "ftp://invalid.example.com/mcp",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ── Task 6: discovered_tools field and _populate_discovered_tools helper ──────
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from api.routes.tool import McpToolConfig, _populate_discovered_tools
|
||||
|
||||
|
||||
def test_mcp_config_accepts_discovered_tools():
|
||||
cfg = McpToolConfig(
|
||||
url="https://x/mcp",
|
||||
discovered_tools=[{"name": "echo", "description": "Echo"}],
|
||||
)
|
||||
assert cfg.discovered_tools == [{"name": "echo", "description": "Echo"}]
|
||||
# Defaults to [] when omitted
|
||||
assert McpToolConfig(url="https://x/mcp").discovered_tools == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_overwrites_cache(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
|
||||
)
|
||||
definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {
|
||||
"url": "https://x/mcp",
|
||||
"tools_filter": [],
|
||||
"discovered_tools": [{"name": "stale", "description": "old"}],
|
||||
},
|
||||
}
|
||||
out = await _populate_discovered_tools(definition, organization_id=1)
|
||||
assert out["config"]["discovered_tools"] == [
|
||||
{"name": "echo", "description": "Echo"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_non_mcp_is_noop():
|
||||
definition = {"schema_version": 1, "type": "http_api", "config": {}}
|
||||
out = await _populate_discovered_tools(definition, organization_id=1)
|
||||
assert out == definition # untouched
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_server_down_sets_empty(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(side_effect=RuntimeError("connection refused")),
|
||||
)
|
||||
definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"url": "https://x/mcp", "tools_filter": []},
|
||||
}
|
||||
out = await _populate_discovered_tools(definition, organization_id=1)
|
||||
assert out["config"]["discovered_tools"] == []
|
||||
|
||||
|
||||
# ── Task 7: POST /{tool_uuid}/mcp/refresh ─────────────────────────────────────
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.routes.tool import refresh_mcp_tools
|
||||
|
||||
|
||||
def _fake_user(org_id=1):
|
||||
u = MagicMock()
|
||||
u.selected_organization_id = org_id
|
||||
u.id = 1
|
||||
u.provider_id = "p1"
|
||||
return u
|
||||
|
||||
|
||||
def _mcp_tool_model(org_id=1):
|
||||
t = MagicMock()
|
||||
t.tool_uuid = "tu-mcp"
|
||||
t.name = "Mock MCP"
|
||||
t.category = "mcp"
|
||||
t.definition = {
|
||||
"schema_version": 1,
|
||||
"type": "mcp",
|
||||
"config": {"url": "https://x/mcp", "tools_filter": []},
|
||||
}
|
||||
return t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_success(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client,
|
||||
"update_tool",
|
||||
AsyncMock(return_value=tool),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
|
||||
)
|
||||
resp = await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert resp.discovered_tools == [{"name": "echo", "description": "Echo"}]
|
||||
assert resp.error is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_server_down_returns_200_with_error(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
monkeypatch.setattr(tool_mod.db_client, "update_tool", AsyncMock(return_value=tool))
|
||||
monkeypatch.setattr(tool_mod, "discover_mcp_tools", AsyncMock(return_value=[]))
|
||||
resp = await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert resp.discovered_tools == []
|
||||
assert resp.error # non-empty human-readable message
|
||||
# update_tool should NOT be called when discovery returns empty
|
||||
tool_mod.db_client.update_tool.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_non_mcp_is_400(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
tool.category = "http_api"
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert ei.value.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_not_found_is_404(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=None)
|
||||
)
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
await refresh_mcp_tools("nope", user=_fake_user())
|
||||
assert ei.value.status_code == 404
|
||||
274
api/tests/test_mcp_tool_session.py
Normal file
274
api/tests/test_mcp_tool_session.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.mcp_tool_session import (
|
||||
McpToolSession,
|
||||
build_streamable_http_params,
|
||||
discover_mcp_tools,
|
||||
)
|
||||
from api.tests.support.mcp_mock_server import running_mcp_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_server_starts_and_serves():
|
||||
async with running_mcp_server() as base_url:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(base_url, timeout=5.0)
|
||||
assert resp.status_code in (400, 404, 405, 406)
|
||||
|
||||
|
||||
def test_build_streamable_http_params_with_credential():
|
||||
cred = MagicMock()
|
||||
cred.credential_type = "bearer_token"
|
||||
cred.credential_data = {"token": "abc"}
|
||||
params = build_streamable_http_params(
|
||||
url="https://acme.example.com/mcp",
|
||||
credential=cred,
|
||||
timeout_secs=30,
|
||||
sse_read_timeout_secs=300,
|
||||
)
|
||||
assert params.url == "https://acme.example.com/mcp"
|
||||
assert params.headers == {"Authorization": "Bearer abc"}
|
||||
assert params.timeout == timedelta(seconds=30)
|
||||
assert params.sse_read_timeout == timedelta(seconds=300)
|
||||
|
||||
|
||||
def test_build_streamable_http_params_no_credential():
|
||||
params = build_streamable_http_params(
|
||||
url="https://acme.example.com/mcp",
|
||||
credential=None,
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
assert params.headers is None or params.headers == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_start_passes_auth_header_to_real_server():
|
||||
cred = MagicMock()
|
||||
cred.credential_type = "bearer_token"
|
||||
cred.credential_data = {"token": "abc"}
|
||||
|
||||
async with running_mcp_server(
|
||||
required_headers={"Authorization": "Bearer abc"}
|
||||
) as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-auth-ok",
|
||||
tool_name="Secure MCP",
|
||||
url=base_url,
|
||||
credential=cred,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
assert session.available is True
|
||||
names = sorted(s.name for s in session.function_schemas())
|
||||
assert names == ["mcp__secure_mcp__add", "mcp__secure_mcp__echo"]
|
||||
result = await session.call("mcp__secure_mcp__echo", {"text": "hi"})
|
||||
assert "echo:hi" in result
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_auth_failure_degrades_not_raises():
|
||||
async with running_mcp_server(
|
||||
required_headers={"Authorization": "Bearer abc"}
|
||||
) as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-auth-fail",
|
||||
tool_name="Secure MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=2,
|
||||
sse_read_timeout_secs=2,
|
||||
)
|
||||
await session.start() # must degrade instead of raising on 401
|
||||
try:
|
||||
assert session.available is False
|
||||
assert session.function_schemas() == []
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_start_lists_and_calls_real_server():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
assert session.available is True
|
||||
schemas = session.function_schemas()
|
||||
names = sorted(s.name for s in schemas)
|
||||
assert names == ["mcp__acme_mcp__add", "mcp__acme_mcp__echo"]
|
||||
result = await session.call("mcp__acme_mcp__echo", {"text": "hi"})
|
||||
assert "echo:hi" in result
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_tools_filter_applied():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=["echo"],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=20,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
names = sorted(s.name for s in session.function_schemas())
|
||||
assert names == ["mcp__acme_mcp__echo"]
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_unreachable_degrades_not_raises():
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=2,
|
||||
sse_read_timeout_secs=2,
|
||||
)
|
||||
await session.start() # must NOT raise
|
||||
assert session.available is False
|
||||
assert session.function_schemas() == []
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_on_unavailable_session_raises():
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=2,
|
||||
sse_read_timeout_secs=2,
|
||||
)
|
||||
await session.start()
|
||||
with pytest.raises(RuntimeError):
|
||||
await session.call("mcp__acme_mcp__echo", {"text": "x"})
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_unknown_function_raises():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="uuid-1234abcd",
|
||||
tool_name="Acme MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
with pytest.raises(RuntimeError):
|
||||
await session.call("mcp__acme_mcp__does_not_exist", {})
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_schemas_filter_by_raw_name():
|
||||
async with running_mcp_server() as base_url:
|
||||
session = McpToolSession(
|
||||
tool_uuid="t-filter",
|
||||
tool_name="Mock MCP",
|
||||
url=base_url,
|
||||
credential=None,
|
||||
tools_filter=[],
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
await session.start()
|
||||
try:
|
||||
# No arg = all (backward compatible)
|
||||
all_names = sorted(s.name for s in session.function_schemas())
|
||||
assert all_names == ["mcp__mock_mcp__add", "mcp__mock_mcp__echo"]
|
||||
|
||||
# Allow only raw "echo"
|
||||
only_echo = session.function_schemas(allowed_raw_names={"echo"})
|
||||
assert [s.name for s in only_echo] == ["mcp__mock_mcp__echo"]
|
||||
|
||||
# Empty set = none (default-none semantics)
|
||||
assert session.function_schemas(allowed_raw_names=set()) == []
|
||||
|
||||
# Unknown raw name = skipped (pure intersection)
|
||||
assert session.function_schemas(allowed_raw_names={"nope"}) == []
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_mcp_tools_success():
|
||||
async with running_mcp_server() as base_url:
|
||||
tools = await discover_mcp_tools(
|
||||
url=base_url,
|
||||
credential=None,
|
||||
timeout_secs=10,
|
||||
sse_read_timeout_secs=10,
|
||||
)
|
||||
names = sorted(t["name"] for t in tools)
|
||||
assert names == ["add", "echo"]
|
||||
by_name = {t["name"]: t for t in tools}
|
||||
assert by_name["echo"]["description"] # non-empty description
|
||||
assert set(by_name["echo"]) == {"name", "description"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_mcp_tools_server_down_returns_empty():
|
||||
# Unroutable port, short timeouts: must degrade to [] (never raise).
|
||||
tools = await discover_mcp_tools(
|
||||
url="http://127.0.0.1:1/mcp",
|
||||
credential=None,
|
||||
timeout_secs=1,
|
||||
sse_read_timeout_secs=1,
|
||||
)
|
||||
assert tools == []
|
||||
|
||||
|
||||
def test_agent_node_data_carries_mcp_tool_filters():
|
||||
from api.services.workflow.dto import AgentNodeData, NodeType
|
||||
from api.services.workflow.workflow_graph import Node
|
||||
|
||||
data = AgentNodeData(
|
||||
name="N1",
|
||||
tool_uuids=["tu-1"],
|
||||
mcp_tool_filters={"tu-1": ["echo"]},
|
||||
)
|
||||
assert data.mcp_tool_filters == {"tu-1": ["echo"]}
|
||||
|
||||
node = Node("n1", NodeType.agentNode, data)
|
||||
assert node.mcp_tool_filters == {"tu-1": ["echo"]}
|
||||
|
||||
# Absent field defaults to None (backward compatible)
|
||||
data2 = AgentNodeData(name="N2")
|
||||
assert data2.mcp_tool_filters is None
|
||||
assert Node("n2", NodeType.agentNode, data2).mcp_tool_filters is None
|
||||
|
|
@ -71,7 +71,8 @@
|
|||
{
|
||||
"group": "Custom Tools",
|
||||
"pages": [
|
||||
"voice-agent/tools/http-api"
|
||||
"voice-agent/tools/http-api",
|
||||
"voice-agent/tools/mcp-tool"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
@ -308,4 +309,4 @@
|
|||
"linkedin": "https://linkedin.com/company/dograh"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ HTTP API tools allow you to attach extrernal REST API calls directly to workflow
|
|||
title="YouTube video player"
|
||||
frameBorder="0"
|
||||
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
|
||||
referrerPolicy="strict-origin-when-cross-origin"
|
||||
referrerPolicy="strict-origin-when-cross-origin"
|
||||
allowFullScreen
|
||||
></iframe>
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ title: "Tools"
|
|||
description: "Extend your voice agent's capabilities by giving it tools to perform actions during live conversations."
|
||||
---
|
||||
|
||||
Tools let your AI agent take actions during a conversation — transfer calls, end calls, or call external APIs — based on the context of the conversation and your prompt instructions.
|
||||
Tools let your AI agent take actions during a conversation — transfer calls, end calls, call external APIs, or invoke remote MCP servers — based on the context of the conversation and your prompt instructions.
|
||||
|
||||
When a tool is attached to a workflow node, the LLM decides **when** to invoke it and **what parameters** to pass, based on the user's spoken intent and your node-level instructions.
|
||||
|
||||
|
|
@ -23,6 +23,7 @@ Pre-configured tools that handle common telephony operations out of the box:
|
|||
Tools you define to integrate with any external system:
|
||||
|
||||
- [**HTTP API**](/voice-agent/tools/http-api) — Call any REST API endpoint during a conversation (e.g., CRM updates, data lookups, triggering automations)
|
||||
- [**MCP Tool**](/voice-agent/tools/mcp-tool) — Connect an external MCP server and expose its remote tools to the LLM during a conversation
|
||||
|
||||
## How Tools Work
|
||||
|
||||
|
|
|
|||
90
docs/voice-agent/tools/mcp-tool.mdx
Normal file
90
docs/voice-agent/tools/mcp-tool.mdx
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
---
|
||||
title: "MCP Tool"
|
||||
description: "Connect an external MCP server to your voice agent so the LLM can call remote tools during a live conversation."
|
||||
---
|
||||
|
||||
<Note>
|
||||
This page is about using an external MCP server as a **tool inside a voice agent**. If you want Claude, Cursor, or another coding agent to control Dograh itself over MCP, see [Dograh MCP Server](/integrations/mcp).
|
||||
</Note>
|
||||
|
||||
MCP tools let a Dograh voice agent call tools exposed by a remote [Model Context Protocol](https://modelcontextprotocol.io/) server during a live conversation. Dograh discovers the remote tool catalog, turns those tools into LLM-callable functions, and forwards invocations to the MCP server over authenticated Streamable HTTP.
|
||||
|
||||
## What to configure
|
||||
|
||||
An MCP tool in Dograh has four important pieces:
|
||||
|
||||
- **Name**: the server label shown in Dograh
|
||||
- **Description**: tells the LLM when this MCP server is relevant
|
||||
- **URL**: the remote MCP server endpoint (`http://` or `https://`)
|
||||
- **Credential**: the auth Dograh should send when connecting to that server
|
||||
|
||||
You can also set a **tool filter** to allow only specific remote MCP tools to be exposed.
|
||||
|
||||
## Authentication
|
||||
|
||||
Most hosted MCP servers expect:
|
||||
|
||||
```http
|
||||
Authorization: Bearer <token>
|
||||
```
|
||||
|
||||
So before creating the MCP tool, create a credential in Dograh with:
|
||||
|
||||
- **Credential Type**: `Bearer Token`
|
||||
- **Token**: the access token issued by your MCP server
|
||||
|
||||
Then select that credential on the MCP tool.
|
||||
|
||||
<Note>
|
||||
If the remote MCP server documentation says to use Bearer auth, choose **Bearer Token** in the credential dialog. Dograh will translate that into the exact `Authorization: Bearer <token>` header on the MCP connection.
|
||||
</Note>
|
||||
|
||||
Dograh also supports other credential types, but **Bearer Token** is the default thing to try for third-party MCP servers unless their docs say otherwise.
|
||||
|
||||
## How it works
|
||||
|
||||
The runtime path is:
|
||||
|
||||
1. When you save or refresh the MCP tool, Dograh opens a short-lived authenticated MCP session and fetches the remote tool catalog.
|
||||
2. Dograh stores that catalog as the tool's discovered tools so the UI can show you which remote functions exist.
|
||||
3. When a call starts, Dograh opens one live MCP session per attached MCP server and reuses your selected credential for that session.
|
||||
4. For each node, Dograh exposes only the MCP tools allowed by the server-level filter and the node-level selection.
|
||||
5. Dograh namespaces those remote tools into ordinary LLM function definitions so they can safely coexist with HTTP API tools, call transfer, end call, and other tools.
|
||||
6. During the conversation, the LLM sees only the tool name, description, and argument schema. It does **not** see the secret.
|
||||
7. When the LLM calls one of those tools, Dograh forwards the invocation to the MCP server over the active authenticated session, receives the result, and feeds that result back into the agent turn.
|
||||
|
||||
In short: **Dograh handles discovery, authentication, session management, tool registration, and result plumbing; the LLM only decides when to call the tool and with which arguments.**
|
||||
|
||||
## Creating an MCP tool
|
||||
|
||||
1. Go to **Tools** and create a new tool.
|
||||
2. Choose **MCP Server**.
|
||||
3. Enter a clear name and a description that explains when the LLM should use this server.
|
||||
4. Paste the MCP server URL.
|
||||
5. Select the credential. In most cases this should be a **Bearer Token** credential.
|
||||
6. Save the tool and confirm that Dograh discovered the remote tools.
|
||||
|
||||
If the server exposes many tools, use filtering to keep only the ones your agent actually needs.
|
||||
|
||||
## Attaching it to a node
|
||||
|
||||
After the MCP tool is created:
|
||||
|
||||
1. Open the workflow node where the tool should be available.
|
||||
2. Add the MCP tool from the node's tool list.
|
||||
3. Select only the remote MCP functions that should be callable on that node.
|
||||
4. In the node prompt, tell the LLM exactly when to use those functions.
|
||||
|
||||
The tighter the node-level selection and prompt guidance, the more reliable MCP tool usage becomes.
|
||||
|
||||
## Best practices
|
||||
|
||||
- Use one MCP server per logical integration when possible.
|
||||
- Keep the tool description explicit about **when** the LLM should use that server.
|
||||
- Expose only the minimum remote functions needed for each node.
|
||||
- Prefer a **Bearer Token** credential unless the MCP server specifies another auth scheme.
|
||||
- Test discovery first, then test a real phone/web call to confirm the LLM invokes the right MCP function with the right arguments.
|
||||
|
||||
<Note>
|
||||
If the remote MCP server is temporarily unavailable, Dograh degrades gracefully and the call can continue without those MCP tools rather than crashing the entire conversation.
|
||||
</Note>
|
||||
2
pipecat
2
pipecat
|
|
@ -1 +1 @@
|
|||
Subproject commit 13e98d0d94aa5db3185e36ba411bae0ffb443b7b
|
||||
Subproject commit ce4ee2d6fc0d0982ad03a1468a517cc5e568aaa9
|
||||
|
|
@ -51,7 +51,7 @@ fi
|
|||
|
||||
# Install pipecat in editable mode with all extras
|
||||
echo "Installing pipecat dependencies..."
|
||||
pip install -e ./pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb]
|
||||
pip install -e ./pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb,mcp]
|
||||
|
||||
if [ "$DEV_MODE" -eq 1 ]; then
|
||||
echo "Installing pipecat dev dependencies..."
|
||||
|
|
|
|||
|
|
@ -9,16 +9,25 @@ import {
|
|||
listRecordingsApiV1WorkflowRecordingsGet,
|
||||
updateToolApiV1ToolsToolUuidPut,
|
||||
} from "@/client/sdk.gen";
|
||||
import type { RecordingResponseSchema, ToolResponse, TransferCallConfig as APITransferCallConfig } from "@/client/types.gen";
|
||||
import type { EndCallConfig } from "@/client/types.gen";
|
||||
import type {
|
||||
EndCallConfig,
|
||||
HttpApiToolDefinition,
|
||||
RecordingResponseSchema,
|
||||
ToolResponse,
|
||||
TransferCallConfig as APITransferCallConfig,
|
||||
UpdateToolRequest,
|
||||
} from "@/client/types.gen";
|
||||
import {
|
||||
CredentialSelector,
|
||||
type HttpMethod,
|
||||
type KeyValueItem,
|
||||
type ParameterType,
|
||||
type PresetToolParameter,
|
||||
type ToolParameter,
|
||||
validateUrl,
|
||||
} from "@/components/http";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
|
|
@ -26,35 +35,33 @@ import {
|
|||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { TOOL_DOCUMENTATION_URLS } from "@/constants/documentation";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
import {
|
||||
createMcpDefinition,
|
||||
DEFAULT_END_CALL_REASON_DESCRIPTION,
|
||||
type EndCallMessageType,
|
||||
getCategoryConfig,
|
||||
getToolTypeLabel,
|
||||
MCP_URL_PATTERN,
|
||||
renderToolIcon,
|
||||
type ToolCategory,
|
||||
} from "../config";
|
||||
import { BuiltinToolConfig, EndCallToolConfig, HttpApiToolConfig, TransferCallToolConfig } from "./components";
|
||||
|
||||
// Extended HttpApiConfig with parameters (until client types are regenerated)
|
||||
interface HttpApiConfigWithParams {
|
||||
method?: string;
|
||||
url?: string;
|
||||
headers?: Record<string, string>;
|
||||
credential_uuid?: string;
|
||||
parameters?: ToolParameter[];
|
||||
preset_parameters?: Array<{
|
||||
name?: string;
|
||||
type?: PresetToolParameter["type"];
|
||||
value_template?: string;
|
||||
required?: boolean;
|
||||
}>;
|
||||
timeout_ms?: number;
|
||||
customMessage?: string;
|
||||
function normalizeParameterType(value: string | null | undefined): ParameterType {
|
||||
switch (value) {
|
||||
case "number":
|
||||
case "boolean":
|
||||
return value;
|
||||
default:
|
||||
return "string";
|
||||
}
|
||||
}
|
||||
|
||||
export default function ToolDetailPage() {
|
||||
|
|
@ -108,6 +115,11 @@ export default function ToolDetailPage() {
|
|||
const [customMessageType, setCustomMessageType] = useState<'text' | 'audio'>('text');
|
||||
const [customMessageRecordingId, setCustomMessageRecordingId] = useState("");
|
||||
|
||||
// MCP form state
|
||||
const [mcpUrl, setMcpUrl] = useState("");
|
||||
const [mcpCredentialUuid, setMcpCredentialUuid] = useState("");
|
||||
const [mcpToolsFilter, setMcpToolsFilter] = useState("");
|
||||
|
||||
// Org-level recordings for audio dropdowns
|
||||
const [recordings, setRecordings] = useState<RecordingResponseSchema[]>([]);
|
||||
|
||||
|
|
@ -155,8 +167,7 @@ export default function ToolDetailPage() {
|
|||
if (config) {
|
||||
setEndCallMessageType(config.messageType || "none");
|
||||
setCustomMessage(config.customMessage || "");
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
setAudioRecordingId((config as any).audioRecordingId || "");
|
||||
setAudioRecordingId(config.audioRecordingId || "");
|
||||
setEndCallReason(config.endCallReason ?? false);
|
||||
setEndCallReasonDescription(config.endCallReasonDescription || "");
|
||||
} else {
|
||||
|
|
@ -173,8 +184,7 @@ export default function ToolDetailPage() {
|
|||
setTransferDestination(config.destination || "");
|
||||
setTransferMessageType(config.messageType || "none");
|
||||
setCustomMessage(config.customMessage || "");
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
setTransferAudioRecordingId((config as any).audioRecordingId || "");
|
||||
setTransferAudioRecordingId(config.audioRecordingId || "");
|
||||
setTransferTimeout(config.timeout ?? 30);
|
||||
} else {
|
||||
setTransferDestination("");
|
||||
|
|
@ -183,19 +193,35 @@ export default function ToolDetailPage() {
|
|||
setTransferAudioRecordingId("");
|
||||
setTransferTimeout(30);
|
||||
}
|
||||
} else if (tool.category === "mcp") {
|
||||
// Populate MCP specific fields
|
||||
const config = tool.definition?.config as
|
||||
| { url?: string; credential_uuid?: string | null; tools_filter?: string[] }
|
||||
| undefined;
|
||||
if (config) {
|
||||
setMcpUrl(config.url || "");
|
||||
setMcpCredentialUuid(config.credential_uuid || "");
|
||||
setMcpToolsFilter(
|
||||
Array.isArray(config.tools_filter)
|
||||
? config.tools_filter.join(", ")
|
||||
: ""
|
||||
);
|
||||
} else {
|
||||
setMcpUrl("");
|
||||
setMcpCredentialUuid("");
|
||||
setMcpToolsFilter("");
|
||||
}
|
||||
} else {
|
||||
// Populate HTTP API specific fields
|
||||
const config = tool.definition?.config as HttpApiConfigWithParams | undefined;
|
||||
const config = tool.definition?.config as HttpApiToolDefinition["config"] | undefined;
|
||||
if (config) {
|
||||
setHttpMethod((config.method as HttpMethod) || "POST");
|
||||
setUrl(config.url || "");
|
||||
setCredentialUuid(config.credential_uuid || "");
|
||||
setTimeoutMs(config.timeout_ms || 5000);
|
||||
setCustomMessage(config.customMessage || "");
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
setCustomMessageType((config as any).customMessageType || "text");
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
setCustomMessageRecordingId((config as any).customMessageRecordingId || "");
|
||||
setCustomMessageType(config.customMessageType || "text");
|
||||
setCustomMessageRecordingId(config.customMessageRecordingId || "");
|
||||
|
||||
// Convert headers object to array
|
||||
if (config.headers) {
|
||||
|
|
@ -212,9 +238,9 @@ export default function ToolDetailPage() {
|
|||
// Load parameters
|
||||
if (config.parameters && Array.isArray(config.parameters)) {
|
||||
setParameters(
|
||||
config.parameters.map((p: ToolParameter) => ({
|
||||
config.parameters.map((p) => ({
|
||||
name: p.name || "",
|
||||
type: p.type || "string",
|
||||
type: normalizeParameterType(p.type),
|
||||
description: p.description || "",
|
||||
required: p.required ?? true,
|
||||
}))
|
||||
|
|
@ -227,7 +253,7 @@ export default function ToolDetailPage() {
|
|||
setPresetParameters(
|
||||
config.preset_parameters.map((p) => ({
|
||||
name: p.name || "",
|
||||
type: p.type || "string",
|
||||
type: normalizeParameterType(p.type),
|
||||
valueTemplate: p.value_template || "",
|
||||
required: p.required ?? true,
|
||||
}))
|
||||
|
|
@ -275,6 +301,16 @@ export default function ToolDetailPage() {
|
|||
setError("Please enter a valid phone number (E.164 format) or SIP endpoint (e.g., PJSIP/1234)");
|
||||
return;
|
||||
}
|
||||
} else if (tool.category === "mcp") {
|
||||
// Validate MCP server URL (must be http(s))
|
||||
if (!mcpUrl.trim()) {
|
||||
setError("Please enter the MCP server URL");
|
||||
return;
|
||||
}
|
||||
if (!MCP_URL_PATTERN.test(mcpUrl.trim())) {
|
||||
setError("MCP server URL must start with http:// or https://");
|
||||
return;
|
||||
}
|
||||
} else if (tool.category !== "end_call") {
|
||||
// Validate URL for HTTP API tools
|
||||
const urlValidation = validateUrl(url);
|
||||
|
|
@ -305,7 +341,7 @@ export default function ToolDetailPage() {
|
|||
setSaveSuccess(false);
|
||||
const accessToken = await getAccessToken();
|
||||
|
||||
let requestBody;
|
||||
let requestBody: UpdateToolRequest;
|
||||
|
||||
if (tool.category === "calculator") {
|
||||
// Built-in tool - only name/description, no config
|
||||
|
|
@ -351,6 +387,12 @@ export default function ToolDetailPage() {
|
|||
},
|
||||
},
|
||||
};
|
||||
} else if (tool.category === "mcp") {
|
||||
requestBody = {
|
||||
name,
|
||||
description: description || undefined,
|
||||
definition: createMcpDefinition(mcpUrl, mcpCredentialUuid, mcpToolsFilter),
|
||||
};
|
||||
} else {
|
||||
// Build HTTP API request body
|
||||
const headersObject: Record<string, string> = {};
|
||||
|
|
@ -399,8 +441,7 @@ export default function ToolDetailPage() {
|
|||
|
||||
const response = await updateToolApiV1ToolsToolUuidPut({
|
||||
path: { tool_uuid: toolUuid },
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
body: requestBody as any,
|
||||
body: requestBody,
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
|
|
@ -510,6 +551,7 @@ const data = await response.json();`;
|
|||
const isEndCallTool = tool.category === "end_call";
|
||||
const isTransferCallTool = tool.category === "transfer_call";
|
||||
const isBuiltinTool = tool.category === "calculator";
|
||||
const isMcpTool = tool.category === "mcp";
|
||||
const categoryConfig = getCategoryConfig(tool.category as ToolCategory);
|
||||
|
||||
return (
|
||||
|
|
@ -545,7 +587,7 @@ const data = await response.json();`;
|
|||
</div>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
{!isEndCallTool && !isTransferCallTool && !isBuiltinTool && (
|
||||
{!isEndCallTool && !isTransferCallTool && !isBuiltinTool && !isMcpTool && (
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => setShowCodeDialog(true)}
|
||||
|
|
@ -613,6 +655,79 @@ const data = await response.json();`;
|
|||
timeout={transferTimeout}
|
||||
onTimeoutChange={setTransferTimeout}
|
||||
/>
|
||||
) : isMcpTool ? (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>MCP Server Configuration</CardTitle>
|
||||
<CardDescription>
|
||||
Configure the MCP server endpoint. Its tools become available to the agent.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-6">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="mcp-name">Tool Name</Label>
|
||||
<Input
|
||||
id="mcp-name"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
placeholder="e.g., Customer MCP Server"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="mcp-description">Description</Label>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Provide a description which makes it easy for LLM to understand what this tool does
|
||||
</p>
|
||||
<Textarea
|
||||
id="mcp-description"
|
||||
value={description}
|
||||
onChange={(e) => setDescription(e.target.value)}
|
||||
placeholder="What does this MCP server provide?"
|
||||
rows={3}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="mcp-url">MCP Server URL</Label>
|
||||
<Input
|
||||
id="mcp-url"
|
||||
value={mcpUrl}
|
||||
onChange={(e) => setMcpUrl(e.target.value)}
|
||||
placeholder="https://your-mcp-server.example.com/mcp"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label>Transport</Label>
|
||||
<Input
|
||||
value="Streamable HTTP"
|
||||
disabled
|
||||
readOnly
|
||||
/>
|
||||
</div>
|
||||
|
||||
<CredentialSelector
|
||||
value={mcpCredentialUuid}
|
||||
onChange={setMcpCredentialUuid}
|
||||
label="Credential (Optional)"
|
||||
description="Select a credential for authenticating with the MCP server, or leave empty for no auth."
|
||||
/>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="mcp-tools-filter">Tools Filter (Optional)</Label>
|
||||
<Input
|
||||
id="mcp-tools-filter"
|
||||
value={mcpToolsFilter}
|
||||
onChange={(e) => setMcpToolsFilter(e.target.value)}
|
||||
placeholder="e.g., tool_one, tool_two"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Comma-separated list of tool names to allow. Leave empty to expose all tools from the server.
|
||||
</p>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
) : (
|
||||
<HttpApiToolConfig
|
||||
name={name}
|
||||
|
|
|
|||
|
|
@ -8,11 +8,12 @@ import type {
|
|||
EndCallConfig,
|
||||
EndCallToolDefinition,
|
||||
HttpApiToolDefinition,
|
||||
McpToolDefinition,
|
||||
TransferCallConfig,
|
||||
TransferCallToolDefinition,
|
||||
} from "@/client/types.gen";
|
||||
|
||||
export type ToolCategory = "http_api" | "end_call" | "transfer_call" | "calculator" | "native" | "integration";
|
||||
export type ToolCategory = "http_api" | "end_call" | "transfer_call" | "calculator" | "native" | "integration" | "mcp";
|
||||
|
||||
export type EndCallMessageType = "none" | "custom" | "audio";
|
||||
|
||||
|
|
@ -75,6 +76,14 @@ export const TOOL_CATEGORIES: ToolCategoryConfig[] = [
|
|||
description: "Perform arithmetic calculations (supports +, -, *, /, **, %, and parentheses)",
|
||||
},
|
||||
},
|
||||
{
|
||||
value: "mcp",
|
||||
label: "MCP Server",
|
||||
description: "Connect a customer MCP server; its tools become available to the agent",
|
||||
icon: Puzzle,
|
||||
iconName: "puzzle",
|
||||
iconColor: "#8B5CF6",
|
||||
},
|
||||
{
|
||||
value: "native",
|
||||
label: "Native (Coming Soon)",
|
||||
|
|
@ -128,6 +137,8 @@ export function getToolTypeLabel(category: string): string {
|
|||
return "Native Tool";
|
||||
case "integration":
|
||||
return "Integration Tool";
|
||||
case "mcp":
|
||||
return "MCP Server Tool";
|
||||
default:
|
||||
return "Tool";
|
||||
}
|
||||
|
|
@ -149,7 +160,12 @@ export const DEFAULT_TRANSFER_CALL_CONFIG: TransferCallConfig = {
|
|||
timeout: 30,
|
||||
};
|
||||
|
||||
export type ToolDefinition = HttpApiToolDefinition | EndCallToolDefinition | TransferCallToolDefinition | CalculatorToolDefinition;
|
||||
export type ToolDefinition =
|
||||
| HttpApiToolDefinition
|
||||
| EndCallToolDefinition
|
||||
| TransferCallToolDefinition
|
||||
| CalculatorToolDefinition
|
||||
| McpToolDefinition;
|
||||
|
||||
export function createEndCallDefinition(config: EndCallConfig): EndCallToolDefinition {
|
||||
return {
|
||||
|
|
@ -185,6 +201,28 @@ export function createCalculatorDefinition(): CalculatorToolDefinition {
|
|||
};
|
||||
}
|
||||
|
||||
export const MCP_URL_PATTERN = /^https?:\/\//i;
|
||||
|
||||
export function createMcpDefinition(
|
||||
url: string,
|
||||
credentialUuid: string,
|
||||
toolsFilterCsv: string,
|
||||
): McpToolDefinition {
|
||||
return {
|
||||
schema_version: 1,
|
||||
type: "mcp" as const,
|
||||
config: {
|
||||
transport: "streamable_http" as const,
|
||||
url: url.trim(),
|
||||
credential_uuid: credentialUuid || null,
|
||||
tools_filter: toolsFilterCsv
|
||||
.split(",")
|
||||
.map((s) => s.trim())
|
||||
.filter((s) => s.length > 0),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function createToolDefinition(category: ToolCategory): ToolDefinition {
|
||||
switch (category) {
|
||||
case "end_call":
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ import {
|
|||
listToolsApiV1ToolsGet,
|
||||
unarchiveToolApiV1ToolsToolUuidUnarchivePost,
|
||||
} from "@/client/sdk.gen";
|
||||
import type { ToolResponse } from "@/client/types.gen";
|
||||
import type { CreateToolRequest, ToolResponse } from "@/client/types.gen";
|
||||
import { CredentialSelector } from "@/components/http";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
|
|
@ -41,8 +42,10 @@ import { Skeleton } from "@/components/ui/skeleton";
|
|||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
import {
|
||||
createMcpDefinition,
|
||||
createToolDefinition,
|
||||
getCategoryConfig,
|
||||
MCP_URL_PATTERN,
|
||||
renderToolIcon,
|
||||
TOOL_CATEGORIES,
|
||||
type ToolCategory,
|
||||
|
|
@ -63,6 +66,11 @@ export default function ToolsPage() {
|
|||
const [error, setError] = useState<string | null>(null);
|
||||
const [createError, setCreateError] = useState<string | null>(null);
|
||||
|
||||
// MCP-specific create dialog state
|
||||
const [mcpUrl, setMcpUrl] = useState("");
|
||||
const [mcpCredentialUuid, setMcpCredentialUuid] = useState("");
|
||||
const [mcpToolsFilter, setMcpToolsFilter] = useState("");
|
||||
|
||||
// Redirect if not authenticated
|
||||
useEffect(() => {
|
||||
if (!loading && !user) {
|
||||
|
|
@ -108,21 +116,38 @@ export default function ToolsPage() {
|
|||
return;
|
||||
}
|
||||
|
||||
if (newToolCategory === "mcp" && !mcpUrl.trim()) {
|
||||
setCreateError("Please enter the MCP server URL");
|
||||
return;
|
||||
}
|
||||
|
||||
if (newToolCategory === "mcp" && !MCP_URL_PATTERN.test(mcpUrl.trim())) {
|
||||
setCreateError("MCP server URL must start with http:// or https://");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setIsCreating(true);
|
||||
setCreateError(null);
|
||||
const accessToken = await getAccessToken();
|
||||
|
||||
const categoryConfig = getCategoryConfig(newToolCategory);
|
||||
|
||||
const definition = newToolCategory === "mcp"
|
||||
? createMcpDefinition(mcpUrl, mcpCredentialUuid, mcpToolsFilter)
|
||||
: createToolDefinition(newToolCategory);
|
||||
|
||||
const requestBody: CreateToolRequest = {
|
||||
name: newToolName,
|
||||
description: newToolDescription || undefined,
|
||||
category: newToolCategory,
|
||||
icon: categoryConfig?.iconName || "globe",
|
||||
icon_color: categoryConfig?.iconColor || "#3B82F6",
|
||||
definition,
|
||||
};
|
||||
|
||||
const response = await createToolApiV1ToolsPost({
|
||||
body: {
|
||||
name: newToolName,
|
||||
description: newToolDescription || undefined,
|
||||
category: newToolCategory,
|
||||
icon: categoryConfig?.iconName || "globe",
|
||||
icon_color: categoryConfig?.iconColor || "#3B82F6",
|
||||
definition: createToolDefinition(newToolCategory),
|
||||
},
|
||||
body: requestBody,
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
|
|
@ -139,6 +164,9 @@ export default function ToolsPage() {
|
|||
setNewToolName("");
|
||||
setNewToolDescription("");
|
||||
setNewToolCategory("http_api");
|
||||
setMcpUrl("");
|
||||
setMcpCredentialUuid("");
|
||||
setMcpToolsFilter("");
|
||||
// Navigate to the new tool's detail page
|
||||
router.push(`/tools/${response.data.tool_uuid}`);
|
||||
}
|
||||
|
|
@ -233,6 +261,8 @@ export default function ToolsPage() {
|
|||
return <Badge variant="secondary">Native</Badge>;
|
||||
case "integration":
|
||||
return <Badge variant="outline">Integration</Badge>;
|
||||
case "mcp":
|
||||
return <Badge variant="outline">MCP</Badge>;
|
||||
default:
|
||||
return <Badge variant="outline">{category}</Badge>;
|
||||
}
|
||||
|
|
@ -465,7 +495,14 @@ export default function ToolsPage() {
|
|||
{/* Create Tool Dialog */}
|
||||
<Dialog open={isCreateDialogOpen} onOpenChange={(open) => {
|
||||
setIsCreateDialogOpen(open);
|
||||
if (open) setCreateError(null);
|
||||
if (open) {
|
||||
setCreateError(null);
|
||||
} else {
|
||||
// Reset MCP fields when dialog is closed without creating
|
||||
setMcpUrl("");
|
||||
setMcpCredentialUuid("");
|
||||
setMcpToolsFilter("");
|
||||
}
|
||||
}}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
|
|
@ -482,6 +519,7 @@ export default function ToolsPage() {
|
|||
onValueChange={(v) => {
|
||||
const category = v as ToolCategory;
|
||||
setNewToolCategory(category);
|
||||
setCreateError(null);
|
||||
const categoryConfig = getCategoryConfig(category);
|
||||
if (categoryConfig?.autoFill) {
|
||||
setNewToolName(categoryConfig.autoFill.name);
|
||||
|
|
@ -532,6 +570,46 @@ export default function ToolsPage() {
|
|||
placeholder="What does this tool do?"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{newToolCategory === "mcp" && (
|
||||
<>
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="mcp-url">MCP Server URL</Label>
|
||||
<Input
|
||||
id="mcp-url"
|
||||
value={mcpUrl}
|
||||
onChange={(e) => setMcpUrl(e.target.value)}
|
||||
placeholder="https://your-mcp-server.example.com/mcp"
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<Label>Transport</Label>
|
||||
<Input
|
||||
value="Streamable HTTP"
|
||||
disabled
|
||||
readOnly
|
||||
/>
|
||||
</div>
|
||||
<CredentialSelector
|
||||
value={mcpCredentialUuid}
|
||||
onChange={setMcpCredentialUuid}
|
||||
label="Credential (Optional)"
|
||||
description="Select a credential for authenticating with the MCP server, or leave empty for no auth."
|
||||
/>
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="mcp-tools-filter">Tools Filter (Optional)</Label>
|
||||
<Input
|
||||
id="mcp-tools-filter"
|
||||
value={mcpToolsFilter}
|
||||
onChange={(e) => setMcpToolsFilter(e.target.value)}
|
||||
placeholder="e.g., tool_one, tool_two"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Comma-separated list of tool names to allow. Leave empty to expose all tools from the server.
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
{createError && (
|
||||
<div className="p-3 bg-destructive/10 border border-destructive/20 rounded-lg text-destructive text-sm">
|
||||
|
|
|
|||
|
|
@ -325,14 +325,33 @@ function RenderWorkflow({ initialWorkflowName, workflowId, workflowUuid, initial
|
|||
await saveWorkflowConfigurations(workflowConfigurations, newName);
|
||||
}, [saveWorkflowConfigurations, workflowConfigurations]);
|
||||
|
||||
const updateTool = useCallback(
|
||||
(toolUuid: string, updater: (tool: ToolResponse) => ToolResponse) => {
|
||||
setTools((prev) =>
|
||||
prev?.map((tool) =>
|
||||
tool.tool_uuid === toolUuid ? updater(tool) : tool,
|
||||
),
|
||||
);
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
// Memoize the context value to prevent unnecessary re-renders
|
||||
const workflowContextValue = useMemo(() => ({
|
||||
saveWorkflow: guardedSaveWorkflow,
|
||||
documents,
|
||||
tools,
|
||||
updateTool,
|
||||
recordings,
|
||||
readOnly: isViewingHistoricalVersion,
|
||||
}), [guardedSaveWorkflow, documents, tools, recordings, isViewingHistoricalVersion]);
|
||||
}), [
|
||||
guardedSaveWorkflow,
|
||||
documents,
|
||||
tools,
|
||||
updateTool,
|
||||
recordings,
|
||||
isViewingHistoricalVersion,
|
||||
]);
|
||||
|
||||
return (
|
||||
<WorkflowProvider value={workflowContextValue}>
|
||||
|
|
|
|||
|
|
@ -7,6 +7,10 @@ interface WorkflowContextType {
|
|||
saveWorkflow: (updateWorkflowDefinition?: boolean) => Promise<void>;
|
||||
documents?: DocumentResponseSchema[];
|
||||
tools?: ToolResponse[];
|
||||
updateTool?: (
|
||||
toolUuid: string,
|
||||
updater: (tool: ToolResponse) => ToolResponse,
|
||||
) => void;
|
||||
recordings?: RecordingResponseSchema[];
|
||||
readOnly?: boolean;
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -949,7 +949,9 @@ export type CreateToolRequest = {
|
|||
type: 'transfer_call';
|
||||
} & TransferCallToolDefinition) | ({
|
||||
type: 'calculator';
|
||||
} & CalculatorToolDefinition);
|
||||
} & CalculatorToolDefinition) | ({
|
||||
type: 'mcp';
|
||||
} & McpToolDefinition);
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -2097,6 +2099,102 @@ export type MpsCreditsResponse = {
|
|||
total_quota: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* McpRefreshResponse
|
||||
*
|
||||
* Result of re-discovering an MCP server's tool catalog.
|
||||
*/
|
||||
export type McpRefreshResponse = {
|
||||
/**
|
||||
* Tool Uuid
|
||||
*/
|
||||
tool_uuid: string;
|
||||
/**
|
||||
* Discovered Tools
|
||||
*/
|
||||
discovered_tools?: Array<unknown>;
|
||||
/**
|
||||
* Error
|
||||
*/
|
||||
error?: string | null;
|
||||
};
|
||||
|
||||
/**
|
||||
* McpToolConfig
|
||||
*
|
||||
* Configuration for an MCP tool definition.
|
||||
*/
|
||||
export type McpToolConfig = {
|
||||
/**
|
||||
* Transport
|
||||
*
|
||||
* MCP transport protocol
|
||||
*/
|
||||
transport?: 'streamable_http';
|
||||
/**
|
||||
* Url
|
||||
*
|
||||
* MCP server URL (must be http:// or https://)
|
||||
*/
|
||||
url: string;
|
||||
/**
|
||||
* Credential Uuid
|
||||
*
|
||||
* Reference to ExternalCredentialModel for auth
|
||||
*/
|
||||
credential_uuid?: string | null;
|
||||
/**
|
||||
* Tools Filter
|
||||
*
|
||||
* Allowlist of MCP tool names to expose (empty = all tools)
|
||||
*/
|
||||
tools_filter?: Array<string>;
|
||||
/**
|
||||
* Timeout Secs
|
||||
*
|
||||
* Connection timeout in seconds
|
||||
*/
|
||||
timeout_secs?: number;
|
||||
/**
|
||||
* Sse Read Timeout Secs
|
||||
*
|
||||
* SSE read timeout in seconds
|
||||
*/
|
||||
sse_read_timeout_secs?: number;
|
||||
/**
|
||||
* Discovered Tools
|
||||
*
|
||||
* Server-managed cache of the MCP server's tool catalog [{name, description}]. Populated best-effort by the backend.
|
||||
*/
|
||||
discovered_tools?: Array<{
|
||||
[key: string]: unknown;
|
||||
}>;
|
||||
};
|
||||
|
||||
/**
|
||||
* McpToolDefinition
|
||||
*
|
||||
* Persisted MCP tool definition.
|
||||
*/
|
||||
export type McpToolDefinition = {
|
||||
/**
|
||||
* Schema Version
|
||||
*
|
||||
* Schema version
|
||||
*/
|
||||
schema_version?: number;
|
||||
/**
|
||||
* Type
|
||||
*
|
||||
* Tool type
|
||||
*/
|
||||
type: 'mcp';
|
||||
/**
|
||||
* MCP server configuration
|
||||
*/
|
||||
config: McpToolConfig;
|
||||
};
|
||||
|
||||
/**
|
||||
* NodeCategory
|
||||
*
|
||||
|
|
@ -3842,7 +3940,9 @@ export type UpdateToolRequest = {
|
|||
type: 'transfer_call';
|
||||
} & TransferCallToolDefinition) | ({
|
||||
type: 'calculator';
|
||||
} & CalculatorToolDefinition) | null;
|
||||
} & CalculatorToolDefinition) | ({
|
||||
type: 'mcp';
|
||||
} & McpToolDefinition) | null;
|
||||
/**
|
||||
* Status
|
||||
*/
|
||||
|
|
@ -7652,6 +7752,50 @@ export type UpdateToolApiV1ToolsToolUuidPutResponses = {
|
|||
|
||||
export type UpdateToolApiV1ToolsToolUuidPutResponse = UpdateToolApiV1ToolsToolUuidPutResponses[keyof UpdateToolApiV1ToolsToolUuidPutResponses];
|
||||
|
||||
export type RefreshMcpToolsApiV1ToolsToolUuidMcpRefreshPostData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
/**
|
||||
* Authorization
|
||||
*/
|
||||
authorization?: string | null;
|
||||
/**
|
||||
* X-Api-Key
|
||||
*/
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path: {
|
||||
/**
|
||||
* Tool Uuid
|
||||
*/
|
||||
tool_uuid: string;
|
||||
};
|
||||
query?: never;
|
||||
url: '/api/v1/tools/{tool_uuid}/mcp/refresh';
|
||||
};
|
||||
|
||||
export type RefreshMcpToolsApiV1ToolsToolUuidMcpRefreshPostErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type RefreshMcpToolsApiV1ToolsToolUuidMcpRefreshPostError = RefreshMcpToolsApiV1ToolsToolUuidMcpRefreshPostErrors[keyof RefreshMcpToolsApiV1ToolsToolUuidMcpRefreshPostErrors];
|
||||
|
||||
export type RefreshMcpToolsApiV1ToolsToolUuidMcpRefreshPostResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: McpRefreshResponse;
|
||||
};
|
||||
|
||||
export type RefreshMcpToolsApiV1ToolsToolUuidMcpRefreshPostResponse = RefreshMcpToolsApiV1ToolsToolUuidMcpRefreshPostResponses[keyof RefreshMcpToolsApiV1ToolsToolUuidMcpRefreshPostResponses];
|
||||
|
||||
export type UnarchiveToolApiV1ToolsToolUuidUnarchivePostData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
|
|
|
|||
|
|
@ -9,9 +9,10 @@ import { Badge } from "@/components/ui/badge";
|
|||
interface ToolBadgesProps {
|
||||
toolUuids: string[];
|
||||
onStaleUuidsDetected?: (staleUuids: string[]) => void;
|
||||
mcpToolFilters?: Record<string, string[]>;
|
||||
}
|
||||
|
||||
export function ToolBadges({ toolUuids, onStaleUuidsDetected }: ToolBadgesProps) {
|
||||
export function ToolBadges({ toolUuids, onStaleUuidsDetected, mcpToolFilters }: ToolBadgesProps) {
|
||||
const { tools } = useWorkflow();
|
||||
const [selectedTools, setSelectedTools] = useState<ToolResponse[]>([]);
|
||||
|
||||
|
|
@ -50,15 +51,29 @@ export function ToolBadges({ toolUuids, onStaleUuidsDetected }: ToolBadgesProps)
|
|||
|
||||
return (
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{selectedTools.map((tool) => (
|
||||
<Badge
|
||||
key={tool.tool_uuid}
|
||||
variant="outline"
|
||||
className="text-xs"
|
||||
>
|
||||
{tool.name}
|
||||
</Badge>
|
||||
))}
|
||||
{selectedTools.map((tool) => {
|
||||
const isMcp = tool.category === "mcp";
|
||||
const enabledFns = isMcp ? (mcpToolFilters?.[tool.tool_uuid] ?? []) : [];
|
||||
|
||||
if (isMcp && enabledFns.length > 0) {
|
||||
return enabledFns.map((fn) => (
|
||||
<Badge
|
||||
key={`${tool.tool_uuid}-${fn}`}
|
||||
variant="outline"
|
||||
className="text-xs flex items-center gap-1.5"
|
||||
>
|
||||
<span className="h-1.5 w-1.5 rounded-full bg-green-500 shrink-0" />
|
||||
{fn}
|
||||
</Badge>
|
||||
));
|
||||
}
|
||||
|
||||
return (
|
||||
<Badge key={tool.tool_uuid} variant="outline" className="text-xs">
|
||||
{tool.name}
|
||||
</Badge>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,15 +1,20 @@
|
|||
"use client";
|
||||
|
||||
import { ExternalLink } from "lucide-react";
|
||||
import { ExternalLink, RefreshCw } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useState } from "react";
|
||||
|
||||
import { renderToolIcon } from "@/app/tools/config";
|
||||
import { useWorkflowOptional } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import type { ToolResponse } from "@/client/types.gen";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { TOOLS_INTRODUCTION_DOC_URL } from "@/constants/documentation";
|
||||
|
||||
import { type McpDiscoveredTool, refreshMcpTools } from "./mcpRefresh";
|
||||
|
||||
interface ToolSelectorProps {
|
||||
value: string[];
|
||||
onChange: (uuids: string[]) => void;
|
||||
|
|
@ -18,6 +23,46 @@ interface ToolSelectorProps {
|
|||
label?: string;
|
||||
description?: string;
|
||||
showLabel?: boolean;
|
||||
mcpToolFilters?: Record<string, string[]>;
|
||||
onMcpToolFiltersChange?: (next: Record<string, string[]>) => void;
|
||||
}
|
||||
|
||||
function isMcp(tool: ToolResponse): boolean {
|
||||
return tool.category === "mcp";
|
||||
}
|
||||
|
||||
function discoveredOf(tool: ToolResponse): McpDiscoveredTool[] {
|
||||
const def = (tool.definition ?? {}) as {
|
||||
config?: { discovered_tools?: McpDiscoveredTool[] };
|
||||
};
|
||||
return def.config?.discovered_tools ?? [];
|
||||
}
|
||||
|
||||
function withDiscoveredTools(
|
||||
tool: ToolResponse,
|
||||
discoveredTools: McpDiscoveredTool[],
|
||||
): ToolResponse {
|
||||
const definition =
|
||||
tool.definition && typeof tool.definition === "object"
|
||||
? tool.definition
|
||||
: {};
|
||||
const config =
|
||||
"config" in definition &&
|
||||
definition.config &&
|
||||
typeof definition.config === "object"
|
||||
? definition.config
|
||||
: {};
|
||||
|
||||
return {
|
||||
...tool,
|
||||
definition: {
|
||||
...definition,
|
||||
config: {
|
||||
...config,
|
||||
discovered_tools: discoveredTools,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function ToolSelector({
|
||||
|
|
@ -28,18 +73,64 @@ export function ToolSelector({
|
|||
label = "Tools",
|
||||
description = "Select tools that the agent can use during the conversation.",
|
||||
showLabel = true,
|
||||
mcpToolFilters = {},
|
||||
onMcpToolFiltersChange = () => {},
|
||||
}: ToolSelectorProps) {
|
||||
// Filter to only show active tools
|
||||
const activeTools = tools.filter((tool) => tool.status === "active");
|
||||
const workflow = useWorkflowOptional();
|
||||
const activeTools = tools.filter((t) => t.status === "active");
|
||||
const httpTools = activeTools.filter((t) => !isMcp(t));
|
||||
const mcpTools = activeTools.filter(isMcp);
|
||||
|
||||
const handleToggle = (toolUuid: string, checked: boolean) => {
|
||||
if (checked) {
|
||||
onChange([...value, toolUuid]);
|
||||
} else {
|
||||
onChange(value.filter((id) => id !== toolUuid));
|
||||
}
|
||||
const [refreshing, setRefreshing] = useState<Record<string, boolean>>({});
|
||||
const [refreshError, setRefreshError] = useState<Record<string, string>>({});
|
||||
|
||||
const httpHandleToggle = (toolUuid: string, checked: boolean) => {
|
||||
if (checked) onChange([...value, toolUuid]);
|
||||
else onChange(value.filter((id) => id !== toolUuid));
|
||||
};
|
||||
|
||||
const mcpFnToggle = (toolUuid: string, fnName: string, checked: boolean) => {
|
||||
const current = mcpToolFilters[toolUuid] ?? [];
|
||||
const nextFns = checked
|
||||
? Array.from(new Set([...current, fnName]))
|
||||
: current.filter((n) => n !== fnName);
|
||||
|
||||
const nextFilters = { ...mcpToolFilters };
|
||||
if (nextFns.length > 0) nextFilters[toolUuid] = nextFns;
|
||||
else delete nextFilters[toolUuid];
|
||||
onMcpToolFiltersChange(nextFilters);
|
||||
|
||||
const hasUuid = value.includes(toolUuid);
|
||||
if (nextFns.length > 0 && !hasUuid) onChange([...value, toolUuid]);
|
||||
else if (nextFns.length === 0 && hasUuid)
|
||||
onChange(value.filter((id) => id !== toolUuid));
|
||||
};
|
||||
|
||||
const doRefresh = async (toolUuid: string) => {
|
||||
setRefreshing((r) => ({ ...r, [toolUuid]: true }));
|
||||
setRefreshError((e) => {
|
||||
const n = { ...e };
|
||||
delete n[toolUuid];
|
||||
return n;
|
||||
});
|
||||
const res = await refreshMcpTools(toolUuid);
|
||||
setRefreshing((r) => ({ ...r, [toolUuid]: false }));
|
||||
if (res.error && res.discovered_tools.length === 0) {
|
||||
setRefreshError((e) => ({ ...e, [toolUuid]: res.error as string }));
|
||||
return;
|
||||
}
|
||||
workflow?.updateTool?.(toolUuid, (tool) =>
|
||||
withDiscoveredTools(tool, res.discovered_tools),
|
||||
);
|
||||
};
|
||||
|
||||
const selectedCount =
|
||||
httpTools.filter((t) => value.includes(t.tool_uuid)).length +
|
||||
mcpTools.reduce(
|
||||
(acc, t) => acc + (mcpToolFilters[t.tool_uuid]?.length ?? 0),
|
||||
0,
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="grid gap-2">
|
||||
{showLabel && (
|
||||
|
|
@ -48,7 +139,14 @@ export function ToolSelector({
|
|||
{description && (
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
{description}{" "}
|
||||
<a href={TOOLS_INTRODUCTION_DOC_URL} target="_blank" rel="noopener noreferrer" className="underline">Learn more</a>
|
||||
<a
|
||||
href={TOOLS_INTRODUCTION_DOC_URL}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
Learn more
|
||||
</a>
|
||||
</Label>
|
||||
)}
|
||||
</>
|
||||
|
|
@ -67,45 +165,178 @@ export function ToolSelector({
|
|||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="border rounded-md divide-y">
|
||||
{activeTools.map((tool) => {
|
||||
const isSelected = value.includes(tool.tool_uuid);
|
||||
return (
|
||||
<label
|
||||
key={tool.tool_uuid}
|
||||
className={`flex items-center gap-3 p-3 cursor-pointer hover:bg-muted/50 ${
|
||||
disabled ? "opacity-50 cursor-not-allowed" : ""
|
||||
}`}
|
||||
>
|
||||
<Checkbox
|
||||
checked={isSelected}
|
||||
disabled={disabled}
|
||||
onCheckedChange={(checked) => {
|
||||
handleToggle(tool.tool_uuid, checked === true);
|
||||
}}
|
||||
/>
|
||||
<div
|
||||
className="w-6 h-6 rounded flex items-center justify-center shrink-0"
|
||||
style={{
|
||||
backgroundColor: tool.icon_color || "#3B82F6",
|
||||
}}
|
||||
>
|
||||
{renderToolIcon(tool.category, "h-3 w-3 text-white")}
|
||||
<Tabs defaultValue="http">
|
||||
<TabsList>
|
||||
<TabsTrigger value="http">
|
||||
HTTP & Tools ({httpTools.length})
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="mcp">
|
||||
MCP ({mcpTools.length})
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="http">
|
||||
<div className="border rounded-md divide-y">
|
||||
{httpTools.length === 0 && (
|
||||
<div className="p-3 text-sm text-muted-foreground">
|
||||
No HTTP/native tools.
|
||||
</div>
|
||||
<div className="flex flex-col min-w-0 flex-1">
|
||||
<span className="text-sm font-medium truncate">
|
||||
{tool.name}
|
||||
</span>
|
||||
{tool.description && (
|
||||
<span className="text-xs text-muted-foreground break-words">
|
||||
{tool.description}
|
||||
</span>
|
||||
)}
|
||||
)}
|
||||
{httpTools.map((tool) => {
|
||||
const isSelected = value.includes(tool.tool_uuid);
|
||||
return (
|
||||
<label
|
||||
key={tool.tool_uuid}
|
||||
className={`flex items-center gap-3 p-3 cursor-pointer hover:bg-muted/50 ${
|
||||
disabled ? "opacity-50 cursor-not-allowed" : ""
|
||||
}`}
|
||||
>
|
||||
<Checkbox
|
||||
checked={isSelected}
|
||||
disabled={disabled}
|
||||
onCheckedChange={(c) =>
|
||||
httpHandleToggle(tool.tool_uuid, c === true)
|
||||
}
|
||||
/>
|
||||
<div
|
||||
className="w-6 h-6 rounded flex items-center justify-center shrink-0"
|
||||
style={{
|
||||
backgroundColor: tool.icon_color || "#3B82F6",
|
||||
}}
|
||||
>
|
||||
{renderToolIcon(tool.category, "h-3 w-3 text-white")}
|
||||
</div>
|
||||
<div className="flex flex-col min-w-0 flex-1">
|
||||
<span className="text-sm font-medium truncate">
|
||||
{tool.name}
|
||||
</span>
|
||||
{tool.description && (
|
||||
<span className="text-xs text-muted-foreground break-words">
|
||||
{tool.description}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</label>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="mcp">
|
||||
<div className="border rounded-md divide-y">
|
||||
{mcpTools.length === 0 && (
|
||||
<div className="p-3 text-sm text-muted-foreground">
|
||||
No MCP tools.
|
||||
</div>
|
||||
</label>
|
||||
);
|
||||
})}
|
||||
<div className="p-2 bg-muted/30">
|
||||
)}
|
||||
{mcpTools.map((tool) => {
|
||||
const fns = discoveredOf(tool);
|
||||
const selected = mcpToolFilters[tool.tool_uuid] ?? [];
|
||||
const busy = !!refreshing[tool.tool_uuid];
|
||||
const err = refreshError[tool.tool_uuid];
|
||||
return (
|
||||
<details key={tool.tool_uuid} className="p-3">
|
||||
<summary className="flex items-center gap-3 cursor-pointer list-none">
|
||||
<div
|
||||
className="w-6 h-6 rounded flex items-center justify-center shrink-0"
|
||||
style={{
|
||||
backgroundColor: tool.icon_color || "#8B5CF6",
|
||||
}}
|
||||
>
|
||||
{renderToolIcon(tool.category, "h-3 w-3 text-white")}
|
||||
</div>
|
||||
<div className="flex flex-col min-w-0 flex-1">
|
||||
<span className="text-sm font-medium truncate">
|
||||
{tool.name}
|
||||
</span>
|
||||
{tool.description && (
|
||||
<span className="text-xs text-muted-foreground break-words">
|
||||
{tool.description}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<span className="text-xs text-muted-foreground shrink-0">
|
||||
{selected.length}/{fns.length} tools
|
||||
</span>
|
||||
</summary>
|
||||
|
||||
<div className="mt-3 pl-9 grid gap-2">
|
||||
<div>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
disabled={busy}
|
||||
onClick={() => doRefresh(tool.tool_uuid)}
|
||||
>
|
||||
<RefreshCw
|
||||
className={`h-3 w-3 mr-2 ${busy ? "animate-spin" : ""}`}
|
||||
/>
|
||||
Refresh tools
|
||||
</Button>
|
||||
</div>
|
||||
{err && (
|
||||
<p className="text-xs text-destructive">{err}</p>
|
||||
)}
|
||||
{fns.length === 0 && !err && (
|
||||
<p className="text-xs text-muted-foreground">
|
||||
No tools discovered — Refresh.
|
||||
</p>
|
||||
)}
|
||||
{fns.map((fn) => {
|
||||
const checked = selected.includes(fn.name);
|
||||
return (
|
||||
<label
|
||||
key={fn.name}
|
||||
className="flex items-start gap-3 cursor-pointer"
|
||||
>
|
||||
<Checkbox
|
||||
checked={checked}
|
||||
disabled={disabled}
|
||||
onCheckedChange={(c) =>
|
||||
mcpFnToggle(tool.tool_uuid, fn.name, c === true)
|
||||
}
|
||||
/>
|
||||
<div className="flex flex-col min-w-0 flex-1">
|
||||
<span className="text-sm font-medium">
|
||||
{fn.name}
|
||||
</span>
|
||||
{fn.description && (
|
||||
<span className="text-xs text-muted-foreground break-words">
|
||||
{fn.description}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</label>
|
||||
);
|
||||
})}
|
||||
{selected
|
||||
.filter((n) => !fns.some((f) => f.name === n))
|
||||
.map((n) => (
|
||||
<label
|
||||
key={`stale-${n}`}
|
||||
className="flex items-start gap-3 cursor-pointer opacity-60"
|
||||
>
|
||||
<Checkbox
|
||||
checked
|
||||
disabled={disabled}
|
||||
onCheckedChange={() =>
|
||||
mcpFnToggle(tool.tool_uuid, n, false)
|
||||
}
|
||||
/>
|
||||
<span className="text-sm line-through">
|
||||
{n} (unavailable)
|
||||
</span>
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
</details>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</TabsContent>
|
||||
|
||||
<div className="mt-2 p-2 bg-muted/30 rounded-md">
|
||||
<Link
|
||||
href="/tools"
|
||||
target="_blank"
|
||||
|
|
@ -115,12 +346,12 @@ export function ToolSelector({
|
|||
Manage Tools
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
</Tabs>
|
||||
)}
|
||||
|
||||
{value.length > 0 && (
|
||||
{selectedCount > 0 && (
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{value.length} tool{value.length !== 1 ? "s" : ""} selected
|
||||
{selectedCount} tool{selectedCount !== 1 ? "s" : ""} selected
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
|
|
|||
68
ui/src/components/flow/mcpRefresh.ts
Normal file
68
ui/src/components/flow/mcpRefresh.ts
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
import { refreshMcpToolsApiV1ToolsToolUuidMcpRefreshPost } from "@/client/sdk.gen";
|
||||
import type { McpRefreshResponse } from "@/client/types.gen";
|
||||
|
||||
export interface McpDiscoveredTool {
|
||||
name: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export interface McpRefreshResult {
|
||||
tool_uuid: string;
|
||||
discovered_tools: McpDiscoveredTool[];
|
||||
error: string | null;
|
||||
}
|
||||
|
||||
function normalizeDiscoveredTools(
|
||||
discoveredTools: McpRefreshResponse["discovered_tools"],
|
||||
): McpDiscoveredTool[] {
|
||||
if (!Array.isArray(discoveredTools)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return discoveredTools.flatMap((tool) => {
|
||||
if (!tool || typeof tool !== "object") {
|
||||
return [];
|
||||
}
|
||||
|
||||
const name = "name" in tool ? tool.name : undefined;
|
||||
if (typeof name !== "string" || !name.trim()) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const description =
|
||||
"description" in tool && typeof tool.description === "string"
|
||||
? tool.description
|
||||
: "";
|
||||
|
||||
return [{ name, description }];
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Re-discover an MCP tool's server catalog.
|
||||
* Uses the shared generated `client` (auth bearer is injected by interceptor).
|
||||
*/
|
||||
export async function refreshMcpTools(
|
||||
toolUuid: string,
|
||||
): Promise<McpRefreshResult> {
|
||||
const { data, error } = await refreshMcpToolsApiV1ToolsToolUuidMcpRefreshPost({
|
||||
path: {
|
||||
tool_uuid: toolUuid,
|
||||
},
|
||||
});
|
||||
if (error || !data) {
|
||||
return {
|
||||
tool_uuid: toolUuid,
|
||||
discovered_tools: [],
|
||||
error:
|
||||
typeof error === "string"
|
||||
? error
|
||||
: "Refresh request failed. Check the MCP server and try again.",
|
||||
};
|
||||
}
|
||||
return {
|
||||
tool_uuid: data.tool_uuid,
|
||||
discovered_tools: normalizeDiscoveredTools(data.discovered_tools),
|
||||
error: data.error ?? null,
|
||||
};
|
||||
}
|
||||
|
|
@ -205,6 +205,7 @@ function CanvasPreview({
|
|||
<ToolBadges
|
||||
toolUuids={data.tool_uuids}
|
||||
onStaleUuidsDetected={onStaleTools}
|
||||
mcpToolFilters={data.mcp_tool_filters}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
|
@ -396,14 +397,22 @@ export const GenericNode = memo(({ data, selected, id, type }: GenericNodeProps)
|
|||
const spec = bySpecName.get(type);
|
||||
|
||||
// ── Form state ─────────────────────────────────────────────────────
|
||||
const [values, setValues] = useState<Record<string, unknown>>(() =>
|
||||
spec ? seedValues(data, spec) : {},
|
||||
// mcp_tool_filters is not a spec property, so seedValues won't carry it;
|
||||
// seed merges it back in alongside the spec-derived values.
|
||||
const seed = useCallback(
|
||||
() =>
|
||||
spec
|
||||
? { ...seedValues(data, spec), mcp_tool_filters: data.mcp_tool_filters }
|
||||
: {},
|
||||
[data, spec],
|
||||
);
|
||||
|
||||
const [values, setValues] = useState<Record<string, unknown>>(seed);
|
||||
|
||||
// Re-seed once the spec arrives (initial fetch race).
|
||||
useEffect(() => {
|
||||
if (spec && Object.keys(values).length === 0) {
|
||||
setValues(seedValues(data, spec));
|
||||
setValues(seed());
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [spec]);
|
||||
|
|
@ -464,7 +473,11 @@ export const GenericNode = memo(({ data, selected, id, type }: GenericNodeProps)
|
|||
const isDirty = useMemo(() => {
|
||||
if (!spec) return false;
|
||||
const baseline = seedValues(data, spec);
|
||||
return propertyNames.some((n) => values[n] !== baseline[n]);
|
||||
if (propertyNames.some((n) => values[n] !== baseline[n])) return true;
|
||||
return (
|
||||
JSON.stringify(values.mcp_tool_filters ?? {}) !==
|
||||
JSON.stringify(data.mcp_tool_filters ?? {})
|
||||
);
|
||||
}, [values, data, spec, propertyNames]);
|
||||
|
||||
const handleSave = async () => {
|
||||
|
|
@ -478,12 +491,12 @@ export const GenericNode = memo(({ data, selected, id, type }: GenericNodeProps)
|
|||
};
|
||||
|
||||
const handleOpenChange = (newOpen: boolean) => {
|
||||
if (newOpen && spec) setValues(seedValues(data, spec));
|
||||
if (newOpen && spec) setValues(seed());
|
||||
setOpen(newOpen);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (open && spec) setValues(seedValues(data, spec));
|
||||
if (open && spec) setValues(seed());
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [data, open]);
|
||||
|
||||
|
|
@ -562,6 +575,18 @@ export const GenericNode = memo(({ data, selected, id, type }: GenericNodeProps)
|
|||
tools: tools ?? [],
|
||||
documents: documents ?? [],
|
||||
recordings: recordings ?? [],
|
||||
mcpToolFilters:
|
||||
(values.mcp_tool_filters as
|
||||
| Record<string, string[]>
|
||||
| undefined) ?? {},
|
||||
onMcpToolFiltersChange: (next) =>
|
||||
setValues((prev) => ({
|
||||
...prev,
|
||||
mcp_tool_filters:
|
||||
Object.keys(next).length > 0
|
||||
? next
|
||||
: undefined,
|
||||
})),
|
||||
}}
|
||||
/>
|
||||
{type === "trigger" && (
|
||||
|
|
|
|||
|
|
@ -23,6 +23,10 @@ export interface RendererContext {
|
|||
tools: ToolResponse[];
|
||||
documents: DocumentResponseSchema[];
|
||||
recordings: RecordingResponseSchema[];
|
||||
/** Per-node MCP function allowlist (sibling of tool_uuids on node data). */
|
||||
mcpToolFilters?: Record<string, string[]>;
|
||||
/** Persist a new mcp_tool_filters object onto the node form values. */
|
||||
onMcpToolFiltersChange?: (next: Record<string, string[]>) => void;
|
||||
}
|
||||
|
||||
export interface PropertyInputProps {
|
||||
|
|
@ -83,6 +87,10 @@ export function PropertyInput({ spec, value, onChange, context }: PropertyInputP
|
|||
value={value}
|
||||
onChange={onChange}
|
||||
tools={context.tools}
|
||||
mcpToolFilters={context.mcpToolFilters ?? {}}
|
||||
onMcpToolFiltersChange={
|
||||
context.onMcpToolFiltersChange ?? (() => {})
|
||||
}
|
||||
/>
|
||||
);
|
||||
case "document_refs":
|
||||
|
|
@ -401,7 +409,13 @@ function ToolRefsWidget({
|
|||
value,
|
||||
onChange,
|
||||
tools,
|
||||
}: WidgetProps & { tools: ToolResponse[] }) {
|
||||
mcpToolFilters,
|
||||
onMcpToolFiltersChange,
|
||||
}: WidgetProps & {
|
||||
tools: ToolResponse[];
|
||||
mcpToolFilters: Record<string, string[]>;
|
||||
onMcpToolFiltersChange: (next: Record<string, string[]>) => void;
|
||||
}) {
|
||||
return (
|
||||
<ToolSelector
|
||||
value={(value as string[] | undefined) ?? []}
|
||||
|
|
@ -409,6 +423,8 @@ function ToolRefsWidget({
|
|||
tools={tools}
|
||||
label={spec.display_name}
|
||||
description={spec.description}
|
||||
mcpToolFilters={mcpToolFilters}
|
||||
onMcpToolFiltersChange={onMcpToolFiltersChange}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,6 +55,10 @@ export type FlowNodeData = {
|
|||
qa_sample_rate?: number;
|
||||
// Tools - array of tool UUIDs that can be invoked by this node
|
||||
tool_uuids?: string[];
|
||||
// Per-node MCP function allowlist: { toolUuid: [raw MCP tool name, ...] }.
|
||||
// Default-none: a toolUuid absent here (or mapped to []) exposes zero
|
||||
// functions of that MCP server on this node.
|
||||
mcp_tool_filters?: Record<string, string[]>;
|
||||
// Documents - array of knowledge base document UUIDs that can be referenced by this node
|
||||
document_uuids?: string[];
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue