mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
* feat(mcp): generic MCP tool source with per-node function filtering
Adds a Model Context Protocol tool category: connect a customer MCP
server and expose its tools to the agent, with optional per-node
allow-listing of individual MCP functions.
- ToolCategory.MCP enum + alembic migration
- MCP definition validator and collision-safe function-name namespacing
- McpToolSession wrapper: graceful-degrade, per-call open/close lifecycle
- CustomToolManager MCP branch (schemas + proxy handlers)
- Per-node mcp_tool_filters threaded through DTO/graph/engine
- Best-effort discovered_tools catalog cache + POST /tools/{uuid}/mcp/refresh
- UI: MCP create/edit config, tabbed ToolSelector with per-node toggles
* feat: refactor for code standardisation and documentation
---------
Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
651 lines
20 KiB
Python
651 lines
20 KiB
Python
"""API routes for managing tools."""
|
|
|
|
import asyncio
|
|
import re
|
|
from datetime import datetime
|
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from loguru import logger
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
from api.db import db_client
|
|
from api.db.models import UserModel
|
|
from api.enums import PostHogEvent, ToolCategory, ToolStatus
|
|
from api.sdk_expose import sdk_expose
|
|
from api.services.auth.depends import get_user
|
|
from api.services.posthog_client import capture_event
|
|
from api.services.workflow.mcp_tool_session import discover_mcp_tools
|
|
from api.services.workflow.tools.mcp_tool import (
|
|
McpDefinitionError,
|
|
validate_mcp_definition,
|
|
)
|
|
from api.services.workflow.tools.mcp_tool import (
|
|
McpToolConfig as SharedMcpToolConfig,
|
|
)
|
|
from api.services.workflow.tools.mcp_tool import (
|
|
McpToolDefinition as SharedMcpToolDefinition,
|
|
)
|
|
|
|
router = APIRouter(prefix="/tools")
|
|
|
|
McpToolConfig = SharedMcpToolConfig
|
|
McpToolDefinition = SharedMcpToolDefinition
|
|
|
|
|
|
# Request/Response schemas
|
|
class ToolParameter(BaseModel):
|
|
"""A parameter that the tool accepts."""
|
|
|
|
name: str = Field(description="Parameter name (used as key in request body)")
|
|
type: str = Field(description="Parameter type: string, number, or boolean")
|
|
description: str = Field(description="Description of what this parameter is for")
|
|
required: bool = Field(
|
|
default=True, description="Whether this parameter is required"
|
|
)
|
|
|
|
|
|
class PresetToolParameter(BaseModel):
|
|
"""A parameter injected by Dograh at runtime."""
|
|
|
|
name: str = Field(description="Parameter name (used as key in request body)")
|
|
type: str = Field(description="Parameter type: string, number, or boolean")
|
|
value_template: str = Field(
|
|
description="Fixed value or template, e.g. {{initial_context.phone_number}}"
|
|
)
|
|
required: bool = Field(
|
|
default=True,
|
|
description="Whether the parameter must resolve to a non-empty value",
|
|
)
|
|
|
|
|
|
class HttpApiConfig(BaseModel):
|
|
"""Configuration for HTTP API tools."""
|
|
|
|
method: str = Field(description="HTTP method (GET, POST, PUT, PATCH, DELETE)")
|
|
url: str = Field(description="Target URL")
|
|
headers: Optional[Dict[str, str]] = Field(
|
|
default=None, description="Static headers to include"
|
|
)
|
|
credential_uuid: Optional[str] = Field(
|
|
default=None, description="Reference to ExternalCredentialModel for auth"
|
|
)
|
|
parameters: Optional[List[ToolParameter]] = Field(
|
|
default=None, description="Parameters that the tool accepts from LLM"
|
|
)
|
|
preset_parameters: Optional[List[PresetToolParameter]] = Field(
|
|
default=None,
|
|
description="Parameters injected by Dograh from fixed values or workflow context templates",
|
|
)
|
|
timeout_ms: Optional[int] = Field(
|
|
default=5000, description="Request timeout in milliseconds"
|
|
)
|
|
customMessage: Optional[str] = Field(
|
|
default=None, description="Custom message to play after tool execution"
|
|
)
|
|
customMessageType: Optional[Literal["text", "audio"]] = Field(
|
|
default=None, description="Type of custom message: text or audio"
|
|
)
|
|
customMessageRecordingId: Optional[str] = Field(
|
|
default=None, description="Recording ID for audio custom message"
|
|
)
|
|
|
|
|
|
class EndCallConfig(BaseModel):
|
|
"""Configuration for End Call tools."""
|
|
|
|
messageType: Literal["none", "custom", "audio"] = Field(
|
|
default="none", description="Type of goodbye message"
|
|
)
|
|
customMessage: Optional[str] = Field(
|
|
default=None, description="Custom message to play before ending the call"
|
|
)
|
|
audioRecordingId: Optional[str] = Field(
|
|
default=None, description="Recording ID for audio goodbye message"
|
|
)
|
|
endCallReason: bool = Field(
|
|
default=False,
|
|
description="When enabled, LLM must provide a reason for ending the call. "
|
|
"The reason is set as call disposition and added to call tags.",
|
|
)
|
|
endCallReasonDescription: Optional[str] = Field(
|
|
default=None,
|
|
description="Description shown to the LLM for the reason parameter. "
|
|
"Used only when endCallReason is enabled.",
|
|
)
|
|
|
|
|
|
class TransferCallConfig(BaseModel):
|
|
"""Configuration for Transfer Call tools."""
|
|
|
|
destination: str = Field(
|
|
description="Phone number or SIP endpoint to transfer the call to (E.164 format e.g., +1234567890, or SIP endpoint e.g., PJSIP/1234)"
|
|
)
|
|
messageType: Literal["none", "custom", "audio"] = Field(
|
|
default="none", description="Type of message to play before transfer"
|
|
)
|
|
customMessage: Optional[str] = Field(
|
|
default=None, description="Custom message to play before transferring the call"
|
|
)
|
|
audioRecordingId: Optional[str] = Field(
|
|
default=None, description="Recording ID for audio message before transfer"
|
|
)
|
|
timeout: int = Field(
|
|
default=30,
|
|
ge=5,
|
|
le=120,
|
|
description="Maximum time in seconds to wait for destination to answer (5-120 seconds)",
|
|
)
|
|
|
|
@field_validator("destination")
|
|
@classmethod
|
|
def validate_destination(cls, v: str) -> str:
|
|
"""Validate that destination is a valid E.164 phone number or SIP endpoint."""
|
|
# Allow empty string for initial creation (like HTTP API tools with empty URL)
|
|
if not v.strip():
|
|
return v
|
|
|
|
# E.164 format: +[1-9]\d{1,14}
|
|
e164_pattern = r"^\+[1-9]\d{1,14}$"
|
|
|
|
# SIP endpoint format: PJSIP/extension or SIP/extension
|
|
sip_pattern = r"^(PJSIP|SIP)/[\w\-\.@]+$"
|
|
|
|
is_valid_e164 = re.match(e164_pattern, v)
|
|
is_valid_sip = re.match(sip_pattern, v, re.IGNORECASE)
|
|
|
|
if not (is_valid_e164 or is_valid_sip):
|
|
raise ValueError(
|
|
"Destination must be a valid E.164 phone number (e.g., +1234567890) or SIP endpoint (e.g., PJSIP/1234)"
|
|
)
|
|
return v
|
|
|
|
|
|
class HttpApiToolDefinition(BaseModel):
|
|
"""Tool definition for HTTP API tools."""
|
|
|
|
schema_version: int = Field(default=1, description="Schema version")
|
|
type: Literal["http_api"] = Field(description="Tool type")
|
|
config: HttpApiConfig = Field(description="HTTP API configuration")
|
|
|
|
|
|
class EndCallToolDefinition(BaseModel):
|
|
"""Tool definition for End Call tools."""
|
|
|
|
schema_version: int = Field(default=1, description="Schema version")
|
|
type: Literal["end_call"] = Field(description="Tool type")
|
|
config: EndCallConfig = Field(description="End Call configuration")
|
|
|
|
|
|
class TransferCallToolDefinition(BaseModel):
|
|
"""Tool definition for Transfer Call tools."""
|
|
|
|
schema_version: int = Field(default=1, description="Schema version")
|
|
type: Literal["transfer_call"] = Field(description="Tool type")
|
|
config: TransferCallConfig = Field(description="Transfer Call configuration")
|
|
|
|
|
|
class CalculatorToolDefinition(BaseModel):
|
|
"""Tool definition for Calculator tools (no configuration needed)."""
|
|
|
|
schema_version: int = Field(default=1, description="Schema version")
|
|
type: Literal["calculator"] = Field(description="Tool type")
|
|
|
|
|
|
# Union type for tool definitions - Pydantic will discriminate based on 'type' field
|
|
ToolDefinition = Annotated[
|
|
Union[
|
|
HttpApiToolDefinition,
|
|
EndCallToolDefinition,
|
|
TransferCallToolDefinition,
|
|
CalculatorToolDefinition,
|
|
McpToolDefinition,
|
|
],
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
class CreateToolRequest(BaseModel):
|
|
"""Request schema for creating a tool."""
|
|
|
|
name: str = Field(max_length=255)
|
|
description: Optional[str] = None
|
|
category: str = Field(default=ToolCategory.HTTP_API.value)
|
|
icon: Optional[str] = Field(default="globe", max_length=50)
|
|
icon_color: Optional[str] = Field(default="#3B82F6", max_length=7)
|
|
definition: ToolDefinition
|
|
|
|
@field_validator("category")
|
|
@classmethod
|
|
def validate_category(cls, v: str) -> str:
|
|
"""Validate that category is a valid ToolCategory value."""
|
|
valid_categories = [c.value for c in ToolCategory]
|
|
if v not in valid_categories:
|
|
raise ValueError(
|
|
f"Invalid category '{v}'. Must be one of: {', '.join(valid_categories)}"
|
|
)
|
|
return v
|
|
|
|
|
|
class UpdateToolRequest(BaseModel):
|
|
"""Request schema for updating a tool."""
|
|
|
|
name: Optional[str] = Field(default=None, max_length=255)
|
|
description: Optional[str] = None
|
|
icon: Optional[str] = Field(default=None, max_length=50)
|
|
icon_color: Optional[str] = Field(default=None, max_length=7)
|
|
definition: Optional[ToolDefinition] = None
|
|
status: Optional[str] = None
|
|
|
|
|
|
class CreatedByResponse(BaseModel):
|
|
"""Response schema for the user who created a tool."""
|
|
|
|
id: int
|
|
provider_id: str
|
|
|
|
|
|
class ToolResponse(BaseModel):
|
|
"""Response schema for a tool."""
|
|
|
|
id: int
|
|
tool_uuid: str
|
|
name: str
|
|
description: Optional[str]
|
|
category: str
|
|
icon: Optional[str]
|
|
icon_color: Optional[str]
|
|
status: str
|
|
definition: Dict[str, Any]
|
|
created_at: datetime
|
|
updated_at: Optional[datetime]
|
|
created_by: Optional[CreatedByResponse] = None
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
class McpRefreshResponse(BaseModel):
|
|
"""Result of re-discovering an MCP server's tool catalog."""
|
|
|
|
tool_uuid: str
|
|
discovered_tools: list = Field(default_factory=list)
|
|
error: Optional[str] = None
|
|
|
|
|
|
def build_tool_response(tool, include_created_by: bool = False) -> ToolResponse:
|
|
"""Build a response from a tool model."""
|
|
created_by = None
|
|
if include_created_by and tool.created_by_user:
|
|
created_by = CreatedByResponse(
|
|
id=tool.created_by_user.id,
|
|
provider_id=tool.created_by_user.provider_id,
|
|
)
|
|
|
|
return ToolResponse(
|
|
id=tool.id,
|
|
tool_uuid=tool.tool_uuid,
|
|
name=tool.name,
|
|
description=tool.description,
|
|
category=tool.category,
|
|
icon=tool.icon,
|
|
icon_color=tool.icon_color,
|
|
status=tool.status,
|
|
definition=tool.definition,
|
|
created_at=tool.created_at,
|
|
updated_at=tool.updated_at,
|
|
created_by=created_by,
|
|
)
|
|
|
|
|
|
def validate_category(category: str) -> None:
|
|
"""Validate that the category is valid."""
|
|
valid_categories = [c.value for c in ToolCategory]
|
|
if category not in valid_categories:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid category '{category}'. Must be one of: {', '.join(valid_categories)}",
|
|
)
|
|
|
|
|
|
def validate_status(status: str) -> None:
|
|
"""Validate that the status is valid. Supports comma-separated values."""
|
|
valid_statuses = [s.value for s in ToolStatus]
|
|
status_list = [s.strip() for s in status.split(",")]
|
|
for s in status_list:
|
|
if s not in valid_statuses:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid status '{s}'. Must be one of: {', '.join(valid_statuses)}",
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/",
|
|
**sdk_expose(
|
|
method="list_tools",
|
|
description="List tools available to the authenticated organization.",
|
|
),
|
|
)
|
|
async def list_tools(
|
|
status: Optional[str] = None,
|
|
category: Optional[str] = None,
|
|
user: UserModel = Depends(get_user),
|
|
) -> List[ToolResponse]:
|
|
"""
|
|
List all tools for the user's organization.
|
|
|
|
Args:
|
|
status: Optional filter by status (active, archived, draft)
|
|
category: Optional filter by category (http_api, native, integration)
|
|
|
|
Returns:
|
|
List of tools
|
|
"""
|
|
if not user.selected_organization_id:
|
|
raise HTTPException(
|
|
status_code=400, detail="No organization selected for the user"
|
|
)
|
|
|
|
if status:
|
|
validate_status(status)
|
|
if category:
|
|
validate_category(category)
|
|
|
|
tools = await db_client.get_tools_for_organization(
|
|
user.selected_organization_id,
|
|
status=status,
|
|
category=category,
|
|
)
|
|
|
|
return [build_tool_response(tool) for tool in tools]
|
|
|
|
|
|
async def _fetch_credential(credential_uuid: Optional[str], organization_id: int):
|
|
"""Best-effort credential lookup for MCP auth. A missing/failed credential
|
|
degrades to ``None`` (unauthenticated) rather than failing the request."""
|
|
if not credential_uuid:
|
|
return None
|
|
try:
|
|
return await db_client.get_credential_by_uuid(credential_uuid, organization_id)
|
|
except Exception as e: # noqa: BLE001
|
|
logger.warning(f"MCP: credential fetch failed: {e}")
|
|
return None
|
|
|
|
|
|
async def _populate_discovered_tools(definition: dict, *, organization_id: int) -> dict:
|
|
"""Best-effort: for an MCP definition, connect to the server, list its
|
|
tools, and overwrite ``config.discovered_tools``. Never raises and never
|
|
blocks tool save — a dead server yields ``discovered_tools: []``. Non-MCP
|
|
definitions pass through untouched."""
|
|
if not isinstance(definition, dict) or definition.get("type") != "mcp":
|
|
return definition
|
|
try:
|
|
cfg = validate_mcp_definition(definition)
|
|
except McpDefinitionError:
|
|
return definition
|
|
|
|
credential = await _fetch_credential(cfg.get("credential_uuid"), organization_id)
|
|
|
|
# Run discovery in an isolated asyncio task so an anyio cancel-scope
|
|
# CancelledError doesn't bleed into the parent task and corrupt the
|
|
# subsequent DB write. _run() never raises (degrades to []).
|
|
async def _run() -> list:
|
|
try:
|
|
return await discover_mcp_tools(
|
|
url=cfg["url"],
|
|
credential=credential,
|
|
timeout_secs=cfg["timeout_secs"],
|
|
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
|
)
|
|
except BaseException as e: # noqa: BLE001
|
|
logger.warning(f"MCP discovery failed; caching empty list: {e}")
|
|
return []
|
|
|
|
discovered = await asyncio.ensure_future(_run())
|
|
definition["config"]["discovered_tools"] = discovered
|
|
return definition
|
|
|
|
|
|
@router.post("/")
|
|
async def create_tool(
|
|
request: CreateToolRequest,
|
|
user: UserModel = Depends(get_user),
|
|
) -> ToolResponse:
|
|
"""
|
|
Create a new tool.
|
|
|
|
Args:
|
|
request: The tool creation request
|
|
|
|
Returns:
|
|
The created tool
|
|
"""
|
|
if not user.selected_organization_id:
|
|
raise HTTPException(
|
|
status_code=400, detail="No organization selected for the user"
|
|
)
|
|
|
|
validate_category(request.category)
|
|
|
|
definition = await _populate_discovered_tools(
|
|
request.definition.model_dump(),
|
|
organization_id=user.selected_organization_id,
|
|
)
|
|
|
|
tool = await db_client.create_tool(
|
|
organization_id=user.selected_organization_id,
|
|
user_id=user.id,
|
|
name=request.name,
|
|
definition=definition,
|
|
category=request.category,
|
|
description=request.description,
|
|
icon=request.icon,
|
|
icon_color=request.icon_color,
|
|
)
|
|
|
|
capture_event(
|
|
distinct_id=str(user.provider_id),
|
|
event=PostHogEvent.TOOL_CREATED,
|
|
properties={
|
|
"tool_name": request.name,
|
|
"tool_category": request.category,
|
|
"organization_id": user.selected_organization_id,
|
|
},
|
|
)
|
|
|
|
return build_tool_response(tool)
|
|
|
|
|
|
@router.get("/{tool_uuid}")
|
|
async def get_tool(
|
|
tool_uuid: str,
|
|
user: UserModel = Depends(get_user),
|
|
) -> ToolResponse:
|
|
"""
|
|
Get a specific tool by UUID.
|
|
|
|
Args:
|
|
tool_uuid: The UUID of the tool
|
|
|
|
Returns:
|
|
The tool
|
|
"""
|
|
if not user.selected_organization_id:
|
|
raise HTTPException(
|
|
status_code=400, detail="No organization selected for the user"
|
|
)
|
|
|
|
tool = await db_client.get_tool_by_uuid(
|
|
tool_uuid, user.selected_organization_id, include_archived=True
|
|
)
|
|
|
|
if not tool:
|
|
raise HTTPException(status_code=404, detail="Tool not found")
|
|
|
|
return build_tool_response(tool, include_created_by=True)
|
|
|
|
|
|
@router.post("/{tool_uuid}/mcp/refresh")
|
|
async def refresh_mcp_tools(
|
|
tool_uuid: str,
|
|
user: UserModel = Depends(get_user),
|
|
) -> McpRefreshResponse:
|
|
"""Re-discover an MCP tool's server catalog and overwrite the cached
|
|
``definition.config.discovered_tools``. Server down → 200 with error
|
|
(cache not overwritten on transient failure)."""
|
|
if not user.selected_organization_id:
|
|
raise HTTPException(
|
|
status_code=400, detail="No organization selected for the user"
|
|
)
|
|
|
|
tool = await db_client.get_tool_by_uuid(
|
|
tool_uuid, user.selected_organization_id, include_archived=True
|
|
)
|
|
if not tool:
|
|
raise HTTPException(status_code=404, detail="Tool not found")
|
|
if tool.category != ToolCategory.MCP.value:
|
|
raise HTTPException(status_code=400, detail="Tool is not an MCP tool")
|
|
|
|
try:
|
|
cfg = validate_mcp_definition(tool.definition)
|
|
except McpDefinitionError as e:
|
|
raise HTTPException(status_code=400, detail=f"Invalid MCP definition: {e}")
|
|
|
|
credential = await _fetch_credential(
|
|
cfg.get("credential_uuid"), user.selected_organization_id
|
|
)
|
|
|
|
try:
|
|
discovered = await discover_mcp_tools(
|
|
url=cfg["url"],
|
|
credential=credential,
|
|
timeout_secs=cfg["timeout_secs"],
|
|
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
|
)
|
|
except Exception as e: # noqa: BLE001
|
|
logger.warning(f"MCP refresh discovery failed: {e}")
|
|
discovered = []
|
|
|
|
if not discovered:
|
|
error = (
|
|
f"Could not reach the MCP server at {cfg['url']} "
|
|
f"(or it exposes no tools). Previously cached list retained."
|
|
)
|
|
# Do NOT clobber a previously-good cache with [] on a transient outage.
|
|
return McpRefreshResponse(tool_uuid=tool_uuid, discovered_tools=[], error=error)
|
|
|
|
new_def = dict(tool.definition or {})
|
|
new_def["config"] = {**new_def.get("config", {}), "discovered_tools": discovered}
|
|
await db_client.update_tool(
|
|
tool_uuid=tool_uuid,
|
|
organization_id=user.selected_organization_id,
|
|
definition=new_def,
|
|
)
|
|
return McpRefreshResponse(
|
|
tool_uuid=tool_uuid, discovered_tools=discovered, error=None
|
|
)
|
|
|
|
|
|
@router.put("/{tool_uuid}")
|
|
async def update_tool(
|
|
tool_uuid: str,
|
|
request: UpdateToolRequest,
|
|
user: UserModel = Depends(get_user),
|
|
) -> ToolResponse:
|
|
"""
|
|
Update a tool.
|
|
|
|
Args:
|
|
tool_uuid: The UUID of the tool to update
|
|
request: The update request
|
|
|
|
Returns:
|
|
The updated tool
|
|
"""
|
|
if not user.selected_organization_id:
|
|
raise HTTPException(
|
|
status_code=400, detail="No organization selected for the user"
|
|
)
|
|
|
|
if request.status:
|
|
validate_status(request.status)
|
|
|
|
definition = (
|
|
await _populate_discovered_tools(
|
|
request.definition.model_dump(),
|
|
organization_id=user.selected_organization_id,
|
|
)
|
|
if request.definition
|
|
else None
|
|
)
|
|
|
|
tool = await db_client.update_tool(
|
|
tool_uuid=tool_uuid,
|
|
organization_id=user.selected_organization_id,
|
|
name=request.name,
|
|
description=request.description,
|
|
definition=definition,
|
|
icon=request.icon,
|
|
icon_color=request.icon_color,
|
|
status=request.status,
|
|
)
|
|
|
|
if not tool:
|
|
raise HTTPException(status_code=404, detail="Tool not found")
|
|
|
|
return build_tool_response(tool, include_created_by=True)
|
|
|
|
|
|
@router.delete("/{tool_uuid}")
|
|
async def delete_tool(
|
|
tool_uuid: str,
|
|
user: UserModel = Depends(get_user),
|
|
) -> dict:
|
|
"""
|
|
Archive (soft delete) a tool.
|
|
|
|
Args:
|
|
tool_uuid: The UUID of the tool to delete
|
|
|
|
Returns:
|
|
Success message
|
|
"""
|
|
if not user.selected_organization_id:
|
|
raise HTTPException(
|
|
status_code=400, detail="No organization selected for the user"
|
|
)
|
|
|
|
deleted = await db_client.archive_tool(tool_uuid, user.selected_organization_id)
|
|
|
|
if not deleted:
|
|
raise HTTPException(status_code=404, detail="Tool not found")
|
|
|
|
return {"status": "archived", "tool_uuid": tool_uuid}
|
|
|
|
|
|
@router.post("/{tool_uuid}/unarchive")
|
|
async def unarchive_tool(
|
|
tool_uuid: str,
|
|
user: UserModel = Depends(get_user),
|
|
) -> ToolResponse:
|
|
"""
|
|
Unarchive a tool (restore from archived state).
|
|
|
|
Args:
|
|
tool_uuid: The UUID of the tool to unarchive
|
|
|
|
Returns:
|
|
The unarchived tool
|
|
"""
|
|
if not user.selected_organization_id:
|
|
raise HTTPException(
|
|
status_code=400, detail="No organization selected for the user"
|
|
)
|
|
|
|
tool = await db_client.unarchive_tool(tool_uuid, user.selected_organization_id)
|
|
|
|
if not tool:
|
|
raise HTTPException(status_code=404, detail="Tool not found")
|
|
|
|
return build_tool_response(tool)
|