diff --git a/api/Dockerfile b/api/Dockerfile index bcda1a6..b871000 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -20,7 +20,7 @@ RUN pip install --user --no-cache-dir -r requirements.txt && \ # Copy and install pipecat from local submodule COPY pipecat /tmp/pipecat -RUN pip install --user --no-cache-dir '/tmp/pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb]' && \ +RUN pip install --user --no-cache-dir '/tmp/pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb,mcp]' && \ # Swap opencv-python (pulled by pipecat[webrtc]) for opencv-python-headless # to drop X11/Qt dependencies that otherwise require libxcb etc. in runner. pip uninstall -y opencv-python && \ diff --git a/api/alembic/versions/0a1b2c3d4e5f_add_mcp_in_toolcategory.py b/api/alembic/versions/0a1b2c3d4e5f_add_mcp_in_toolcategory.py new file mode 100644 index 0000000..34b006a --- /dev/null +++ b/api/alembic/versions/0a1b2c3d4e5f_add_mcp_in_toolcategory.py @@ -0,0 +1,64 @@ +"""add mcp in ToolCategory + +Revision ID: 0a1b2c3d4e5f +Revises: 4c1f1e3e8ef2 +Create Date: 2026-05-16 00:00:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +from alembic_postgresql_enum import TableReference + +# revision identifiers, used by Alembic. +revision: str = "0a1b2c3d4e5f" +down_revision: Union[str, None] = "4c1f1e3e8ef2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values( + enum_schema="public", + enum_name="tool_category", + new_values=[ + "http_api", + "end_call", + "transfer_call", + "calculator", + "native", + "integration", + "mcp", + ], + affected_columns=[ + TableReference( + table_schema="public", table_name="tools", column_name="category" + ) + ], + enum_values_to_rename=[], + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values( + enum_schema="public", + enum_name="tool_category", + new_values=[ + "http_api", + "end_call", + "transfer_call", + "calculator", + "native", + "integration", + ], + affected_columns=[ + TableReference( + table_schema="public", table_name="tools", column_name="category" + ) + ], + enum_values_to_rename=[], + ) diff --git a/api/enums.py b/api/enums.py index b7655b1..538f5fd 100644 --- a/api/enums.py +++ b/api/enums.py @@ -133,6 +133,7 @@ class ToolCategory(Enum): CALCULATOR = "calculator" # Built-in calculator tool NATIVE = "native" # Built-in integrations (future: dtmf_input) INTEGRATION = "integration" # Third-party integrations (future: Google Calendar, Salesforce, etc.) + MCP = "mcp" # Customer-provided MCP server exposing a tool catalog class ToolStatus(Enum): diff --git a/api/mcp_server/auth.py b/api/mcp_server/auth.py index 33a4a46..6c19d61 100644 --- a/api/mcp_server/auth.py +++ b/api/mcp_server/auth.py @@ -18,7 +18,9 @@ async def authenticate_mcp_request() -> UserModel: the `langfuse.user.id` / `langfuse.session.id` attributes make the span filterable in the Langfuse UI. """ - headers = get_http_headers() + # FastMCP strips Authorization by default unless explicitly included. + # Preserve it here so Bearer API keys work for MCP tool invocations. + headers = get_http_headers(include={"authorization"}) api_key = headers.get("x-api-key") if not api_key: auth = headers.get("authorization", "") diff --git a/api/routes/tool.py b/api/routes/tool.py index ce3d217..b7fa97e 100644 --- a/api/routes/tool.py +++ b/api/routes/tool.py @@ -1,10 +1,12 @@ """API routes for managing tools.""" +import asyncio import re from datetime import datetime from typing import Annotated, Any, Dict, List, Literal, Optional, Union from fastapi import APIRouter, Depends, HTTPException +from loguru import logger from pydantic import BaseModel, Field, field_validator from api.db import db_client @@ -13,9 +15,23 @@ from api.enums import PostHogEvent, ToolCategory, ToolStatus from api.sdk_expose import sdk_expose from api.services.auth.depends import get_user from api.services.posthog_client import capture_event +from api.services.workflow.mcp_tool_session import discover_mcp_tools +from api.services.workflow.tools.mcp_tool import ( + McpDefinitionError, + validate_mcp_definition, +) +from api.services.workflow.tools.mcp_tool import ( + McpToolConfig as SharedMcpToolConfig, +) +from api.services.workflow.tools.mcp_tool import ( + McpToolDefinition as SharedMcpToolDefinition, +) router = APIRouter(prefix="/tools") +McpToolConfig = SharedMcpToolConfig +McpToolDefinition = SharedMcpToolDefinition + # Request/Response schemas class ToolParameter(BaseModel): @@ -183,6 +199,7 @@ ToolDefinition = Annotated[ EndCallToolDefinition, TransferCallToolDefinition, CalculatorToolDefinition, + McpToolDefinition, ], Field(discriminator="type"), ] @@ -248,6 +265,14 @@ class ToolResponse(BaseModel): from_attributes = True +class McpRefreshResponse(BaseModel): + """Result of re-discovering an MCP server's tool catalog.""" + + tool_uuid: str + discovered_tools: list = Field(default_factory=list) + error: Optional[str] = None + + def build_tool_response(tool, include_created_by: bool = False) -> ToolResponse: """Build a response from a tool model.""" created_by = None @@ -336,6 +361,52 @@ async def list_tools( return [build_tool_response(tool) for tool in tools] +async def _fetch_credential(credential_uuid: Optional[str], organization_id: int): + """Best-effort credential lookup for MCP auth. A missing/failed credential + degrades to ``None`` (unauthenticated) rather than failing the request.""" + if not credential_uuid: + return None + try: + return await db_client.get_credential_by_uuid(credential_uuid, organization_id) + except Exception as e: # noqa: BLE001 + logger.warning(f"MCP: credential fetch failed: {e}") + return None + + +async def _populate_discovered_tools(definition: dict, *, organization_id: int) -> dict: + """Best-effort: for an MCP definition, connect to the server, list its + tools, and overwrite ``config.discovered_tools``. Never raises and never + blocks tool save — a dead server yields ``discovered_tools: []``. Non-MCP + definitions pass through untouched.""" + if not isinstance(definition, dict) or definition.get("type") != "mcp": + return definition + try: + cfg = validate_mcp_definition(definition) + except McpDefinitionError: + return definition + + credential = await _fetch_credential(cfg.get("credential_uuid"), organization_id) + + # Run discovery in an isolated asyncio task so an anyio cancel-scope + # CancelledError doesn't bleed into the parent task and corrupt the + # subsequent DB write. _run() never raises (degrades to []). + async def _run() -> list: + try: + return await discover_mcp_tools( + url=cfg["url"], + credential=credential, + timeout_secs=cfg["timeout_secs"], + sse_read_timeout_secs=cfg["sse_read_timeout_secs"], + ) + except BaseException as e: # noqa: BLE001 + logger.warning(f"MCP discovery failed; caching empty list: {e}") + return [] + + discovered = await asyncio.ensure_future(_run()) + definition["config"]["discovered_tools"] = discovered + return definition + + @router.post("/") async def create_tool( request: CreateToolRequest, @@ -357,11 +428,16 @@ async def create_tool( validate_category(request.category) + definition = await _populate_discovered_tools( + request.definition.model_dump(), + organization_id=user.selected_organization_id, + ) + tool = await db_client.create_tool( organization_id=user.selected_organization_id, user_id=user.id, name=request.name, - definition=request.definition.model_dump(), + definition=definition, category=request.category, description=request.description, icon=request.icon, @@ -410,6 +486,67 @@ async def get_tool( return build_tool_response(tool, include_created_by=True) +@router.post("/{tool_uuid}/mcp/refresh") +async def refresh_mcp_tools( + tool_uuid: str, + user: UserModel = Depends(get_user), +) -> McpRefreshResponse: + """Re-discover an MCP tool's server catalog and overwrite the cached + ``definition.config.discovered_tools``. Server down → 200 with error + (cache not overwritten on transient failure).""" + if not user.selected_organization_id: + raise HTTPException( + status_code=400, detail="No organization selected for the user" + ) + + tool = await db_client.get_tool_by_uuid( + tool_uuid, user.selected_organization_id, include_archived=True + ) + if not tool: + raise HTTPException(status_code=404, detail="Tool not found") + if tool.category != ToolCategory.MCP.value: + raise HTTPException(status_code=400, detail="Tool is not an MCP tool") + + try: + cfg = validate_mcp_definition(tool.definition) + except McpDefinitionError as e: + raise HTTPException(status_code=400, detail=f"Invalid MCP definition: {e}") + + credential = await _fetch_credential( + cfg.get("credential_uuid"), user.selected_organization_id + ) + + try: + discovered = await discover_mcp_tools( + url=cfg["url"], + credential=credential, + timeout_secs=cfg["timeout_secs"], + sse_read_timeout_secs=cfg["sse_read_timeout_secs"], + ) + except Exception as e: # noqa: BLE001 + logger.warning(f"MCP refresh discovery failed: {e}") + discovered = [] + + if not discovered: + error = ( + f"Could not reach the MCP server at {cfg['url']} " + f"(or it exposes no tools). Previously cached list retained." + ) + # Do NOT clobber a previously-good cache with [] on a transient outage. + return McpRefreshResponse(tool_uuid=tool_uuid, discovered_tools=[], error=error) + + new_def = dict(tool.definition or {}) + new_def["config"] = {**new_def.get("config", {}), "discovered_tools": discovered} + await db_client.update_tool( + tool_uuid=tool_uuid, + organization_id=user.selected_organization_id, + definition=new_def, + ) + return McpRefreshResponse( + tool_uuid=tool_uuid, discovered_tools=discovered, error=None + ) + + @router.put("/{tool_uuid}") async def update_tool( tool_uuid: str, @@ -434,12 +571,21 @@ async def update_tool( if request.status: validate_status(request.status) + definition = ( + await _populate_discovered_tools( + request.definition.model_dump(), + organization_id=user.selected_organization_id, + ) + if request.definition + else None + ) + tool = await db_client.update_tool( tool_uuid=tool_uuid, organization_id=user.selected_organization_id, name=request.name, description=request.description, - definition=request.definition.model_dump() if request.definition else None, + definition=definition, icon=request.icon, icon_color=request.icon_color, status=request.status, diff --git a/api/services/workflow/dto.py b/api/services/workflow/dto.py index dc26843..795b28b 100644 --- a/api/services/workflow/dto.py +++ b/api/services/workflow/dto.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Annotated, List, Literal, Optional, Union +from typing import Annotated, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field, ValidationError, model_validator @@ -69,6 +69,7 @@ class _ExtractionNodeDataMixin(BaseModel): class _ToolDocumentRefsMixin(BaseModel): tool_uuids: Optional[List[str]] = None document_uuids: Optional[List[str]] = None + mcp_tool_filters: Optional[Dict[str, List[str]]] = None class StartCallNodeData( diff --git a/api/services/workflow/mcp_tool_session.py b/api/services/workflow/mcp_tool_session.py new file mode 100644 index 0000000..0caa1b7 --- /dev/null +++ b/api/services/workflow/mcp_tool_session.py @@ -0,0 +1,254 @@ +"""Single unit that knows the MCP protocol + credentials. + +Wraps the vendored Pipecat ``MCPClient`` for connection/session, builds +streamable-HTTP params from a Dograh credential, exposes namespaced +``FunctionSchema``s, and proxies tool calls. Connection failures degrade +(``available = False``) instead of raising — the call must survive a +dead MCP server. +""" + +from __future__ import annotations + +import asyncio +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set + +from loguru import logger +from mcp.client.session_group import StreamableHttpParameters +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.services.mcp_service import MCPClient + +from api.services.workflow.tools.mcp_tool import namespace_function_name +from api.utils.credential_auth import build_auth_header + +if TYPE_CHECKING: + from api.db.models import ExternalCredentialModel + + +def build_streamable_http_params( + *, + url: str, + credential: Optional["ExternalCredentialModel"], + timeout_secs: int, + sse_read_timeout_secs: int, +) -> StreamableHttpParameters: + """Build Pipecat/MCP streamable-HTTP params, injecting the auth header + from an ExternalCredentialModel (reuses the http_api credential path).""" + headers: Optional[Dict[str, str]] = None + if credential is not None: + auth = build_auth_header(credential) + headers = auth or None + return StreamableHttpParameters( + url=url, + headers=headers, + timeout=timedelta(seconds=timeout_secs), + sse_read_timeout=timedelta(seconds=sse_read_timeout_secs), + ) + + +class McpToolSession: + """One live MCP server connection for the duration of a call.""" + + def __init__( + self, + *, + tool_uuid: str, + tool_name: str, + url: str, + credential: Optional["ExternalCredentialModel"], + tools_filter: List[str], + timeout_secs: int, + sse_read_timeout_secs: int, + ) -> None: + self._tool_uuid = tool_uuid + self._tool_name = tool_name + self._url = url + self._credential = credential + # An empty list is intentionally treated as "no filter (expose all + # tools)" — Pipecat's MCPClient applies a filter only when this is a + # non-empty list, so [] and None are equivalent ("all tools"). + self._tools_filter = tools_filter or None + self._timeout_secs = timeout_secs + self._sse_read_timeout_secs = sse_read_timeout_secs + + self._client: Optional[MCPClient] = None + self._session: Any = None # mcp.ClientSession (read once after start) + self._schemas: List[FunctionSchema] = [] + # namespaced LLM name -> original MCP tool name + self._name_map: Dict[str, str] = {} + self.available: bool = False + + async def start(self) -> None: + """Connect, initialize, and cache the tool list. Never raises — + on any failure the session is marked unavailable.""" + try: + params = build_streamable_http_params( + url=self._url, + credential=self._credential, + timeout_secs=self._timeout_secs, + sse_read_timeout_secs=self._sse_read_timeout_secs, + ) + self._client = MCPClient(params, tools_filter=self._tools_filter) + await self._client.start() + # Single, isolated touch of Pipecat internals (vendored submodule). + self._session = self._client._active_session + tools_schema = await self._client.get_tools_schema() + + fallback = self._tool_uuid[:8] if self._tool_uuid else "server" + for fs in tools_schema.standard_tools: + ns_name = namespace_function_name( + self._tool_name, fs.name, fallback=fallback + ) + self._name_map[ns_name] = fs.name + self._schemas.append( + FunctionSchema( + name=ns_name, + description=fs.description, + properties=fs.properties, + required=fs.required, + ) + ) + self.available = True + logger.info( + f"MCP session ready for tool '{self._tool_name}' " + f"({self._tool_uuid}): {sorted(self._name_map)}" + ) + except (KeyboardInterrupt, SystemExit): + raise + except asyncio.CancelledError as e: + # Empirically, a dead/unreachable MCP server does NOT surface as a + # plain Exception here. The real failure is httpx.ConnectError, but + # anyio's streamablehttp_client task group, while tearing down that + # ConnectError, re-surfaces it to our frame as an *internal* + # cancel-scope CancelledError carrying the signature message + # "Cancelled via cancel scope ". A genuine *external* + # cancellation (call teardown / shutdown) is a bare CancelledError + # (empty args) or one with an application-chosen message. Type, MRO, + # context chain, and asyncio task.cancelling() are all identical + # between the two, so the anyio scope-signature message is the only + # reliable discriminator. Re-raise genuine external cancellation to + # preserve structured concurrency; degrade only on the anyio + # connect-teardown artifact. + msg = "" if not e.args else str(e.args[0] or "") + if not msg.startswith("Cancelled via cancel scope"): + raise + await self._degrade(e) + except Exception as e: # noqa: BLE001 — see _degrade docstring + # Defensive: if a future Pipecat/httpx version surfaces the connect + # failure directly (e.g. httpx.ConnectError) instead of via the + # anyio cancel-scope artifact above, still degrade gracefully. + await self._degrade(e) + + async def _degrade(self, e: BaseException) -> None: + """Mark this session unavailable and tear down any dangling client so + start() leaves self._client either fully usable or None. The contract + requires graceful degradation on any *connect* failure (never raising + for a dead MCP server) while genuine external cancellation / + KeyboardInterrupt / SystemExit are re-raised by the caller.""" + self.available = False + self._schemas = [] + self._name_map = {} + # Self-contained cleanup: _client.start() may have succeeded before a + # later step (e.g. get_tools_schema()) failed, leaving an open client. + if self._client is not None: + try: + await self._client.close() + except Exception: + pass + finally: + self._client = None + self._session = None + logger.warning( + f"MCP session unavailable for tool '{self._tool_name}' " + f"({self._tool_uuid}) at {self._url}: {e!r}. " + f"Call proceeds without these tools." + ) + + @property + def call_timeout_secs(self) -> float: + """Pipecat function-call timeout for this server's tools. Slightly + longer than the transport read timeout so a slow MCP call surfaces + as a structured tool error (handled in the handler) rather than a + hard pipeline timeout.""" + return float(self._sse_read_timeout_secs) + 5.0 + + def function_schemas( + self, allowed_raw_names: Optional[Set[str]] = None + ) -> List[FunctionSchema]: + """Return cached FunctionSchemas, optionally filtered by raw MCP tool name. + + ``allowed_raw_names=None`` returns all schemas. An empty set returns none. + Raw names are the pre-namespace MCP tool names (e.g. ``echo``, not + ``mcp__slug__echo``). + """ + if allowed_raw_names is None: + return list(self._schemas) + return [ + s for s in self._schemas if self._name_map.get(s.name) in allowed_raw_names + ] + + def discovered_tools(self) -> List[Dict[str, str]]: + """Raw MCP tool catalog for UI/cache: ``[{name, description}]`` + using the *raw* server names (not the namespaced LLM names). + Empty if the session is unavailable.""" + out: List[Dict[str, str]] = [] + for s in self._schemas: + raw = self._name_map.get(s.name) + if raw is None: + continue + out.append({"name": raw, "description": s.description or ""}) + return out + + async def call(self, namespaced_name: str, arguments: Dict[str, Any]) -> str: + """Invoke an MCP tool by its namespaced LLM name. Returns a string + (flattened text content). Raises if the session is unavailable so + the caller can map it to a structured error for the LLM.""" + if not self.available or self._session is None: + raise RuntimeError(f"MCP session unavailable for {namespaced_name}") + original = self._name_map.get(namespaced_name) + if original is None: + raise RuntimeError(f"Unknown MCP function {namespaced_name}") + result = await self._session.call_tool(original, arguments=arguments) + text = "" + for content in getattr(result, "content", []) or []: + if getattr(content, "text", None): + text += content.text + return text or "Sorry, the MCP tool returned no content." + + async def close(self) -> None: + if self._client is not None: + try: + await self._client.close() + except Exception as e: + logger.warning(f"Error closing MCP session {self._tool_uuid}: {e}") + finally: + self._client = None + self._session = None + + +async def discover_mcp_tools( + *, + url: str, + credential: Optional["ExternalCredentialModel"], + timeout_secs: int, + sse_read_timeout_secs: int, +) -> List[Dict[str, str]]: + """Open an ephemeral MCP session, list its tools, close it. Returns + ``[{name, description}]`` (raw names). Never raises — on any connect + failure returns ``[]``.""" + session = McpToolSession( + tool_uuid="discover", + tool_name="discover", + url=url, + credential=credential, + tools_filter=[], + timeout_secs=timeout_secs, + sse_read_timeout_secs=sse_read_timeout_secs, + ) + await session.start() + try: + if not session.available: + return [] + return session.discovered_tools() + finally: + await session.close() diff --git a/api/services/workflow/pipecat_engine.py b/api/services/workflow/pipecat_engine.py index d73d9be..b98d4ba 100644 --- a/api/services/workflow/pipecat_engine.py +++ b/api/services/workflow/pipecat_engine.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Union +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Union from pipecat.adapters.schemas.tools_schema import ToolsSchema from pipecat.frames.frames import ( @@ -16,6 +16,7 @@ from pipecat.services.settings import LLMSettings from pipecat.utils.enums import EndTaskReason from api.db import db_client +from api.enums import ToolCategory from api.services.pipecat.audio_playback import play_audio from api.services.workflow.disposition_mapper import apply_disposition_mapping from api.services.workflow.workflow_graph import Node, WorkflowGraph @@ -34,6 +35,7 @@ import asyncio from loguru import logger from api.services.workflow import pipecat_engine_callbacks as engine_callbacks +from api.services.workflow.mcp_tool_session import McpToolSession from api.services.workflow.pipecat_engine_context_composer import ( compose_functions_for_node, compose_system_prompt_for_node, @@ -116,6 +118,9 @@ class PipecatEngine: # Cached organization ID (resolved lazily from workflow run) self._organization_id: Optional[int] = None + # Open MCP tool sessions for this call, keyed by tool_uuid + self._mcp_sessions: Dict[str, McpToolSession] = {} + # Embeddings configuration (passed from run_pipeline.py) self._embeddings_api_key: Optional[str] = embeddings_api_key self._embeddings_model: Optional[str] = embeddings_model @@ -178,6 +183,9 @@ class PipecatEngine: # Helper that encapsulates custom tool management self._custom_tool_manager = CustomToolManager(self) + # Open persistent MCP server sessions for this call (degrades on failure) + await self._open_mcp_sessions() + # Helper that encapsulates context summarization if self._context_compaction_enabled: self._context_summarization_manager = ContextSummarizationManager(self) @@ -503,7 +511,10 @@ class PipecatEngine: # Register custom tool handlers for this node if node.tool_uuids and self._custom_tool_manager: - await self._custom_tool_manager.register_handlers(node.tool_uuids) + await self._custom_tool_manager.register_handlers( + node.tool_uuids, + mcp_tool_filters=getattr(node, "mcp_tool_filters", None), + ) # Register knowledge base retrieval handler if node has documents if node.document_uuids: @@ -814,6 +825,79 @@ class PipecatEngine: """Get the gathered context including extracted variables.""" return self._gathered_context.copy() + async def _open_mcp_sessions(self) -> None: + """Connect every MCP-category tool referenced by any workflow node. + Failures degrade (session marked unavailable); never raises.""" + from api.services.workflow.tools.mcp_tool import ( + McpDefinitionError, + validate_mcp_definition, + ) + + try: + tool_uuids: set[str] = set() + for node in self.workflow.nodes.values(): + for tu in getattr(node, "tool_uuids", None) or []: + tool_uuids.add(tu) + if not tool_uuids: + return + + organization_id = await self._get_organization_id() + if not organization_id: + logger.warning("Cannot open MCP sessions: organization_id missing") + return + + tools = await db_client.get_tools_by_uuids( + list(tool_uuids), organization_id + ) + for tool in tools: + if tool.category != ToolCategory.MCP.value: + continue + try: + cfg = validate_mcp_definition(tool.definition) + except McpDefinitionError as e: + logger.warning( + f"Skipping MCP tool '{tool.name}' ({tool.tool_uuid}): " + f"invalid definition: {e}" + ) + continue + + credential = None + if cfg["credential_uuid"]: + try: + credential = await db_client.get_credential_by_uuid( + cfg["credential_uuid"], organization_id + ) + except Exception as e: + logger.warning( + f"MCP tool '{tool.name}': credential fetch failed: {e}" + ) + continue + + session = McpToolSession( + tool_uuid=tool.tool_uuid, + tool_name=tool.name, + url=cfg["url"], + credential=credential, + tools_filter=cfg["tools_filter"], + timeout_secs=cfg["timeout_secs"], + sse_read_timeout_secs=cfg["sse_read_timeout_secs"], + ) + await session.start() + self._mcp_sessions[tool.tool_uuid] = session + except Exception as e: + logger.warning( + f"Failed to open MCP sessions; call proceeds without MCP tools: {e}", + exc_info=True, + ) + + async def _close_mcp_sessions(self) -> None: + for tool_uuid, session in list(self._mcp_sessions.items()): + try: + await session.close() + except Exception as e: + logger.warning(f"Error closing MCP session {tool_uuid}: {e}") + self._mcp_sessions = {} + async def cleanup(self): """Clean up engine resources on disconnect.""" # Cancel any pending timeout tasks @@ -823,6 +907,12 @@ class PipecatEngine: ): self._user_response_timeout_task.cancel() - # Cancel any in-flight background summarization - if self._context_summarization_manager: - await self._context_summarization_manager.cleanup() + # Cancel any in-flight background summarization. + # MCP sessions are closed in a finally block so they are guaranteed to + # run even if the summarization cleanup raises an exception. + try: + if self._context_summarization_manager: + await self._context_summarization_manager.cleanup() + finally: + # Close any open MCP tool sessions + await self._close_mcp_sessions() diff --git a/api/services/workflow/pipecat_engine_context_composer.py b/api/services/workflow/pipecat_engine_context_composer.py index 03c253a..41b62c0 100644 --- a/api/services/workflow/pipecat_engine_context_composer.py +++ b/api/services/workflow/pipecat_engine_context_composer.py @@ -117,7 +117,8 @@ async def compose_functions_for_node( # Custom tools if node.tool_uuids and custom_tool_manager: custom_tool_schemas = await custom_tool_manager.get_tool_schemas( - node.tool_uuids + node.tool_uuids, + mcp_tool_filters=getattr(node, "mcp_tool_filters", None), ) functions.extend(custom_tool_schemas) diff --git a/api/services/workflow/pipecat_engine_custom_tools.py b/api/services/workflow/pipecat_engine_custom_tools.py index 90eed7e..c40c664 100644 --- a/api/services/workflow/pipecat_engine_custom_tools.py +++ b/api/services/workflow/pipecat_engine_custom_tools.py @@ -34,6 +34,7 @@ from api.services.workflow.tools.custom_tool import ( ) if TYPE_CHECKING: + from api.services.workflow.mcp_tool_session import McpToolSession from api.services.workflow.pipecat_engine import PipecatEngine @@ -121,11 +122,18 @@ class CustomToolManager: """Get the organization ID from the engine (shared cache).""" return await self._engine._get_organization_id() - async def get_tool_schemas(self, tool_uuids: list[str]) -> list[FunctionSchema]: + async def get_tool_schemas( + self, + tool_uuids: list[str], + mcp_tool_filters: Optional[dict[str, list[str]]] = None, + ) -> list[FunctionSchema]: """Fetch custom tools and convert them to function schemas. Args: tool_uuids: List of tool UUIDs to fetch + mcp_tool_filters: Optional per-node filter mapping tool_uuid → list of + raw MCP tool names to expose. None (default) exposes all tools. + Empty dict or entry with [] suppresses all tools for that uuid. Returns: List of FunctionSchema objects for LLM @@ -154,6 +162,22 @@ class CustomToolManager: ) continue + if tool.category == ToolCategory.MCP.value: + session = self._engine._mcp_sessions.get(tool.tool_uuid) + if session is None or not session.available: + logger.warning( + f"MCP tool '{tool.name}' ({tool.tool_uuid}) " + f"unavailable; skipping" + ) + continue + allowed = ( + None + if mcp_tool_filters is None + else set(mcp_tool_filters.get(tool.tool_uuid, [])) + ) + schemas.extend(session.function_schemas(allowed)) + continue + raw_schema = tool_to_function_schema(tool) function_name = raw_schema["function"]["name"] @@ -178,11 +202,18 @@ class CustomToolManager: logger.error(f"Failed to fetch custom tools: {e}") return [] - async def register_handlers(self, tool_uuids: list[str]) -> None: + async def register_handlers( + self, + tool_uuids: list[str], + mcp_tool_filters: Optional[dict[str, list[str]]] = None, + ) -> None: """Register custom tool execution handlers with the LLM. Args: tool_uuids: List of tool UUIDs to register handlers for + mcp_tool_filters: Optional per-node filter mapping tool_uuid → list of + raw MCP tool names to expose. None (default) exposes all tools. + Empty dict or entry with [] suppresses all tools for that uuid. """ organization_id = await self.get_organization_id() if not organization_id: @@ -203,6 +234,32 @@ class CustomToolManager: ) continue + if tool.category == ToolCategory.MCP.value: + session = self._engine._mcp_sessions.get(tool.tool_uuid) + if session is None or not session.available: + logger.warning( + f"MCP tool '{tool.name}' ({tool.tool_uuid}) " + f"unavailable; skipping handler registration" + ) + continue + allowed = ( + None + if mcp_tool_filters is None + else set(mcp_tool_filters.get(tool.tool_uuid, [])) + ) + mcp_schemas = session.function_schemas(allowed) + for fs in mcp_schemas: + self._engine.llm.register_function( + fs.name, + self._create_mcp_handler(session, fs.name), + timeout_secs=session.call_timeout_secs, + ) + logger.debug( + f"Registered {len(mcp_schemas)} MCP " + f"handlers for tool '{tool.name}' ({tool.tool_uuid})" + ) + continue + schema = tool_to_function_schema(tool) function_name = schema["function"]["name"] @@ -335,6 +392,29 @@ class CustomToolManager: return http_tool_handler + def _create_mcp_handler(self, session: "McpToolSession", function_name: str): + """Create a handler that proxies an LLM function call to a live MCP + session. Errors are returned to the LLM as structured text so the + agent can recover verbally; the call is never crashed.""" + + async def mcp_tool_handler( + function_call_params: FunctionCallParams, + ) -> None: + logger.info(f"MCP Tool EXECUTED: {function_name}") + logger.info(f"Arguments: {function_call_params.arguments}") + try: + result = await session.call( + function_name, function_call_params.arguments or {} + ) + await function_call_params.result_callback(result) + except Exception as e: + logger.error(f"MCP tool '{function_name}' failed: {e}") + await function_call_params.result_callback( + {"status": "error", "error": str(e)} + ) + + return mcp_tool_handler + def _create_end_call_handler(self, tool: Any, function_name: str): """Create a handler function for an end call tool. diff --git a/api/services/workflow/tools/mcp_tool.py b/api/services/workflow/tools/mcp_tool.py new file mode 100644 index 0000000..26dac2a --- /dev/null +++ b/api/services/workflow/tools/mcp_tool.py @@ -0,0 +1,116 @@ +"""Pure helpers for MCP-category tools: definition validation and +LLM-function-name namespacing. No I/O, no MCP protocol here.""" + +from __future__ import annotations + +import re +from typing import Any, Dict, Literal, Optional + +from pydantic import BaseModel, Field, ValidationError, field_validator + +DEFAULT_TIMEOUT_SECS = 30 +DEFAULT_SSE_READ_TIMEOUT_SECS = 300 + + +class McpDefinitionError(ValueError): + """Raised when an MCP tool definition is structurally invalid.""" + + +class McpToolConfig(BaseModel): + """Configuration for an MCP tool definition.""" + + transport: Literal["streamable_http"] = Field( + default="streamable_http", description="MCP transport protocol" + ) + url: str = Field(description="MCP server URL (must be http:// or https://)") + credential_uuid: Optional[str] = Field( + default=None, description="Reference to ExternalCredentialModel for auth" + ) + tools_filter: list[str] = Field( + default_factory=list, + description="Allowlist of MCP tool names to expose (empty = all tools)", + ) + timeout_secs: int = Field( + default=DEFAULT_TIMEOUT_SECS, description="Connection timeout in seconds" + ) + sse_read_timeout_secs: int = Field( + default=DEFAULT_SSE_READ_TIMEOUT_SECS, + description="SSE read timeout in seconds", + ) + discovered_tools: list[dict[str, Any]] = Field( + default_factory=list, + description=( + "Server-managed cache of the MCP server's tool catalog " + "[{name, description}]. Populated best-effort by the backend." + ), + ) + + @field_validator("url") + @classmethod + def validate_url(cls, v: str) -> str: + if not isinstance(v, str) or not v.startswith(("http://", "https://")): + raise ValueError("config.url must be an http(s) URL") + return v + + @field_validator("tools_filter") + @classmethod + def validate_tools_filter(cls, v: list[str]) -> list[str]: + if not all(isinstance(tool_name, str) for tool_name in v): + raise ValueError("config.tools_filter must be a list of strings") + return v + + +class McpToolDefinition(BaseModel): + """Persisted MCP tool definition.""" + + schema_version: int = Field(default=1, description="Schema version") + type: Literal["mcp"] = Field(description="Tool type") + config: McpToolConfig = Field(description="MCP server configuration") + + +def _format_validation_error(error: ValidationError) -> str: + parts: list[str] = [] + for item in error.errors(): + location = ".".join(str(part) for part in item["loc"]) + parts.append(f"{location}: {item['msg']}") + return "; ".join(parts) + + +def validate_mcp_definition(definition: Dict[str, Any]) -> Dict[str, Any]: + """Validate a ``type: "mcp"`` ToolModel definition and return a + normalized config dict with defaults applied. + + Raises: + McpDefinitionError: if the definition is missing required fields + or uses an unsupported transport. + """ + if not isinstance(definition, dict) or definition.get("type") != "mcp": + raise McpDefinitionError("definition.type must be 'mcp'") + + config = definition.get("config") + if not isinstance(config, dict): + raise McpDefinitionError("definition.config is required and must be an object") + + try: + parsed = McpToolDefinition.model_validate(definition) + except ValidationError as e: + raise McpDefinitionError(_format_validation_error(e)) from e + + return parsed.config.model_dump(exclude={"discovered_tools"}) + + +def _slugify(value: str) -> str: + slug = re.sub(r"[^a-z0-9]+", "_", value.strip().lower()).strip("_") + return slug + + +def namespace_function_name( + tool_name: str, mcp_tool_name: str, *, fallback: str = "server" +) -> str: + """Build a collision-safe LLM function name: ``mcp____``. + + ``slug`` is derived from the Dograh ToolModel name; if it slugifies to + empty, ``fallback`` (e.g. first 8 chars of tool_uuid) is used instead. + """ + slug = _slugify(tool_name) or _slugify(fallback) or "server" + return f"mcp__{slug}__{mcp_tool_name}" diff --git a/api/services/workflow/workflow_graph.py b/api/services/workflow/workflow_graph.py index fc2be79..2a2fa50 100644 --- a/api/services/workflow/workflow_graph.py +++ b/api/services/workflow/workflow_graph.py @@ -89,6 +89,7 @@ class Node: self.delayed_start_duration = getattr(data, "delayed_start_duration", None) self.tool_uuids = getattr(data, "tool_uuids", None) self.document_uuids = getattr(data, "document_uuids", None) + self.mcp_tool_filters = getattr(data, "mcp_tool_filters", None) self.pre_call_fetch_enabled = getattr(data, "pre_call_fetch_enabled", False) self.pre_call_fetch_url = getattr(data, "pre_call_fetch_url", None) self.pre_call_fetch_credential_uuid = getattr( diff --git a/api/tests/support/__init__.py b/api/tests/support/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/support/mcp_mock_server.py b/api/tests/support/mcp_mock_server.py new file mode 100644 index 0000000..09a3c8b --- /dev/null +++ b/api/tests/support/mcp_mock_server.py @@ -0,0 +1,103 @@ +"""A real FastMCP server exposing 2 tools over streamable-HTTP, run in a +background uvicorn thread on an ephemeral port. Used to exercise the real +MCP protocol path in tests. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import socket +import threading +from typing import AsyncIterator + +import httpx +import uvicorn +from fastmcp import FastMCP +from starlette.responses import JSONResponse + + +def _build_app(required_headers: dict[str, str] | None = None): + mcp = FastMCP("mock-mcp") + + @mcp.tool() + def echo(text: str) -> str: + """Echo the provided text back.""" + return f"echo:{text}" + + @mcp.tool() + def add(a: int, b: int) -> int: + """Add two integers.""" + return a + b + + # FastMCP 3.x: ASGI app for streamable-HTTP transport at "/mcp". + app = mcp.http_app() + if not required_headers: + return app + + normalized = {k.lower(): v for k, v in required_headers.items()} + + async def guarded_app(scope, receive, send): + if scope["type"] == "http": + headers = { + key.decode("latin-1").lower(): value.decode("latin-1") + for key, value in scope.get("headers", []) + } + for header_name, expected_value in normalized.items(): + if headers.get(header_name) != expected_value: + response = JSONResponse( + {"detail": f"Missing or invalid header: {header_name}"}, + status_code=401, + ) + await response(scope, receive, send) + return + await app(scope, receive, send) + + return guarded_app + + +def _free_port() -> int: + with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@contextlib.asynccontextmanager +async def running_mcp_server( + *, required_headers: dict[str, str] | None = None +) -> AsyncIterator[str]: + """Yield the base streamable-HTTP URL of a live mock MCP server.""" + port = _free_port() + config = uvicorn.Config( + _build_app(required_headers), host="127.0.0.1", port=port, log_level="warning" + ) + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + base_url = f"http://127.0.0.1:{port}/mcp" + server_ready = False + for _ in range(50): + try: + async with httpx.AsyncClient() as client: + await client.get(base_url, timeout=0.5) + server_ready = True + break + except Exception: + await asyncio.sleep(0.1) + if not server_ready: + server.should_exit = True + thread.join(timeout=5) + raise RuntimeError(f"Mock MCP server at {base_url} failed to start within 5s") + try: + yield base_url + finally: + server.should_exit = True + thread.join(timeout=5) + if thread.is_alive(): + import warnings + + warnings.warn( + "Mock MCP server thread did not terminate within 5s", + ResourceWarning, + ) diff --git a/api/tests/test_mcp_auth.py b/api/tests/test_mcp_auth.py new file mode 100644 index 0000000..c6c10c1 --- /dev/null +++ b/api/tests/test_mcp_auth.py @@ -0,0 +1,63 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +from api.mcp_server.auth import authenticate_mcp_request + + +@pytest.mark.asyncio +async def test_authenticate_mcp_request_accepts_bearer_authorization(): + user = MagicMock() + user.id = 1 + user.selected_organization_id = 90 + + with ( + patch( + "api.mcp_server.auth.get_http_headers", + return_value={"authorization": "Bearer secret-api-key"}, + ) as get_headers, + patch( + "api.mcp_server.auth._handle_api_key_auth", + AsyncMock(return_value=user), + ) as handle_auth, + ): + authed = await authenticate_mcp_request() + + assert authed is user + get_headers.assert_called_once_with(include={"authorization"}) + handle_auth.assert_awaited_once_with("secret-api-key") + + +@pytest.mark.asyncio +async def test_authenticate_mcp_request_accepts_x_api_key(): + user = MagicMock() + user.id = 2 + user.selected_organization_id = 91 + + with ( + patch( + "api.mcp_server.auth.get_http_headers", + return_value={"x-api-key": "secret-api-key"}, + ) as get_headers, + patch( + "api.mcp_server.auth._handle_api_key_auth", + AsyncMock(return_value=user), + ) as handle_auth, + ): + authed = await authenticate_mcp_request() + + assert authed is user + get_headers.assert_called_once_with(include={"authorization"}) + handle_auth.assert_awaited_once_with("secret-api-key") + + +@pytest.mark.asyncio +async def test_authenticate_mcp_request_rejects_missing_api_key(): + with patch("api.mcp_server.auth.get_http_headers", return_value={}) as get_headers: + with pytest.raises(HTTPException) as exc_info: + await authenticate_mcp_request() + + assert exc_info.value.status_code == 401 + assert "Missing API key" in str(exc_info.value.detail) + get_headers.assert_called_once_with(include={"authorization"}) diff --git a/api/tests/test_mcp_custom_tool_manager.py b/api/tests/test_mcp_custom_tool_manager.py new file mode 100644 index 0000000..b40776a --- /dev/null +++ b/api/tests/test_mcp_custom_tool_manager.py @@ -0,0 +1,181 @@ +import uuid +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from api.enums import ToolCategory +from api.services.workflow.mcp_tool_session import McpToolSession +from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager +from api.tests.support.mcp_mock_server import running_mcp_server + + +def _mcp_tool(): + t = MagicMock() + t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8] + t.name = "Acme MCP" + t.category = ToolCategory.MCP.value + t.definition = {"type": "mcp", "config": {"url": "https://x/mcp"}} + return t + + +@pytest.mark.asyncio +async def test_get_tool_schemas_and_handler_for_mcp(monkeypatch): + async with running_mcp_server() as base_url: + tool = _mcp_tool() + session = McpToolSession( + tool_uuid=tool.tool_uuid, + tool_name=tool.name, + url=base_url, + credential=None, + tools_filter=[], + timeout_secs=10, + sse_read_timeout_secs=10, + ) + await session.start() + + engine = MagicMock() + engine._mcp_sessions = {tool.tool_uuid: session} + registered = {} + reg_kwargs = {} + + def _reg(name, fn, **kw): + registered[name] = fn + reg_kwargs[name] = kw + + engine.llm.register_function = _reg + + mgr = CustomToolManager(engine) + mgr.get_organization_id = AsyncMock(return_value=42) + + from api.db import db_client + + monkeypatch.setattr( + db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]) + ) + + try: + schemas = await mgr.get_tool_schemas([tool.tool_uuid]) + names = sorted(s.name for s in schemas) + assert names == ["mcp__acme_mcp__add", "mcp__acme_mcp__echo"] + + await mgr.register_handlers([tool.tool_uuid]) + assert "mcp__acme_mcp__echo" in registered + assert reg_kwargs["mcp__acme_mcp__echo"]["timeout_secs"] == pytest.approx( + 15.0 + ) + + captured = {} + + class P: + function_name = "mcp__acme_mcp__echo" + arguments = {"text": "yo"} + + async def result_callback(self, r, *, properties=None): + captured["r"] = r + + await registered["mcp__acme_mcp__echo"](P()) + assert "echo:yo" in str(captured["r"]) + finally: + await session.close() + + +@pytest.mark.asyncio +async def test_unavailable_mcp_session_contributes_nothing(monkeypatch): + tool = _mcp_tool() + session = McpToolSession( + tool_uuid=tool.tool_uuid, + tool_name=tool.name, + url="http://127.0.0.1:1/mcp", + credential=None, + tools_filter=[], + timeout_secs=1, + sse_read_timeout_secs=1, + ) + await session.start() # degrades + + engine = MagicMock() + engine._mcp_sessions = {tool.tool_uuid: session} + mgr = CustomToolManager(engine) + mgr.get_organization_id = AsyncMock(return_value=42) + + from api.db import db_client + + monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])) + + schemas = await mgr.get_tool_schemas([tool.tool_uuid]) + assert schemas == [] + await mgr.register_handlers([tool.tool_uuid]) # must not raise + + +def test_call_timeout_secs_is_read_timeout_plus_buffer(): + session = McpToolSession( + tool_uuid="uuid-abc123", + tool_name="Acme MCP", + url="https://x/mcp", + credential=None, + tools_filter=[], + timeout_secs=10, + sse_read_timeout_secs=20, + ) + assert session.call_timeout_secs == 25.0 + + +@pytest.mark.asyncio +async def test_per_node_mcp_filter_intersection(monkeypatch): + async with running_mcp_server() as base_url: + tool = _mcp_tool() + session = McpToolSession( + tool_uuid=tool.tool_uuid, + tool_name=tool.name, + url=base_url, + credential=None, + tools_filter=[], + timeout_secs=10, + sse_read_timeout_secs=10, + ) + await session.start() + + engine = MagicMock() + engine._mcp_sessions = {tool.tool_uuid: session} + registered = {} + engine.llm.register_function = lambda name, fn, **kw: registered.__setitem__( + name, fn + ) + + mgr = CustomToolManager(engine) + mgr.get_organization_id = AsyncMock(return_value=42) + + from api.db import db_client + + monkeypatch.setattr( + db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]) + ) + try: + # Allow only raw "echo" for this node + filters = {tool.tool_uuid: ["echo"]} + schemas = await mgr.get_tool_schemas( + [tool.tool_uuid], mcp_tool_filters=filters + ) + # Check only "echo" schema returned (namespaced name depends on tool.name) + assert len(schemas) == 1 + assert all("echo" in s.name for s in schemas) + + await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters=filters) + assert len(registered) == 1 + assert all("echo" in k for k in registered) + + # No filter entry for this uuid = none (default-none) + registered.clear() + result = await mgr.get_tool_schemas([tool.tool_uuid], mcp_tool_filters={}) + assert result == [] + await mgr.register_handlers([tool.tool_uuid], mcp_tool_filters={}) + assert registered == {} + + # mcp_tool_filters=None = backward-compatible (all tools) + registered.clear() + all_schemas = await mgr.get_tool_schemas([tool.tool_uuid]) + assert len(all_schemas) == 2 # both echo and add + await mgr.register_handlers([tool.tool_uuid]) + assert len(registered) == 2 # both handlers registered + finally: + await session.close() diff --git a/api/tests/test_mcp_integration.py b/api/tests/test_mcp_integration.py new file mode 100644 index 0000000..4cf01f0 --- /dev/null +++ b/api/tests/test_mcp_integration.py @@ -0,0 +1,107 @@ +import uuid +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from api.enums import ToolCategory +from api.services.workflow.pipecat_engine import PipecatEngine +from api.tests.support.mcp_mock_server import running_mcp_server + + +def _mcp_tool(url: str): + t = MagicMock() + t.tool_uuid = "uuid-" + uuid.uuid4().hex[:8] + t.name = "Acme MCP" + t.category = ToolCategory.MCP.value + t.definition = { + "schema_version": 1, + "type": "mcp", + "config": {"transport": "streamable_http", "url": url}, + } + return t + + +@pytest.mark.asyncio +async def test_engine_opens_and_closes_mcp_sessions(monkeypatch): + async with running_mcp_server() as base_url: + tool = _mcp_tool(base_url) + + engine = PipecatEngine.__new__(PipecatEngine) + node = MagicMock() + node.tool_uuids = [tool.tool_uuid] + workflow = MagicMock() + workflow.nodes = {"n1": node} + engine.workflow = workflow + engine._mcp_sessions = {} + + from api.db import db_client + + monkeypatch.setattr( + db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool]) + ) + monkeypatch.setattr( + db_client, "get_credential_by_uuid", AsyncMock(return_value=None) + ) + engine._get_organization_id = AsyncMock(return_value=42) + + await engine._open_mcp_sessions() + try: + assert tool.tool_uuid in engine._mcp_sessions + sess = engine._mcp_sessions[tool.tool_uuid] + assert sess.available is True + assert len(sess.function_schemas()) == 2 + finally: + await engine._close_mcp_sessions() + assert engine._mcp_sessions == {} + + +@pytest.mark.asyncio +async def test_open_mcp_sessions_swallows_db_error(monkeypatch): + engine = PipecatEngine.__new__(PipecatEngine) + node = MagicMock() + node.tool_uuids = ["uuid-deadbeef"] + workflow = MagicMock() + workflow.nodes = {"n1": node} + engine.workflow = workflow + engine._mcp_sessions = {} + + from api.db import db_client + + monkeypatch.setattr( + db_client, + "get_tools_by_uuids", + AsyncMock(side_effect=RuntimeError("db down")), + ) + engine._get_organization_id = AsyncMock(return_value=42) + + # Must NOT raise + await engine._open_mcp_sessions() + assert engine._mcp_sessions == {} + + +@pytest.mark.asyncio +async def test_open_mcp_sessions_skips_tool_when_credential_fetch_fails(monkeypatch): + tool = _mcp_tool("http://127.0.0.1:1/mcp") + tool.definition["config"]["credential_uuid"] = "cred-1234" + + engine = PipecatEngine.__new__(PipecatEngine) + node = MagicMock() + node.tool_uuids = [tool.tool_uuid] + workflow = MagicMock() + workflow.nodes = {"n1": node} + engine.workflow = workflow + engine._mcp_sessions = {} + + from api.db import db_client + + monkeypatch.setattr(db_client, "get_tools_by_uuids", AsyncMock(return_value=[tool])) + monkeypatch.setattr( + db_client, + "get_credential_by_uuid", + AsyncMock(side_effect=RuntimeError("cred store down")), + ) + engine._get_organization_id = AsyncMock(return_value=42) + + # Must NOT raise, and must skip the tool (no futile unauthenticated start) + await engine._open_mcp_sessions() + assert engine._mcp_sessions == {} diff --git a/api/tests/test_mcp_tool_definition.py b/api/tests/test_mcp_tool_definition.py new file mode 100644 index 0000000..5158c3a --- /dev/null +++ b/api/tests/test_mcp_tool_definition.py @@ -0,0 +1,112 @@ +import importlib + +import pytest + +from api.enums import ToolCategory +from api.routes.tool import McpToolConfig as RouteMcpToolConfig +from api.routes.tool import McpToolDefinition as RouteMcpToolDefinition +from api.services.workflow.tools.mcp_tool import ( + McpDefinitionError, + McpToolConfig, + McpToolDefinition, + namespace_function_name, + validate_mcp_definition, +) + + +def test_mcp_category_exists(): + assert ToolCategory.MCP.value == "mcp" + assert ToolCategory("mcp") is ToolCategory.MCP + + +def test_mcp_migration_present_and_chained(monkeypatch): + mod = importlib.import_module( + "api.alembic.versions.0a1b2c3d4e5f_add_mcp_in_toolcategory" + ) + assert mod.revision == "0a1b2c3d4e5f" + assert mod.down_revision == "4c1f1e3e8ef2" + + calls = [] + + def fake_sync_enum_values(**kwargs): + calls.append(kwargs) + + monkeypatch.setattr(mod.op, "sync_enum_values", fake_sync_enum_values) + + mod.upgrade() + mod.downgrade() + + assert len(calls) == 2 + assert calls[0]["enum_name"] == "tool_category" + assert "mcp" in calls[0]["new_values"] + assert "mcp" not in calls[1]["new_values"] + + +def test_route_reuses_shared_mcp_models(): + assert RouteMcpToolConfig is McpToolConfig + assert RouteMcpToolDefinition is McpToolDefinition + + +def test_validate_mcp_definition_ok(): + cfg = validate_mcp_definition( + { + "schema_version": 1, + "type": "mcp", + "config": { + "transport": "streamable_http", + "url": "https://acme.example.com/mcp", + "credential_uuid": "cred-123", + "tools_filter": ["lookup_patient"], + "timeout_secs": 30, + "sse_read_timeout_secs": 300, + }, + } + ) + assert cfg["url"] == "https://acme.example.com/mcp" + assert cfg["transport"] == "streamable_http" + assert cfg["tools_filter"] == ["lookup_patient"] + assert cfg["timeout_secs"] == 30 + assert cfg["sse_read_timeout_secs"] == 300 + assert cfg["credential_uuid"] == "cred-123" + + +def test_validate_mcp_definition_defaults(): + cfg = validate_mcp_definition({"type": "mcp", "config": {"url": "https://x/mcp"}}) + assert cfg["transport"] == "streamable_http" + assert cfg["tools_filter"] == [] + assert cfg["timeout_secs"] == 30 + assert cfg["sse_read_timeout_secs"] == 300 + assert cfg["credential_uuid"] is None + + +@pytest.mark.parametrize( + "definition", + [ + {"type": "mcp", "config": {}}, + {"type": "mcp", "config": {"url": ""}}, + {"type": "mcp", "config": {"url": "ftp://x"}}, + {"type": "mcp"}, + {"type": "mcp", "config": {"url": "https://x", "transport": "stdio"}}, + ], +) +def test_validate_mcp_definition_rejects(definition): + with pytest.raises(McpDefinitionError): + validate_mcp_definition(definition) + + +def test_validate_mcp_definition_zero_timeout_preserved(): + cfg = validate_mcp_definition( + {"type": "mcp", "config": {"url": "https://x/mcp", "timeout_secs": 0}} + ) + assert cfg["timeout_secs"] == 0 + + +def test_namespace_function_name(): + assert ( + namespace_function_name("Acme MCP", "lookup_patient") + == "mcp__acme_mcp__lookup_patient" + ) + assert ( + namespace_function_name("", "ping", fallback="abcd1234") + == "mcp__abcd1234__ping" + ) diff --git a/api/tests/test_mcp_tool_route.py b/api/tests/test_mcp_tool_route.py new file mode 100644 index 0000000..a16f75c --- /dev/null +++ b/api/tests/test_mcp_tool_route.py @@ -0,0 +1,437 @@ +"""Route-level tests for the MCP tool definition schema. + +These tests exercise the Pydantic request models (CreateToolRequest / +UpdateToolRequest) to catch schema gaps at the route/request-model layer — +the layer where the pre-fix defect lived (HTTP 422 on every MCP tool +creation attempt). + +Test coverage: +- CreateToolRequest validates a valid MCP definition (was 422 before Part A). +- UpdateToolRequest validates a valid MCP definition. +- Invalid MCP bodies are rejected (ftp:// url, missing url). +- Round-trip: validated definition dict passes through validate_mcp_definition + unchanged, proving the request schema and call-time validator agree. +- Full HTTP round-trip via the ASGI test client (POST /api/v1/tools/). +""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from api.routes.tool import CreateToolRequest, McpToolDefinition, UpdateToolRequest +from api.services.workflow.tools.mcp_tool import ( + validate_mcp_definition, +) + +# ── Canonical valid MCP request body ───────────────────────────────────────── + +VALID_MCP_DEFINITION = { + "schema_version": 1, + "type": "mcp", + "config": { + "transport": "streamable_http", + "url": "https://x/mcp", + "credential_uuid": None, + "tools_filter": [], + }, +} + + +# ── Part A regression: CreateToolRequest / UpdateToolRequest validation ─────── + + +def test_create_tool_request_accepts_mcp_definition(): + """CreateToolRequest must accept an MCP definition (was HTTP 422 before fix).""" + req = CreateToolRequest( + name="My MCP Tool", + description="Integration via MCP", + category="mcp", + definition=VALID_MCP_DEFINITION, + ) + assert isinstance(req.definition, McpToolDefinition) + assert req.definition.type == "mcp" + assert req.definition.config.url == "https://x/mcp" + assert req.definition.config.transport == "streamable_http" + assert req.definition.config.credential_uuid is None + assert req.definition.config.tools_filter == [] + assert req.definition.config.timeout_secs == 30 + assert req.definition.config.sse_read_timeout_secs == 300 + + +def test_update_tool_request_accepts_mcp_definition(): + """UpdateToolRequest must also accept an MCP definition.""" + req = UpdateToolRequest( + name="Updated MCP Tool", + definition=VALID_MCP_DEFINITION, + ) + assert isinstance(req.definition, McpToolDefinition) + assert req.definition.type == "mcp" + assert req.definition.config.url == "https://x/mcp" + + +def test_create_tool_request_accepts_mcp_with_all_fields(): + """All optional MCP config fields are accepted and preserved.""" + req = CreateToolRequest( + name="Full MCP Tool", + category="mcp", + definition={ + "schema_version": 1, + "type": "mcp", + "config": { + "transport": "streamable_http", + "url": "https://acme.example.com/mcp", + "credential_uuid": "cred-abc-123", + "tools_filter": ["lookup_patient", "schedule_appointment"], + "timeout_secs": 60, + "sse_read_timeout_secs": 600, + }, + }, + ) + cfg = req.definition.config # type: ignore[union-attr] + assert cfg.url == "https://acme.example.com/mcp" + assert cfg.credential_uuid == "cred-abc-123" + assert cfg.tools_filter == ["lookup_patient", "schedule_appointment"] + assert cfg.timeout_secs == 60 + assert cfg.sse_read_timeout_secs == 600 + + +# ── Invalid bodies are rejected ─────────────────────────────────────────────── + + +@pytest.mark.parametrize( + "definition", + [ + # ftp:// URL — rejected by McpToolConfig.validate_url + { + "schema_version": 1, + "type": "mcp", + "config": {"transport": "streamable_http", "url": "ftp://x/mcp"}, + }, + # Empty url — rejected by McpToolConfig.validate_url + { + "schema_version": 1, + "type": "mcp", + "config": {"transport": "streamable_http", "url": ""}, + }, + # Missing url — rejected by McpToolConfig (required field) + { + "schema_version": 1, + "type": "mcp", + "config": {"transport": "streamable_http"}, + }, + # Unsupported transport — rejected because Literal["streamable_http"] constraint + { + "schema_version": 1, + "type": "mcp", + "config": {"url": "https://x/mcp", "transport": "stdio"}, + }, + ], +) +def test_create_tool_request_rejects_invalid_mcp_definition(definition): + """Invalid MCP definitions must raise ValidationError.""" + with pytest.raises(ValidationError): + CreateToolRequest( + name="Bad MCP Tool", + category="mcp", + definition=definition, + ) + + +# ── Round-trip compatibility: request schema ↔ validate_mcp_definition ─────── + + +def test_mcp_definition_round_trips_through_validate_mcp_definition(): + """The dict produced by CreateToolRequest.definition.model_dump() must be + accepted by validate_mcp_definition without raising, and the result must + contain the expected fields. This proves the request-layer schema and the + call-time validator agree on the stored config shape.""" + req = CreateToolRequest( + name="Round-Trip MCP Tool", + category="mcp", + definition={ + "schema_version": 1, + "type": "mcp", + "config": { + "transport": "streamable_http", + "url": "https://roundtrip.example.com/mcp", + "credential_uuid": "cred-rt-456", + "tools_filter": ["ping"], + "timeout_secs": 45, + "sse_read_timeout_secs": 400, + }, + }, + ) + + # Simulate what the route does: persist definition as a plain dict + persisted = req.definition.model_dump() # type: ignore[union-attr] + + # validate_mcp_definition must accept the persisted shape without raising + normalized = validate_mcp_definition(persisted) + + assert normalized["url"] == "https://roundtrip.example.com/mcp" + assert normalized["transport"] == "streamable_http" + assert normalized["credential_uuid"] == "cred-rt-456" + assert normalized["tools_filter"] == ["ping"] + assert normalized["timeout_secs"] == 45 + assert normalized["sse_read_timeout_secs"] == 400 + + +def test_mcp_definition_round_trip_defaults(): + """Round-trip with minimal body: defaults fill in correctly and + validate_mcp_definition agrees on them.""" + req = CreateToolRequest( + name="Minimal MCP Tool", + category="mcp", + definition=VALID_MCP_DEFINITION, + ) + + persisted = req.definition.model_dump() # type: ignore[union-attr] + normalized = validate_mcp_definition(persisted) + + assert normalized["transport"] == "streamable_http" + assert normalized["tools_filter"] == [] + assert normalized["timeout_secs"] == 30 + assert normalized["sse_read_timeout_secs"] == 300 + assert normalized["credential_uuid"] is None + # Part B: auth_header / auth_scheme must NOT be present in the normalized + # config dict (they were dead config removed in the fix) + assert "auth_header" not in normalized + assert "auth_scheme" not in normalized + + +# ── Full HTTP round-trip via ASGI test client ───────────────────────────────── + + +async def test_post_tool_mcp_returns_200(test_client_factory, db_session): + """POST /api/v1/tools/ with an MCP definition must return HTTP 200 and + persist the definition with type='mcp'. Before Part A this always + returned 422.""" + # Create a user and an organization, then link them so the route's + # selected_organization_id check passes. + user, _ = await db_session.get_or_create_user_by_provider_id("mcp_route_test_user") + org, _ = await db_session.get_or_create_organization_by_provider_id( + "mcp_route_test_org", user.id + ) + await db_session.update_user_selected_organization(user.id, org.id) + # Reload the user so selected_organization_id is populated on the object. + user = await db_session.get_user_by_id(user.id) + + async with test_client_factory(user) as client: + response = await client.post( + "/api/v1/tools/", + json={ + "name": "HTTP Round-Trip MCP Tool", + "description": "Testing the full route", + "category": "mcp", + "definition": { + "schema_version": 1, + "type": "mcp", + "config": { + "transport": "streamable_http", + "url": "https://roundtrip.example.com/mcp", + "credential_uuid": None, + "tools_filter": [], + }, + }, + }, + ) + + assert response.status_code == 200, ( + f"Expected 200, got {response.status_code}: {response.text}" + ) + body = response.json() + assert body["definition"]["type"] == "mcp" + assert body["definition"]["config"]["url"] == "https://roundtrip.example.com/mcp" + assert body["category"] == "mcp" + + +async def test_post_tool_mcp_invalid_url_returns_422(test_client_factory, db_session): + """POST /api/v1/tools/ with an ftp:// URL must return HTTP 422.""" + user, _ = await db_session.get_or_create_user_by_provider_id( + "mcp_route_test_user_422" + ) + org, _ = await db_session.get_or_create_organization_by_provider_id( + "mcp_route_test_org_422", user.id + ) + await db_session.update_user_selected_organization(user.id, org.id) + user = await db_session.get_user_by_id(user.id) + + async with test_client_factory(user) as client: + response = await client.post( + "/api/v1/tools/", + json={ + "name": "Bad MCP Tool", + "category": "mcp", + "definition": { + "schema_version": 1, + "type": "mcp", + "config": { + "transport": "streamable_http", + "url": "ftp://invalid.example.com/mcp", + }, + }, + }, + ) + + assert response.status_code == 422 + + +# ── Task 6: discovered_tools field and _populate_discovered_tools helper ────── + +from unittest.mock import AsyncMock, MagicMock + +from api.routes.tool import McpToolConfig, _populate_discovered_tools + + +def test_mcp_config_accepts_discovered_tools(): + cfg = McpToolConfig( + url="https://x/mcp", + discovered_tools=[{"name": "echo", "description": "Echo"}], + ) + assert cfg.discovered_tools == [{"name": "echo", "description": "Echo"}] + # Defaults to [] when omitted + assert McpToolConfig(url="https://x/mcp").discovered_tools == [] + + +@pytest.mark.asyncio +async def test_populate_discovered_tools_overwrites_cache(monkeypatch): + import api.routes.tool as tool_mod + + monkeypatch.setattr( + tool_mod, + "discover_mcp_tools", + AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]), + ) + definition = { + "schema_version": 1, + "type": "mcp", + "config": { + "url": "https://x/mcp", + "tools_filter": [], + "discovered_tools": [{"name": "stale", "description": "old"}], + }, + } + out = await _populate_discovered_tools(definition, organization_id=1) + assert out["config"]["discovered_tools"] == [ + {"name": "echo", "description": "Echo"} + ] + + +@pytest.mark.asyncio +async def test_populate_discovered_tools_non_mcp_is_noop(): + definition = {"schema_version": 1, "type": "http_api", "config": {}} + out = await _populate_discovered_tools(definition, organization_id=1) + assert out == definition # untouched + + +@pytest.mark.asyncio +async def test_populate_discovered_tools_server_down_sets_empty(monkeypatch): + import api.routes.tool as tool_mod + + monkeypatch.setattr( + tool_mod, + "discover_mcp_tools", + AsyncMock(side_effect=RuntimeError("connection refused")), + ) + definition = { + "schema_version": 1, + "type": "mcp", + "config": {"url": "https://x/mcp", "tools_filter": []}, + } + out = await _populate_discovered_tools(definition, organization_id=1) + assert out["config"]["discovered_tools"] == [] + + +# ── Task 7: POST /{tool_uuid}/mcp/refresh ───────────────────────────────────── + +from fastapi import HTTPException + +from api.routes.tool import refresh_mcp_tools + + +def _fake_user(org_id=1): + u = MagicMock() + u.selected_organization_id = org_id + u.id = 1 + u.provider_id = "p1" + return u + + +def _mcp_tool_model(org_id=1): + t = MagicMock() + t.tool_uuid = "tu-mcp" + t.name = "Mock MCP" + t.category = "mcp" + t.definition = { + "schema_version": 1, + "type": "mcp", + "config": {"url": "https://x/mcp", "tools_filter": []}, + } + return t + + +@pytest.mark.asyncio +async def test_refresh_success(monkeypatch): + import api.routes.tool as tool_mod + + tool = _mcp_tool_model() + monkeypatch.setattr( + tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool) + ) + monkeypatch.setattr( + tool_mod.db_client, + "update_tool", + AsyncMock(return_value=tool), + ) + monkeypatch.setattr( + tool_mod, + "discover_mcp_tools", + AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]), + ) + resp = await refresh_mcp_tools("tu-mcp", user=_fake_user()) + assert resp.discovered_tools == [{"name": "echo", "description": "Echo"}] + assert resp.error is None + + +@pytest.mark.asyncio +async def test_refresh_server_down_returns_200_with_error(monkeypatch): + import api.routes.tool as tool_mod + + tool = _mcp_tool_model() + monkeypatch.setattr( + tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool) + ) + monkeypatch.setattr(tool_mod.db_client, "update_tool", AsyncMock(return_value=tool)) + monkeypatch.setattr(tool_mod, "discover_mcp_tools", AsyncMock(return_value=[])) + resp = await refresh_mcp_tools("tu-mcp", user=_fake_user()) + assert resp.discovered_tools == [] + assert resp.error # non-empty human-readable message + # update_tool should NOT be called when discovery returns empty + tool_mod.db_client.update_tool.assert_not_called() + + +@pytest.mark.asyncio +async def test_refresh_non_mcp_is_400(monkeypatch): + import api.routes.tool as tool_mod + + tool = _mcp_tool_model() + tool.category = "http_api" + monkeypatch.setattr( + tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool) + ) + with pytest.raises(HTTPException) as ei: + await refresh_mcp_tools("tu-mcp", user=_fake_user()) + assert ei.value.status_code == 400 + + +@pytest.mark.asyncio +async def test_refresh_not_found_is_404(monkeypatch): + import api.routes.tool as tool_mod + + monkeypatch.setattr( + tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=None) + ) + with pytest.raises(HTTPException) as ei: + await refresh_mcp_tools("nope", user=_fake_user()) + assert ei.value.status_code == 404 diff --git a/api/tests/test_mcp_tool_session.py b/api/tests/test_mcp_tool_session.py new file mode 100644 index 0000000..eab3d5c --- /dev/null +++ b/api/tests/test_mcp_tool_session.py @@ -0,0 +1,274 @@ +from datetime import timedelta +from unittest.mock import MagicMock + +import httpx +import pytest + +from api.services.workflow.mcp_tool_session import ( + McpToolSession, + build_streamable_http_params, + discover_mcp_tools, +) +from api.tests.support.mcp_mock_server import running_mcp_server + + +@pytest.mark.asyncio +async def test_mock_server_starts_and_serves(): + async with running_mcp_server() as base_url: + async with httpx.AsyncClient() as client: + resp = await client.get(base_url, timeout=5.0) + assert resp.status_code in (400, 404, 405, 406) + + +def test_build_streamable_http_params_with_credential(): + cred = MagicMock() + cred.credential_type = "bearer_token" + cred.credential_data = {"token": "abc"} + params = build_streamable_http_params( + url="https://acme.example.com/mcp", + credential=cred, + timeout_secs=30, + sse_read_timeout_secs=300, + ) + assert params.url == "https://acme.example.com/mcp" + assert params.headers == {"Authorization": "Bearer abc"} + assert params.timeout == timedelta(seconds=30) + assert params.sse_read_timeout == timedelta(seconds=300) + + +def test_build_streamable_http_params_no_credential(): + params = build_streamable_http_params( + url="https://acme.example.com/mcp", + credential=None, + timeout_secs=10, + sse_read_timeout_secs=20, + ) + assert params.headers is None or params.headers == {} + + +@pytest.mark.asyncio +async def test_session_start_passes_auth_header_to_real_server(): + cred = MagicMock() + cred.credential_type = "bearer_token" + cred.credential_data = {"token": "abc"} + + async with running_mcp_server( + required_headers={"Authorization": "Bearer abc"} + ) as base_url: + session = McpToolSession( + tool_uuid="uuid-auth-ok", + tool_name="Secure MCP", + url=base_url, + credential=cred, + tools_filter=[], + timeout_secs=10, + sse_read_timeout_secs=20, + ) + await session.start() + try: + assert session.available is True + names = sorted(s.name for s in session.function_schemas()) + assert names == ["mcp__secure_mcp__add", "mcp__secure_mcp__echo"] + result = await session.call("mcp__secure_mcp__echo", {"text": "hi"}) + assert "echo:hi" in result + finally: + await session.close() + + +@pytest.mark.asyncio +async def test_session_auth_failure_degrades_not_raises(): + async with running_mcp_server( + required_headers={"Authorization": "Bearer abc"} + ) as base_url: + session = McpToolSession( + tool_uuid="uuid-auth-fail", + tool_name="Secure MCP", + url=base_url, + credential=None, + tools_filter=[], + timeout_secs=2, + sse_read_timeout_secs=2, + ) + await session.start() # must degrade instead of raising on 401 + try: + assert session.available is False + assert session.function_schemas() == [] + finally: + await session.close() + + +@pytest.mark.asyncio +async def test_session_start_lists_and_calls_real_server(): + async with running_mcp_server() as base_url: + session = McpToolSession( + tool_uuid="uuid-1234abcd", + tool_name="Acme MCP", + url=base_url, + credential=None, + tools_filter=[], + timeout_secs=10, + sse_read_timeout_secs=20, + ) + await session.start() + try: + assert session.available is True + schemas = session.function_schemas() + names = sorted(s.name for s in schemas) + assert names == ["mcp__acme_mcp__add", "mcp__acme_mcp__echo"] + result = await session.call("mcp__acme_mcp__echo", {"text": "hi"}) + assert "echo:hi" in result + finally: + await session.close() + + +@pytest.mark.asyncio +async def test_session_tools_filter_applied(): + async with running_mcp_server() as base_url: + session = McpToolSession( + tool_uuid="uuid-1234abcd", + tool_name="Acme MCP", + url=base_url, + credential=None, + tools_filter=["echo"], + timeout_secs=10, + sse_read_timeout_secs=20, + ) + await session.start() + try: + names = sorted(s.name for s in session.function_schemas()) + assert names == ["mcp__acme_mcp__echo"] + finally: + await session.close() + + +@pytest.mark.asyncio +async def test_session_unreachable_degrades_not_raises(): + session = McpToolSession( + tool_uuid="uuid-1234abcd", + tool_name="Acme MCP", + url="http://127.0.0.1:1/mcp", + credential=None, + tools_filter=[], + timeout_secs=2, + sse_read_timeout_secs=2, + ) + await session.start() # must NOT raise + assert session.available is False + assert session.function_schemas() == [] + await session.close() + + +@pytest.mark.asyncio +async def test_call_on_unavailable_session_raises(): + session = McpToolSession( + tool_uuid="uuid-1234abcd", + tool_name="Acme MCP", + url="http://127.0.0.1:1/mcp", + credential=None, + tools_filter=[], + timeout_secs=2, + sse_read_timeout_secs=2, + ) + await session.start() + with pytest.raises(RuntimeError): + await session.call("mcp__acme_mcp__echo", {"text": "x"}) + await session.close() + + +@pytest.mark.asyncio +async def test_call_unknown_function_raises(): + async with running_mcp_server() as base_url: + session = McpToolSession( + tool_uuid="uuid-1234abcd", + tool_name="Acme MCP", + url=base_url, + credential=None, + tools_filter=[], + timeout_secs=10, + sse_read_timeout_secs=10, + ) + await session.start() + try: + with pytest.raises(RuntimeError): + await session.call("mcp__acme_mcp__does_not_exist", {}) + finally: + await session.close() + + +@pytest.mark.asyncio +async def test_function_schemas_filter_by_raw_name(): + async with running_mcp_server() as base_url: + session = McpToolSession( + tool_uuid="t-filter", + tool_name="Mock MCP", + url=base_url, + credential=None, + tools_filter=[], + timeout_secs=10, + sse_read_timeout_secs=10, + ) + await session.start() + try: + # No arg = all (backward compatible) + all_names = sorted(s.name for s in session.function_schemas()) + assert all_names == ["mcp__mock_mcp__add", "mcp__mock_mcp__echo"] + + # Allow only raw "echo" + only_echo = session.function_schemas(allowed_raw_names={"echo"}) + assert [s.name for s in only_echo] == ["mcp__mock_mcp__echo"] + + # Empty set = none (default-none semantics) + assert session.function_schemas(allowed_raw_names=set()) == [] + + # Unknown raw name = skipped (pure intersection) + assert session.function_schemas(allowed_raw_names={"nope"}) == [] + finally: + await session.close() + + +@pytest.mark.asyncio +async def test_discover_mcp_tools_success(): + async with running_mcp_server() as base_url: + tools = await discover_mcp_tools( + url=base_url, + credential=None, + timeout_secs=10, + sse_read_timeout_secs=10, + ) + names = sorted(t["name"] for t in tools) + assert names == ["add", "echo"] + by_name = {t["name"]: t for t in tools} + assert by_name["echo"]["description"] # non-empty description + assert set(by_name["echo"]) == {"name", "description"} + + +@pytest.mark.asyncio +async def test_discover_mcp_tools_server_down_returns_empty(): + # Unroutable port, short timeouts: must degrade to [] (never raise). + tools = await discover_mcp_tools( + url="http://127.0.0.1:1/mcp", + credential=None, + timeout_secs=1, + sse_read_timeout_secs=1, + ) + assert tools == [] + + +def test_agent_node_data_carries_mcp_tool_filters(): + from api.services.workflow.dto import AgentNodeData, NodeType + from api.services.workflow.workflow_graph import Node + + data = AgentNodeData( + name="N1", + tool_uuids=["tu-1"], + mcp_tool_filters={"tu-1": ["echo"]}, + ) + assert data.mcp_tool_filters == {"tu-1": ["echo"]} + + node = Node("n1", NodeType.agentNode, data) + assert node.mcp_tool_filters == {"tu-1": ["echo"]} + + # Absent field defaults to None (backward compatible) + data2 = AgentNodeData(name="N2") + assert data2.mcp_tool_filters is None + assert Node("n2", NodeType.agentNode, data2).mcp_tool_filters is None diff --git a/docs/docs.json b/docs/docs.json index 09b55f5..13fb7f9 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -71,7 +71,8 @@ { "group": "Custom Tools", "pages": [ - "voice-agent/tools/http-api" + "voice-agent/tools/http-api", + "voice-agent/tools/mcp-tool" ] } ] @@ -308,4 +309,4 @@ "linkedin": "https://linkedin.com/company/dograh" } } -} \ No newline at end of file +} diff --git a/docs/voice-agent/tools/http-api.mdx b/docs/voice-agent/tools/http-api.mdx index 255c778..4b38633 100644 --- a/docs/voice-agent/tools/http-api.mdx +++ b/docs/voice-agent/tools/http-api.mdx @@ -14,7 +14,7 @@ HTTP API tools allow you to attach extrernal REST API calls directly to workflow title="YouTube video player" frameBorder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" - referrerPolicy="strict-origin-when-cross-origin" + referrerPolicy="strict-origin-when-cross-origin" allowFullScreen > diff --git a/docs/voice-agent/tools/introduction.mdx b/docs/voice-agent/tools/introduction.mdx index 1e10df5..2c6d035 100644 --- a/docs/voice-agent/tools/introduction.mdx +++ b/docs/voice-agent/tools/introduction.mdx @@ -3,7 +3,7 @@ title: "Tools" description: "Extend your voice agent's capabilities by giving it tools to perform actions during live conversations." --- -Tools let your AI agent take actions during a conversation — transfer calls, end calls, or call external APIs — based on the context of the conversation and your prompt instructions. +Tools let your AI agent take actions during a conversation — transfer calls, end calls, call external APIs, or invoke remote MCP servers — based on the context of the conversation and your prompt instructions. When a tool is attached to a workflow node, the LLM decides **when** to invoke it and **what parameters** to pass, based on the user's spoken intent and your node-level instructions. @@ -23,6 +23,7 @@ Pre-configured tools that handle common telephony operations out of the box: Tools you define to integrate with any external system: - [**HTTP API**](/voice-agent/tools/http-api) — Call any REST API endpoint during a conversation (e.g., CRM updates, data lookups, triggering automations) +- [**MCP Tool**](/voice-agent/tools/mcp-tool) — Connect an external MCP server and expose its remote tools to the LLM during a conversation ## How Tools Work diff --git a/docs/voice-agent/tools/mcp-tool.mdx b/docs/voice-agent/tools/mcp-tool.mdx new file mode 100644 index 0000000..3b9013e --- /dev/null +++ b/docs/voice-agent/tools/mcp-tool.mdx @@ -0,0 +1,90 @@ +--- +title: "MCP Tool" +description: "Connect an external MCP server to your voice agent so the LLM can call remote tools during a live conversation." +--- + + +This page is about using an external MCP server as a **tool inside a voice agent**. If you want Claude, Cursor, or another coding agent to control Dograh itself over MCP, see [Dograh MCP Server](/integrations/mcp). + + +MCP tools let a Dograh voice agent call tools exposed by a remote [Model Context Protocol](https://modelcontextprotocol.io/) server during a live conversation. Dograh discovers the remote tool catalog, turns those tools into LLM-callable functions, and forwards invocations to the MCP server over authenticated Streamable HTTP. + +## What to configure + +An MCP tool in Dograh has four important pieces: + +- **Name**: the server label shown in Dograh +- **Description**: tells the LLM when this MCP server is relevant +- **URL**: the remote MCP server endpoint (`http://` or `https://`) +- **Credential**: the auth Dograh should send when connecting to that server + +You can also set a **tool filter** to allow only specific remote MCP tools to be exposed. + +## Authentication + +Most hosted MCP servers expect: + +```http +Authorization: Bearer +``` + +So before creating the MCP tool, create a credential in Dograh with: + +- **Credential Type**: `Bearer Token` +- **Token**: the access token issued by your MCP server + +Then select that credential on the MCP tool. + + +If the remote MCP server documentation says to use Bearer auth, choose **Bearer Token** in the credential dialog. Dograh will translate that into the exact `Authorization: Bearer ` header on the MCP connection. + + +Dograh also supports other credential types, but **Bearer Token** is the default thing to try for third-party MCP servers unless their docs say otherwise. + +## How it works + +The runtime path is: + +1. When you save or refresh the MCP tool, Dograh opens a short-lived authenticated MCP session and fetches the remote tool catalog. +2. Dograh stores that catalog as the tool's discovered tools so the UI can show you which remote functions exist. +3. When a call starts, Dograh opens one live MCP session per attached MCP server and reuses your selected credential for that session. +4. For each node, Dograh exposes only the MCP tools allowed by the server-level filter and the node-level selection. +5. Dograh namespaces those remote tools into ordinary LLM function definitions so they can safely coexist with HTTP API tools, call transfer, end call, and other tools. +6. During the conversation, the LLM sees only the tool name, description, and argument schema. It does **not** see the secret. +7. When the LLM calls one of those tools, Dograh forwards the invocation to the MCP server over the active authenticated session, receives the result, and feeds that result back into the agent turn. + +In short: **Dograh handles discovery, authentication, session management, tool registration, and result plumbing; the LLM only decides when to call the tool and with which arguments.** + +## Creating an MCP tool + +1. Go to **Tools** and create a new tool. +2. Choose **MCP Server**. +3. Enter a clear name and a description that explains when the LLM should use this server. +4. Paste the MCP server URL. +5. Select the credential. In most cases this should be a **Bearer Token** credential. +6. Save the tool and confirm that Dograh discovered the remote tools. + +If the server exposes many tools, use filtering to keep only the ones your agent actually needs. + +## Attaching it to a node + +After the MCP tool is created: + +1. Open the workflow node where the tool should be available. +2. Add the MCP tool from the node's tool list. +3. Select only the remote MCP functions that should be callable on that node. +4. In the node prompt, tell the LLM exactly when to use those functions. + +The tighter the node-level selection and prompt guidance, the more reliable MCP tool usage becomes. + +## Best practices + +- Use one MCP server per logical integration when possible. +- Keep the tool description explicit about **when** the LLM should use that server. +- Expose only the minimum remote functions needed for each node. +- Prefer a **Bearer Token** credential unless the MCP server specifies another auth scheme. +- Test discovery first, then test a real phone/web call to confirm the LLM invokes the right MCP function with the right arguments. + + +If the remote MCP server is temporarily unavailable, Dograh degrades gracefully and the call can continue without those MCP tools rather than crashing the entire conversation. + diff --git a/pipecat b/pipecat index 13e98d0..ce4ee2d 160000 --- a/pipecat +++ b/pipecat @@ -1 +1 @@ -Subproject commit 13e98d0d94aa5db3185e36ba411bae0ffb443b7b +Subproject commit ce4ee2d6fc0d0982ad03a1468a517cc5e568aaa9 diff --git a/scripts/setup_requirements.sh b/scripts/setup_requirements.sh index f38ed38..201b952 100755 --- a/scripts/setup_requirements.sh +++ b/scripts/setup_requirements.sh @@ -51,7 +51,7 @@ fi # Install pipecat in editable mode with all extras echo "Installing pipecat dependencies..." -pip install -e ./pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb] +pip install -e ./pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb,mcp] if [ "$DEV_MODE" -eq 1 ]; then echo "Installing pipecat dev dependencies..." diff --git a/ui/src/app/tools/[toolUuid]/page.tsx b/ui/src/app/tools/[toolUuid]/page.tsx index 065e2ba..7118cfd 100644 --- a/ui/src/app/tools/[toolUuid]/page.tsx +++ b/ui/src/app/tools/[toolUuid]/page.tsx @@ -9,16 +9,25 @@ import { listRecordingsApiV1WorkflowRecordingsGet, updateToolApiV1ToolsToolUuidPut, } from "@/client/sdk.gen"; -import type { RecordingResponseSchema, ToolResponse, TransferCallConfig as APITransferCallConfig } from "@/client/types.gen"; -import type { EndCallConfig } from "@/client/types.gen"; +import type { + EndCallConfig, + HttpApiToolDefinition, + RecordingResponseSchema, + ToolResponse, + TransferCallConfig as APITransferCallConfig, + UpdateToolRequest, +} from "@/client/types.gen"; import { + CredentialSelector, type HttpMethod, type KeyValueItem, + type ParameterType, type PresetToolParameter, type ToolParameter, validateUrl, } from "@/components/http"; import { Button } from "@/components/ui/button"; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Dialog, DialogContent, @@ -26,35 +35,33 @@ import { DialogHeader, DialogTitle, } from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; import { Skeleton } from "@/components/ui/skeleton"; +import { Textarea } from "@/components/ui/textarea"; import { TOOL_DOCUMENTATION_URLS } from "@/constants/documentation"; import { useAuth } from "@/lib/auth"; import { + createMcpDefinition, DEFAULT_END_CALL_REASON_DESCRIPTION, type EndCallMessageType, getCategoryConfig, getToolTypeLabel, + MCP_URL_PATTERN, renderToolIcon, type ToolCategory, } from "../config"; import { BuiltinToolConfig, EndCallToolConfig, HttpApiToolConfig, TransferCallToolConfig } from "./components"; -// Extended HttpApiConfig with parameters (until client types are regenerated) -interface HttpApiConfigWithParams { - method?: string; - url?: string; - headers?: Record; - credential_uuid?: string; - parameters?: ToolParameter[]; - preset_parameters?: Array<{ - name?: string; - type?: PresetToolParameter["type"]; - value_template?: string; - required?: boolean; - }>; - timeout_ms?: number; - customMessage?: string; +function normalizeParameterType(value: string | null | undefined): ParameterType { + switch (value) { + case "number": + case "boolean": + return value; + default: + return "string"; + } } export default function ToolDetailPage() { @@ -108,6 +115,11 @@ export default function ToolDetailPage() { const [customMessageType, setCustomMessageType] = useState<'text' | 'audio'>('text'); const [customMessageRecordingId, setCustomMessageRecordingId] = useState(""); + // MCP form state + const [mcpUrl, setMcpUrl] = useState(""); + const [mcpCredentialUuid, setMcpCredentialUuid] = useState(""); + const [mcpToolsFilter, setMcpToolsFilter] = useState(""); + // Org-level recordings for audio dropdowns const [recordings, setRecordings] = useState([]); @@ -155,8 +167,7 @@ export default function ToolDetailPage() { if (config) { setEndCallMessageType(config.messageType || "none"); setCustomMessage(config.customMessage || ""); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - setAudioRecordingId((config as any).audioRecordingId || ""); + setAudioRecordingId(config.audioRecordingId || ""); setEndCallReason(config.endCallReason ?? false); setEndCallReasonDescription(config.endCallReasonDescription || ""); } else { @@ -173,8 +184,7 @@ export default function ToolDetailPage() { setTransferDestination(config.destination || ""); setTransferMessageType(config.messageType || "none"); setCustomMessage(config.customMessage || ""); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - setTransferAudioRecordingId((config as any).audioRecordingId || ""); + setTransferAudioRecordingId(config.audioRecordingId || ""); setTransferTimeout(config.timeout ?? 30); } else { setTransferDestination(""); @@ -183,19 +193,35 @@ export default function ToolDetailPage() { setTransferAudioRecordingId(""); setTransferTimeout(30); } + } else if (tool.category === "mcp") { + // Populate MCP specific fields + const config = tool.definition?.config as + | { url?: string; credential_uuid?: string | null; tools_filter?: string[] } + | undefined; + if (config) { + setMcpUrl(config.url || ""); + setMcpCredentialUuid(config.credential_uuid || ""); + setMcpToolsFilter( + Array.isArray(config.tools_filter) + ? config.tools_filter.join(", ") + : "" + ); + } else { + setMcpUrl(""); + setMcpCredentialUuid(""); + setMcpToolsFilter(""); + } } else { // Populate HTTP API specific fields - const config = tool.definition?.config as HttpApiConfigWithParams | undefined; + const config = tool.definition?.config as HttpApiToolDefinition["config"] | undefined; if (config) { setHttpMethod((config.method as HttpMethod) || "POST"); setUrl(config.url || ""); setCredentialUuid(config.credential_uuid || ""); setTimeoutMs(config.timeout_ms || 5000); setCustomMessage(config.customMessage || ""); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - setCustomMessageType((config as any).customMessageType || "text"); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - setCustomMessageRecordingId((config as any).customMessageRecordingId || ""); + setCustomMessageType(config.customMessageType || "text"); + setCustomMessageRecordingId(config.customMessageRecordingId || ""); // Convert headers object to array if (config.headers) { @@ -212,9 +238,9 @@ export default function ToolDetailPage() { // Load parameters if (config.parameters && Array.isArray(config.parameters)) { setParameters( - config.parameters.map((p: ToolParameter) => ({ + config.parameters.map((p) => ({ name: p.name || "", - type: p.type || "string", + type: normalizeParameterType(p.type), description: p.description || "", required: p.required ?? true, })) @@ -227,7 +253,7 @@ export default function ToolDetailPage() { setPresetParameters( config.preset_parameters.map((p) => ({ name: p.name || "", - type: p.type || "string", + type: normalizeParameterType(p.type), valueTemplate: p.value_template || "", required: p.required ?? true, })) @@ -275,6 +301,16 @@ export default function ToolDetailPage() { setError("Please enter a valid phone number (E.164 format) or SIP endpoint (e.g., PJSIP/1234)"); return; } + } else if (tool.category === "mcp") { + // Validate MCP server URL (must be http(s)) + if (!mcpUrl.trim()) { + setError("Please enter the MCP server URL"); + return; + } + if (!MCP_URL_PATTERN.test(mcpUrl.trim())) { + setError("MCP server URL must start with http:// or https://"); + return; + } } else if (tool.category !== "end_call") { // Validate URL for HTTP API tools const urlValidation = validateUrl(url); @@ -305,7 +341,7 @@ export default function ToolDetailPage() { setSaveSuccess(false); const accessToken = await getAccessToken(); - let requestBody; + let requestBody: UpdateToolRequest; if (tool.category === "calculator") { // Built-in tool - only name/description, no config @@ -351,6 +387,12 @@ export default function ToolDetailPage() { }, }, }; + } else if (tool.category === "mcp") { + requestBody = { + name, + description: description || undefined, + definition: createMcpDefinition(mcpUrl, mcpCredentialUuid, mcpToolsFilter), + }; } else { // Build HTTP API request body const headersObject: Record = {}; @@ -399,8 +441,7 @@ export default function ToolDetailPage() { const response = await updateToolApiV1ToolsToolUuidPut({ path: { tool_uuid: toolUuid }, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - body: requestBody as any, + body: requestBody, headers: { Authorization: `Bearer ${accessToken}`, }, @@ -510,6 +551,7 @@ const data = await response.json();`; const isEndCallTool = tool.category === "end_call"; const isTransferCallTool = tool.category === "transfer_call"; const isBuiltinTool = tool.category === "calculator"; + const isMcpTool = tool.category === "mcp"; const categoryConfig = getCategoryConfig(tool.category as ToolCategory); return ( @@ -545,7 +587,7 @@ const data = await response.json();`;
- {!isEndCallTool && !isTransferCallTool && !isBuiltinTool && ( + {!isEndCallTool && !isTransferCallTool && !isBuiltinTool && !isMcpTool && (