mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: create tools using MCP
This commit is contained in:
parent
5c29b6ed94
commit
fcb7004c7a
17 changed files with 1989 additions and 572 deletions
|
|
@ -36,6 +36,11 @@ The guide tool is the authoritative source for prompt-authoring craft (turn-taki
|
|||
|
||||
## Call order
|
||||
|
||||
### Creating a reusable tool
|
||||
1. If authentication is needed, call `list_credentials` and use an existing `credential_uuid`; the user creates credential secrets in the UI.
|
||||
2. Build a typed tool definition and call `create_tool`. The request schema is authoritative for allowed tool categories and config fields.
|
||||
3. Use the returned `tool_uuid` in workflow node `tool_uuids`, then call `create_workflow` or `save_workflow`.
|
||||
|
||||
### Reading documentation
|
||||
1. `search_docs` — use first for keyword or acronym lookup when the user is asking how Dograh works or how to configure something.
|
||||
2. `read_doc` — fetch the full page once one result looks likely. Prefer this over reasoning from search summaries alone.
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from api.mcp_server.tools.docs_search import list_docs, read_doc, search_docs
|
|||
from api.mcp_server.tools.get_workflow_code import get_workflow_code
|
||||
from api.mcp_server.tools.node_types import get_node_type, list_node_types
|
||||
from api.mcp_server.tools.save_workflow import save_workflow
|
||||
from api.mcp_server.tools.tool_creation import create_tool
|
||||
from api.mcp_server.tools.voice_prompting_guide import get_voice_prompting_guide
|
||||
from api.mcp_server.tools.workflows import get_workflow, list_workflows
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ mcp = FastMCP("dograh", instructions=DOGRAH_MCP_INSTRUCTIONS)
|
|||
|
||||
for _tool in (
|
||||
create_workflow,
|
||||
create_tool,
|
||||
get_node_type,
|
||||
get_workflow,
|
||||
get_workflow_code,
|
||||
|
|
|
|||
63
api/mcp_server/tools/tool_creation.py
Normal file
63
api/mcp_server/tools/tool_creation.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
"""MCP tool for creating reusable Dograh tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ValidationError as PydanticValidationError
|
||||
|
||||
from api.mcp_server.auth import authenticate_mcp_request
|
||||
from api.mcp_server.tracing import traced_tool
|
||||
from api.schemas.tool import CreateToolRequest
|
||||
from api.services.tool_management import ToolManagementError, create_tool_for_user
|
||||
|
||||
|
||||
def _error_result(code: str, message: str, **extra: Any) -> dict[str, Any]:
|
||||
return {"created": False, "error_code": code, "error": message, **extra}
|
||||
|
||||
|
||||
@traced_tool
|
||||
async def create_tool(request: CreateToolRequest) -> dict[str, Any]:
|
||||
"""Create a reusable tool the agent can invoke during calls.
|
||||
|
||||
The request schema is the same `CreateToolRequest` used by the REST API
|
||||
and generated SDKs. Use it to create HTTP API, end-call, transfer-call,
|
||||
calculator, or MCP-server tools. For authenticated HTTP or MCP tools,
|
||||
reference an existing `credential_uuid` from `list_credentials`; users
|
||||
create credential secrets in the UI, and this flow only stores the UUID
|
||||
reference. For MCP tools, the server best-effort discovers the remote
|
||||
tool catalog and caches it in `definition.config.discovered_tools`.
|
||||
|
||||
On success, returns `created: true` and the new `tool_uuid`; use that
|
||||
UUID in workflow node `tool_uuids`. On failure, returns `created: false`,
|
||||
a machine-readable `error_code`, and a human-readable `error`. Possible
|
||||
`error_code` values:
|
||||
- `validation_error` — the request failed schema validation.
|
||||
- `credential_not_found` — a supplied credential_uuid is not in this
|
||||
organization; ask the user to create/select it in the UI first.
|
||||
- `organization_required` — the API key user has no selected organization.
|
||||
- `create_failed` — unexpected persistence or backend failure; retry once,
|
||||
then surface the error.
|
||||
"""
|
||||
user = await authenticate_mcp_request()
|
||||
|
||||
try:
|
||||
parsed_request = CreateToolRequest.model_validate(request)
|
||||
except PydanticValidationError as e:
|
||||
return _error_result("validation_error", str(e))
|
||||
|
||||
try:
|
||||
tool = await create_tool_for_user(parsed_request, user, source="mcp")
|
||||
except ToolManagementError as e:
|
||||
return _error_result(e.error_code, e.message)
|
||||
except Exception as e: # noqa: BLE001
|
||||
return _error_result("create_failed", str(e))
|
||||
|
||||
return {
|
||||
"created": True,
|
||||
"tool_uuid": tool.tool_uuid,
|
||||
"name": tool.name,
|
||||
"category": tool.category,
|
||||
"status": tool.status,
|
||||
"definition": tool.definition,
|
||||
}
|
||||
|
|
@ -1,303 +1,68 @@
|
|||
"""API routes for managing tools."""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import PostHogEvent, ToolCategory, ToolStatus
|
||||
from api.enums import ToolCategory, ToolStatus
|
||||
from api.schemas.tool import (
|
||||
CalculatorToolDefinition,
|
||||
CreatedByResponse,
|
||||
CreateToolRequest,
|
||||
EndCallConfig,
|
||||
EndCallToolDefinition,
|
||||
HttpApiConfig,
|
||||
HttpApiToolDefinition,
|
||||
McpRefreshResponse,
|
||||
McpToolConfig,
|
||||
McpToolDefinition,
|
||||
PresetToolParameter,
|
||||
ToolDefinition,
|
||||
ToolParameter,
|
||||
ToolResponse,
|
||||
TransferCallConfig,
|
||||
TransferCallToolDefinition,
|
||||
UpdateToolRequest,
|
||||
)
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.workflow.mcp_tool_session import discover_mcp_tools
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpDefinitionError,
|
||||
validate_mcp_definition,
|
||||
from api.services.tool_management import (
|
||||
ToolManagementError,
|
||||
build_tool_response,
|
||||
create_tool_for_user,
|
||||
refresh_mcp_tool_for_user,
|
||||
validate_tool_credential_references,
|
||||
)
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpToolConfig as SharedMcpToolConfig,
|
||||
)
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpToolDefinition as SharedMcpToolDefinition,
|
||||
from api.services.tool_management import (
|
||||
populate_discovered_tools as _populate_discovered_tools,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/tools")
|
||||
|
||||
McpToolConfig = SharedMcpToolConfig
|
||||
McpToolDefinition = SharedMcpToolDefinition
|
||||
|
||||
|
||||
# Request/Response schemas
|
||||
class ToolParameter(BaseModel):
|
||||
"""A parameter that the tool accepts."""
|
||||
|
||||
name: str = Field(description="Parameter name (used as key in request body)")
|
||||
type: str = Field(description="Parameter type: string, number, or boolean")
|
||||
description: str = Field(description="Description of what this parameter is for")
|
||||
required: bool = Field(
|
||||
default=True, description="Whether this parameter is required"
|
||||
)
|
||||
|
||||
|
||||
class PresetToolParameter(BaseModel):
|
||||
"""A parameter injected by Dograh at runtime."""
|
||||
|
||||
name: str = Field(description="Parameter name (used as key in request body)")
|
||||
type: str = Field(description="Parameter type: string, number, or boolean")
|
||||
value_template: str = Field(
|
||||
description="Fixed value or template, e.g. {{initial_context.phone_number}}"
|
||||
)
|
||||
required: bool = Field(
|
||||
default=True,
|
||||
description="Whether the parameter must resolve to a non-empty value",
|
||||
)
|
||||
|
||||
|
||||
class HttpApiConfig(BaseModel):
|
||||
"""Configuration for HTTP API tools."""
|
||||
|
||||
method: str = Field(description="HTTP method (GET, POST, PUT, PATCH, DELETE)")
|
||||
url: str = Field(description="Target URL")
|
||||
headers: Optional[Dict[str, str]] = Field(
|
||||
default=None, description="Static headers to include"
|
||||
)
|
||||
credential_uuid: Optional[str] = Field(
|
||||
default=None, description="Reference to ExternalCredentialModel for auth"
|
||||
)
|
||||
parameters: Optional[List[ToolParameter]] = Field(
|
||||
default=None, description="Parameters that the tool accepts from LLM"
|
||||
)
|
||||
preset_parameters: Optional[List[PresetToolParameter]] = Field(
|
||||
default=None,
|
||||
description="Parameters injected by Dograh from fixed values or workflow context templates",
|
||||
)
|
||||
timeout_ms: Optional[int] = Field(
|
||||
default=5000, description="Request timeout in milliseconds"
|
||||
)
|
||||
customMessage: Optional[str] = Field(
|
||||
default=None, description="Custom message to play after tool execution"
|
||||
)
|
||||
customMessageType: Optional[Literal["text", "audio"]] = Field(
|
||||
default=None, description="Type of custom message: text or audio"
|
||||
)
|
||||
customMessageRecordingId: Optional[str] = Field(
|
||||
default=None, description="Recording ID for audio custom message"
|
||||
)
|
||||
|
||||
|
||||
class EndCallConfig(BaseModel):
|
||||
"""Configuration for End Call tools."""
|
||||
|
||||
messageType: Literal["none", "custom", "audio"] = Field(
|
||||
default="none", description="Type of goodbye message"
|
||||
)
|
||||
customMessage: Optional[str] = Field(
|
||||
default=None, description="Custom message to play before ending the call"
|
||||
)
|
||||
audioRecordingId: Optional[str] = Field(
|
||||
default=None, description="Recording ID for audio goodbye message"
|
||||
)
|
||||
endCallReason: bool = Field(
|
||||
default=False,
|
||||
description="When enabled, LLM must provide a reason for ending the call. "
|
||||
"The reason is set as call disposition and added to call tags.",
|
||||
)
|
||||
endCallReasonDescription: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Description shown to the LLM for the reason parameter. "
|
||||
"Used only when endCallReason is enabled.",
|
||||
)
|
||||
|
||||
|
||||
class TransferCallConfig(BaseModel):
|
||||
"""Configuration for Transfer Call tools."""
|
||||
|
||||
destination: str = Field(
|
||||
description="Phone number or SIP endpoint to transfer the call to (E.164 format e.g., +1234567890, or SIP endpoint e.g., PJSIP/1234)"
|
||||
)
|
||||
messageType: Literal["none", "custom", "audio"] = Field(
|
||||
default="none", description="Type of message to play before transfer"
|
||||
)
|
||||
customMessage: Optional[str] = Field(
|
||||
default=None, description="Custom message to play before transferring the call"
|
||||
)
|
||||
audioRecordingId: Optional[str] = Field(
|
||||
default=None, description="Recording ID for audio message before transfer"
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=30,
|
||||
ge=5,
|
||||
le=120,
|
||||
description="Maximum time in seconds to wait for destination to answer (5-120 seconds)",
|
||||
)
|
||||
|
||||
@field_validator("destination")
|
||||
@classmethod
|
||||
def validate_destination(cls, v: str) -> str:
|
||||
"""Validate that destination is a valid E.164 phone number or SIP endpoint."""
|
||||
# Allow empty string for initial creation (like HTTP API tools with empty URL)
|
||||
if not v.strip():
|
||||
return v
|
||||
|
||||
# E.164 format: +[1-9]\d{1,14}
|
||||
e164_pattern = r"^\+[1-9]\d{1,14}$"
|
||||
|
||||
# SIP endpoint format: PJSIP/extension or SIP/extension
|
||||
sip_pattern = r"^(PJSIP|SIP)/[\w\-\.@]+$"
|
||||
|
||||
is_valid_e164 = re.match(e164_pattern, v)
|
||||
is_valid_sip = re.match(sip_pattern, v, re.IGNORECASE)
|
||||
|
||||
if not (is_valid_e164 or is_valid_sip):
|
||||
raise ValueError(
|
||||
"Destination must be a valid E.164 phone number (e.g., +1234567890) or SIP endpoint (e.g., PJSIP/1234)"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class HttpApiToolDefinition(BaseModel):
|
||||
"""Tool definition for HTTP API tools."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version")
|
||||
type: Literal["http_api"] = Field(description="Tool type")
|
||||
config: HttpApiConfig = Field(description="HTTP API configuration")
|
||||
|
||||
|
||||
class EndCallToolDefinition(BaseModel):
|
||||
"""Tool definition for End Call tools."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version")
|
||||
type: Literal["end_call"] = Field(description="Tool type")
|
||||
config: EndCallConfig = Field(description="End Call configuration")
|
||||
|
||||
|
||||
class TransferCallToolDefinition(BaseModel):
|
||||
"""Tool definition for Transfer Call tools."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version")
|
||||
type: Literal["transfer_call"] = Field(description="Tool type")
|
||||
config: TransferCallConfig = Field(description="Transfer Call configuration")
|
||||
|
||||
|
||||
class CalculatorToolDefinition(BaseModel):
|
||||
"""Tool definition for Calculator tools (no configuration needed)."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version")
|
||||
type: Literal["calculator"] = Field(description="Tool type")
|
||||
|
||||
|
||||
# Union type for tool definitions - Pydantic will discriminate based on 'type' field
|
||||
ToolDefinition = Annotated[
|
||||
Union[
|
||||
HttpApiToolDefinition,
|
||||
EndCallToolDefinition,
|
||||
TransferCallToolDefinition,
|
||||
CalculatorToolDefinition,
|
||||
McpToolDefinition,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
__all__ = [
|
||||
"CalculatorToolDefinition",
|
||||
"CreateToolRequest",
|
||||
"CreatedByResponse",
|
||||
"EndCallConfig",
|
||||
"EndCallToolDefinition",
|
||||
"HttpApiConfig",
|
||||
"HttpApiToolDefinition",
|
||||
"McpRefreshResponse",
|
||||
"McpToolConfig",
|
||||
"McpToolDefinition",
|
||||
"PresetToolParameter",
|
||||
"ToolDefinition",
|
||||
"ToolParameter",
|
||||
"ToolResponse",
|
||||
"TransferCallConfig",
|
||||
"TransferCallToolDefinition",
|
||||
"UpdateToolRequest",
|
||||
"_populate_discovered_tools",
|
||||
]
|
||||
|
||||
|
||||
class CreateToolRequest(BaseModel):
|
||||
"""Request schema for creating a tool."""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
description: Optional[str] = None
|
||||
category: str = Field(default=ToolCategory.HTTP_API.value)
|
||||
icon: Optional[str] = Field(default="globe", max_length=50)
|
||||
icon_color: Optional[str] = Field(default="#3B82F6", max_length=7)
|
||||
definition: ToolDefinition
|
||||
|
||||
@field_validator("category")
|
||||
@classmethod
|
||||
def validate_category(cls, v: str) -> str:
|
||||
"""Validate that category is a valid ToolCategory value."""
|
||||
valid_categories = [c.value for c in ToolCategory]
|
||||
if v not in valid_categories:
|
||||
raise ValueError(
|
||||
f"Invalid category '{v}'. Must be one of: {', '.join(valid_categories)}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class UpdateToolRequest(BaseModel):
|
||||
"""Request schema for updating a tool."""
|
||||
|
||||
name: Optional[str] = Field(default=None, max_length=255)
|
||||
description: Optional[str] = None
|
||||
icon: Optional[str] = Field(default=None, max_length=50)
|
||||
icon_color: Optional[str] = Field(default=None, max_length=7)
|
||||
definition: Optional[ToolDefinition] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class CreatedByResponse(BaseModel):
|
||||
"""Response schema for the user who created a tool."""
|
||||
|
||||
id: int
|
||||
provider_id: str
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
"""Response schema for a tool."""
|
||||
|
||||
id: int
|
||||
tool_uuid: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
category: str
|
||||
icon: Optional[str]
|
||||
icon_color: Optional[str]
|
||||
status: str
|
||||
definition: Dict[str, Any]
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
created_by: Optional[CreatedByResponse] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class McpRefreshResponse(BaseModel):
|
||||
"""Result of re-discovering an MCP server's tool catalog."""
|
||||
|
||||
tool_uuid: str
|
||||
discovered_tools: list = Field(default_factory=list)
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def build_tool_response(tool, include_created_by: bool = False) -> ToolResponse:
|
||||
"""Build a response from a tool model."""
|
||||
created_by = None
|
||||
if include_created_by and tool.created_by_user:
|
||||
created_by = CreatedByResponse(
|
||||
id=tool.created_by_user.id,
|
||||
provider_id=tool.created_by_user.provider_id,
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
id=tool.id,
|
||||
tool_uuid=tool.tool_uuid,
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
category=tool.category,
|
||||
icon=tool.icon,
|
||||
icon_color=tool.icon_color,
|
||||
status=tool.status,
|
||||
definition=tool.definition,
|
||||
created_at=tool.created_at,
|
||||
updated_at=tool.updated_at,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
|
||||
def validate_category(category: str) -> None:
|
||||
"""Validate that the category is valid."""
|
||||
valid_categories = [c.value for c in ToolCategory]
|
||||
|
|
@ -361,53 +126,13 @@ async def list_tools(
|
|||
return [build_tool_response(tool) for tool in tools]
|
||||
|
||||
|
||||
async def _fetch_credential(credential_uuid: Optional[str], organization_id: int):
|
||||
"""Best-effort credential lookup for MCP auth. A missing/failed credential
|
||||
degrades to ``None`` (unauthenticated) rather than failing the request."""
|
||||
if not credential_uuid:
|
||||
return None
|
||||
try:
|
||||
return await db_client.get_credential_by_uuid(credential_uuid, organization_id)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(f"MCP: credential fetch failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _populate_discovered_tools(definition: dict, *, organization_id: int) -> dict:
|
||||
"""Best-effort: for an MCP definition, connect to the server, list its
|
||||
tools, and overwrite ``config.discovered_tools``. Never raises and never
|
||||
blocks tool save — a dead server yields ``discovered_tools: []``. Non-MCP
|
||||
definitions pass through untouched."""
|
||||
if not isinstance(definition, dict) or definition.get("type") != "mcp":
|
||||
return definition
|
||||
try:
|
||||
cfg = validate_mcp_definition(definition)
|
||||
except McpDefinitionError:
|
||||
return definition
|
||||
|
||||
credential = await _fetch_credential(cfg.get("credential_uuid"), organization_id)
|
||||
|
||||
# Run discovery in an isolated asyncio task so an anyio cancel-scope
|
||||
# CancelledError doesn't bleed into the parent task and corrupt the
|
||||
# subsequent DB write. _run() never raises (degrades to []).
|
||||
async def _run() -> list:
|
||||
try:
|
||||
return await discover_mcp_tools(
|
||||
url=cfg["url"],
|
||||
credential=credential,
|
||||
timeout_secs=cfg["timeout_secs"],
|
||||
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
||||
)
|
||||
except BaseException as e: # noqa: BLE001
|
||||
logger.warning(f"MCP discovery failed; caching empty list: {e}")
|
||||
return []
|
||||
|
||||
discovered = await asyncio.ensure_future(_run())
|
||||
definition["config"]["discovered_tools"] = discovered
|
||||
return definition
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@router.post(
|
||||
"/",
|
||||
**sdk_expose(
|
||||
method="create_tool",
|
||||
description="Create a reusable tool for the authenticated organization.",
|
||||
),
|
||||
)
|
||||
async def create_tool(
|
||||
request: CreateToolRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
|
|
@ -421,40 +146,10 @@ async def create_tool(
|
|||
Returns:
|
||||
The created tool
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
validate_category(request.category)
|
||||
|
||||
definition = await _populate_discovered_tools(
|
||||
request.definition.model_dump(),
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
tool = await db_client.create_tool(
|
||||
organization_id=user.selected_organization_id,
|
||||
user_id=user.id,
|
||||
name=request.name,
|
||||
definition=definition,
|
||||
category=request.category,
|
||||
description=request.description,
|
||||
icon=request.icon,
|
||||
icon_color=request.icon_color,
|
||||
)
|
||||
|
||||
capture_event(
|
||||
distinct_id=str(user.provider_id),
|
||||
event=PostHogEvent.TOOL_CREATED,
|
||||
properties={
|
||||
"tool_name": request.name,
|
||||
"tool_category": request.category,
|
||||
"organization_id": user.selected_organization_id,
|
||||
},
|
||||
)
|
||||
|
||||
return build_tool_response(tool)
|
||||
try:
|
||||
return await create_tool_for_user(request, user, source="api")
|
||||
except ToolManagementError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message) from e
|
||||
|
||||
|
||||
@router.get("/{tool_uuid}")
|
||||
|
|
@ -494,57 +189,10 @@ async def refresh_mcp_tools(
|
|||
"""Re-discover an MCP tool's server catalog and overwrite the cached
|
||||
``definition.config.discovered_tools``. Server down → 200 with error
|
||||
(cache not overwritten on transient failure)."""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
tool = await db_client.get_tool_by_uuid(
|
||||
tool_uuid, user.selected_organization_id, include_archived=True
|
||||
)
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
if tool.category != ToolCategory.MCP.value:
|
||||
raise HTTPException(status_code=400, detail="Tool is not an MCP tool")
|
||||
|
||||
try:
|
||||
cfg = validate_mcp_definition(tool.definition)
|
||||
except McpDefinitionError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid MCP definition: {e}")
|
||||
|
||||
credential = await _fetch_credential(
|
||||
cfg.get("credential_uuid"), user.selected_organization_id
|
||||
)
|
||||
|
||||
try:
|
||||
discovered = await discover_mcp_tools(
|
||||
url=cfg["url"],
|
||||
credential=credential,
|
||||
timeout_secs=cfg["timeout_secs"],
|
||||
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(f"MCP refresh discovery failed: {e}")
|
||||
discovered = []
|
||||
|
||||
if not discovered:
|
||||
error = (
|
||||
f"Could not reach the MCP server at {cfg['url']} "
|
||||
f"(or it exposes no tools). Previously cached list retained."
|
||||
)
|
||||
# Do NOT clobber a previously-good cache with [] on a transient outage.
|
||||
return McpRefreshResponse(tool_uuid=tool_uuid, discovered_tools=[], error=error)
|
||||
|
||||
new_def = dict(tool.definition or {})
|
||||
new_def["config"] = {**new_def.get("config", {}), "discovered_tools": discovered}
|
||||
await db_client.update_tool(
|
||||
tool_uuid=tool_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
definition=new_def,
|
||||
)
|
||||
return McpRefreshResponse(
|
||||
tool_uuid=tool_uuid, discovered_tools=discovered, error=None
|
||||
)
|
||||
return await refresh_mcp_tool_for_user(tool_uuid, user)
|
||||
except ToolManagementError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message) from e
|
||||
|
||||
|
||||
@router.put("/{tool_uuid}")
|
||||
|
|
@ -571,14 +219,20 @@ async def update_tool(
|
|||
if request.status:
|
||||
validate_status(request.status)
|
||||
|
||||
definition = (
|
||||
await _populate_discovered_tools(
|
||||
request.definition.model_dump(),
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
if request.definition
|
||||
else None
|
||||
)
|
||||
definition = None
|
||||
if request.definition:
|
||||
definition = request.definition.model_dump()
|
||||
try:
|
||||
await validate_tool_credential_references(
|
||||
definition,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
definition = await _populate_discovered_tools(
|
||||
definition,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
except ToolManagementError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message) from e
|
||||
|
||||
tool = await db_client.update_tool(
|
||||
tool_uuid=tool_uuid,
|
||||
|
|
|
|||
440
api/schemas/tool.py
Normal file
440
api/schemas/tool.py
Normal file
|
|
@ -0,0 +1,440 @@
|
|||
"""Pydantic schemas for reusable Dograh tools.
|
||||
|
||||
These models are the single contract for tool creation/update across the
|
||||
REST API, generated SDKs, and the MCP authoring surface. Field descriptions
|
||||
are human/API-facing; ``llm_hint`` JSON schema extras are guidance for LLMs
|
||||
when the same schema is surfaced through MCP or SDK authoring flows.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from api.enums import ToolCategory
|
||||
|
||||
DEFAULT_MCP_TIMEOUT_SECS = 30
|
||||
DEFAULT_MCP_SSE_READ_TIMEOUT_SECS = 300
|
||||
|
||||
ToolParameterType = Literal["string", "number", "boolean"]
|
||||
HttpMethod = Literal["GET", "POST", "PUT", "PATCH", "DELETE"]
|
||||
ToolCategoryValue = Literal[
|
||||
"http_api",
|
||||
"end_call",
|
||||
"transfer_call",
|
||||
"calculator",
|
||||
"native",
|
||||
"integration",
|
||||
"mcp",
|
||||
]
|
||||
|
||||
|
||||
def _llm_hint(text: str) -> dict[str, str]:
|
||||
return {"llm_hint": text}
|
||||
|
||||
|
||||
class ToolParameter(BaseModel):
|
||||
"""A parameter that the tool accepts from the model at call time."""
|
||||
|
||||
name: str = Field(
|
||||
description="Parameter name used as a key in the tool request body.",
|
||||
json_schema_extra=_llm_hint(
|
||||
"Use a stable snake_case name the agent can naturally fill."
|
||||
),
|
||||
)
|
||||
type: ToolParameterType = Field(
|
||||
description="JSON type for the parameter value.",
|
||||
json_schema_extra=_llm_hint("Allowed values are string, number, and boolean."),
|
||||
)
|
||||
description: str = Field(
|
||||
description="Description shown to the model for this parameter.",
|
||||
json_schema_extra=_llm_hint(
|
||||
"Write this as an instruction to the agent: what value to provide and when."
|
||||
),
|
||||
)
|
||||
required: bool = Field(
|
||||
default=True,
|
||||
description="Whether this parameter is required when the tool is called.",
|
||||
)
|
||||
|
||||
|
||||
class PresetToolParameter(BaseModel):
|
||||
"""A parameter injected by Dograh at runtime."""
|
||||
|
||||
name: str = Field(description="Parameter name used as a key in the request body.")
|
||||
type: ToolParameterType = Field(description="JSON type for the resolved value.")
|
||||
value_template: str = Field(
|
||||
description="Fixed value or template, e.g. {{initial_context.phone_number}}.",
|
||||
json_schema_extra=_llm_hint(
|
||||
"Use {{initial_context.*}} for call-start context and "
|
||||
"{{gathered_context.*}} for values extracted during the call."
|
||||
),
|
||||
)
|
||||
required: bool = Field(
|
||||
default=True,
|
||||
description="Whether the parameter must resolve to a non-empty value.",
|
||||
)
|
||||
|
||||
|
||||
class HttpApiConfig(BaseModel):
|
||||
"""Configuration for HTTP API tools."""
|
||||
|
||||
method: HttpMethod = Field(
|
||||
description="HTTP method to use for the request.",
|
||||
json_schema_extra=_llm_hint("Use one of GET, POST, PUT, PATCH, DELETE."),
|
||||
)
|
||||
url: str = Field(
|
||||
description="Target HTTP or HTTPS URL.",
|
||||
json_schema_extra=_llm_hint(
|
||||
"Use the final endpoint URL. Authentication belongs in credential_uuid, "
|
||||
"not embedded in the URL."
|
||||
),
|
||||
)
|
||||
headers: Optional[Dict[str, str]] = Field(
|
||||
default=None,
|
||||
description="Static headers to include with every request.",
|
||||
json_schema_extra=_llm_hint(
|
||||
"Do not place secrets here. Store secrets in the UI credential manager "
|
||||
"and reference them with credential_uuid."
|
||||
),
|
||||
)
|
||||
credential_uuid: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Reference to an external credential for request authentication.",
|
||||
json_schema_extra=_llm_hint(
|
||||
"Use a credential_uuid returned by list_credentials. The MCP flow does "
|
||||
"not create credential secrets."
|
||||
),
|
||||
)
|
||||
parameters: Optional[List[ToolParameter]] = Field(
|
||||
default=None,
|
||||
description="Parameters the model must provide when calling this tool.",
|
||||
)
|
||||
preset_parameters: Optional[List[PresetToolParameter]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Parameters injected by Dograh from fixed values or workflow context "
|
||||
"templates."
|
||||
),
|
||||
)
|
||||
timeout_ms: Optional[int] = Field(
|
||||
default=5000,
|
||||
ge=1,
|
||||
description="Request timeout in milliseconds.",
|
||||
)
|
||||
customMessage: Optional[str] = Field(
|
||||
default=None, description="Custom message to play after tool execution."
|
||||
)
|
||||
customMessageType: Optional[Literal["text", "audio"]] = Field(
|
||||
default=None, description="Type of custom message."
|
||||
)
|
||||
customMessageRecordingId: Optional[str] = Field(
|
||||
default=None, description="Recording ID for an audio custom message."
|
||||
)
|
||||
|
||||
@field_validator("method", mode="before")
|
||||
@classmethod
|
||||
def validate_method(cls, v: Any) -> str:
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("method must be one of GET, POST, PUT, PATCH, DELETE")
|
||||
method = v.upper()
|
||||
if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}:
|
||||
raise ValueError("method must be one of GET, POST, PUT, PATCH, DELETE")
|
||||
return method
|
||||
|
||||
|
||||
class EndCallConfig(BaseModel):
|
||||
"""Configuration for End Call tools."""
|
||||
|
||||
messageType: Literal["none", "custom", "audio"] = Field(
|
||||
default="none", description="Type of goodbye message."
|
||||
)
|
||||
customMessage: Optional[str] = Field(
|
||||
default=None, description="Custom message to play before ending the call."
|
||||
)
|
||||
audioRecordingId: Optional[str] = Field(
|
||||
default=None, description="Recording ID for audio goodbye message."
|
||||
)
|
||||
endCallReason: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"When enabled, the model must provide a reason for ending the call. "
|
||||
"The reason is set as call disposition and added to call tags."
|
||||
),
|
||||
)
|
||||
endCallReasonDescription: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Description shown to the model for the reason parameter. Used only "
|
||||
"when endCallReason is enabled."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TransferCallConfig(BaseModel):
|
||||
"""Configuration for Transfer Call tools."""
|
||||
|
||||
destination: str = Field(
|
||||
description=(
|
||||
"Phone number or SIP endpoint to transfer the call to, e.g. "
|
||||
"+1234567890 or PJSIP/1234."
|
||||
)
|
||||
)
|
||||
messageType: Literal["none", "custom", "audio"] = Field(
|
||||
default="none", description="Type of message to play before transfer."
|
||||
)
|
||||
customMessage: Optional[str] = Field(
|
||||
default=None, description="Custom message to play before transferring."
|
||||
)
|
||||
audioRecordingId: Optional[str] = Field(
|
||||
default=None, description="Recording ID for audio message before transfer."
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=30,
|
||||
ge=5,
|
||||
le=120,
|
||||
description="Maximum seconds to wait for the destination to answer.",
|
||||
)
|
||||
|
||||
@field_validator("destination")
|
||||
@classmethod
|
||||
def validate_destination(cls, v: str) -> str:
|
||||
"""Validate that destination is a valid E.164 phone number or SIP endpoint."""
|
||||
if not v.strip():
|
||||
return v
|
||||
|
||||
e164_pattern = r"^\+[1-9]\d{1,14}$"
|
||||
sip_pattern = r"^(PJSIP|SIP)/[\w\-\.@]+$"
|
||||
|
||||
is_valid_e164 = re.match(e164_pattern, v)
|
||||
is_valid_sip = re.match(sip_pattern, v, re.IGNORECASE)
|
||||
|
||||
if not (is_valid_e164 or is_valid_sip):
|
||||
raise ValueError(
|
||||
"Destination must be a valid E.164 phone number "
|
||||
"(e.g., +1234567890) or SIP endpoint (e.g., PJSIP/1234)"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class McpToolConfig(BaseModel):
|
||||
"""Configuration for a customer MCP server tool definition."""
|
||||
|
||||
transport: Literal["streamable_http"] = Field(
|
||||
default="streamable_http",
|
||||
description="MCP transport protocol.",
|
||||
)
|
||||
url: str = Field(
|
||||
description="MCP server URL. Must use http:// or https://.",
|
||||
json_schema_extra=_llm_hint("Use the server's streamable HTTP MCP endpoint."),
|
||||
)
|
||||
credential_uuid: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Reference to an external credential for MCP server auth.",
|
||||
json_schema_extra=_llm_hint(
|
||||
"Use a credential_uuid returned by list_credentials. Credentials are "
|
||||
"created by the user in the UI."
|
||||
),
|
||||
)
|
||||
tools_filter: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Allowlist of MCP tool names to expose. Empty exposes all tools.",
|
||||
json_schema_extra=_llm_hint(
|
||||
"Use exact MCP tool names from the remote server catalog when you need "
|
||||
"to restrict the exposed tools."
|
||||
),
|
||||
)
|
||||
timeout_secs: int = Field(
|
||||
default=DEFAULT_MCP_TIMEOUT_SECS,
|
||||
ge=0,
|
||||
description="Connection timeout in seconds.",
|
||||
)
|
||||
sse_read_timeout_secs: int = Field(
|
||||
default=DEFAULT_MCP_SSE_READ_TIMEOUT_SECS,
|
||||
ge=0,
|
||||
description="SSE read timeout in seconds.",
|
||||
)
|
||||
discovered_tools: list[dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Server-managed cache of the MCP server's tool catalog "
|
||||
"[{name, description}]. Populated best-effort by the backend."
|
||||
),
|
||||
json_schema_extra=_llm_hint("Do not author this field; the server fills it."),
|
||||
)
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def validate_url(cls, v: str) -> str:
|
||||
if not isinstance(v, str) or not v.startswith(("http://", "https://")):
|
||||
raise ValueError("config.url must be an http(s) URL")
|
||||
return v
|
||||
|
||||
@field_validator("tools_filter")
|
||||
@classmethod
|
||||
def validate_tools_filter(cls, v: list[str]) -> list[str]:
|
||||
if not all(isinstance(tool_name, str) for tool_name in v):
|
||||
raise ValueError("config.tools_filter must be a list of strings")
|
||||
return v
|
||||
|
||||
|
||||
class HttpApiToolDefinition(BaseModel):
|
||||
"""Tool definition for HTTP API tools."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version.")
|
||||
type: Literal["http_api"] = Field(description="Tool type.")
|
||||
config: HttpApiConfig = Field(description="HTTP API configuration.")
|
||||
|
||||
|
||||
class EndCallToolDefinition(BaseModel):
|
||||
"""Tool definition for End Call tools."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version.")
|
||||
type: Literal["end_call"] = Field(description="Tool type.")
|
||||
config: EndCallConfig = Field(description="End Call configuration.")
|
||||
|
||||
|
||||
class TransferCallToolDefinition(BaseModel):
|
||||
"""Tool definition for Transfer Call tools."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version.")
|
||||
type: Literal["transfer_call"] = Field(description="Tool type.")
|
||||
config: TransferCallConfig = Field(description="Transfer Call configuration.")
|
||||
|
||||
|
||||
class CalculatorToolDefinition(BaseModel):
|
||||
"""Tool definition for Calculator tools."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version.")
|
||||
type: Literal["calculator"] = Field(description="Tool type.")
|
||||
|
||||
|
||||
class McpToolDefinition(BaseModel):
|
||||
"""Persisted MCP tool definition."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version.")
|
||||
type: Literal["mcp"] = Field(description="Tool type.")
|
||||
config: McpToolConfig = Field(description="MCP server configuration.")
|
||||
|
||||
|
||||
ToolDefinition = Annotated[
|
||||
Union[
|
||||
HttpApiToolDefinition,
|
||||
EndCallToolDefinition,
|
||||
TransferCallToolDefinition,
|
||||
CalculatorToolDefinition,
|
||||
McpToolDefinition,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class CreateToolRequest(BaseModel):
|
||||
"""Request schema for creating a reusable tool."""
|
||||
|
||||
name: str = Field(
|
||||
max_length=255,
|
||||
description="Display name for the tool.",
|
||||
json_schema_extra=_llm_hint(
|
||||
"Use a concise action-oriented name; this influences the function "
|
||||
"name shown to the agent."
|
||||
),
|
||||
)
|
||||
description: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Description shown to the agent when deciding whether to call it.",
|
||||
json_schema_extra=_llm_hint(
|
||||
"State exactly when the agent should call the tool and what result it gets."
|
||||
),
|
||||
)
|
||||
category: ToolCategoryValue = Field(
|
||||
default=ToolCategory.HTTP_API.value,
|
||||
description="Tool category. Must match definition.type.",
|
||||
)
|
||||
icon: Optional[str] = Field(
|
||||
default="globe", max_length=50, description="Lucide icon identifier."
|
||||
)
|
||||
icon_color: Optional[str] = Field(
|
||||
default="#3B82F6", max_length=7, description="Hex color for the tool icon."
|
||||
)
|
||||
definition: ToolDefinition = Field(description="Typed tool definition.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def default_category_from_definition(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
if data.get("category"):
|
||||
return data
|
||||
definition = data.get("definition")
|
||||
if isinstance(definition, dict) and definition.get("type"):
|
||||
return {**data, "category": definition["type"]}
|
||||
return data
|
||||
|
||||
@field_validator("category")
|
||||
@classmethod
|
||||
def validate_category(cls, v: str) -> str:
|
||||
valid_categories = [c.value for c in ToolCategory]
|
||||
if v not in valid_categories:
|
||||
raise ValueError(
|
||||
f"Invalid category '{v}'. Must be one of: {', '.join(valid_categories)}"
|
||||
)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_category_matches_definition(self) -> "CreateToolRequest":
|
||||
definition_type = self.definition.type
|
||||
if self.category != definition_type:
|
||||
raise ValueError(
|
||||
f"category '{self.category}' must match definition.type "
|
||||
f"'{definition_type}'"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class UpdateToolRequest(BaseModel):
|
||||
"""Request schema for updating a reusable tool."""
|
||||
|
||||
name: Optional[str] = Field(default=None, max_length=255)
|
||||
description: Optional[str] = None
|
||||
icon: Optional[str] = Field(default=None, max_length=50)
|
||||
icon_color: Optional[str] = Field(default=None, max_length=7)
|
||||
definition: Optional[ToolDefinition] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class CreatedByResponse(BaseModel):
|
||||
"""Response schema for the user who created a tool."""
|
||||
|
||||
id: int
|
||||
provider_id: str
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
"""Response schema for a reusable tool."""
|
||||
|
||||
id: int
|
||||
tool_uuid: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
category: str
|
||||
icon: Optional[str]
|
||||
icon_color: Optional[str]
|
||||
status: str
|
||||
definition: Dict[str, Any]
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
created_by: Optional[CreatedByResponse] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class McpRefreshResponse(BaseModel):
|
||||
"""Result of re-discovering an MCP server's tool catalog."""
|
||||
|
||||
tool_uuid: str
|
||||
discovered_tools: list = Field(default_factory=list)
|
||||
error: Optional[str] = None
|
||||
251
api/services/tool_management.py
Normal file
251
api/services/tool_management.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
"""Service layer for reusable tool management.
|
||||
|
||||
Routes and MCP tools both use this module so validation, credential
|
||||
scoping, MCP discovery, and analytics stay consistent.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import PostHogEvent, ToolCategory
|
||||
from api.schemas.tool import (
|
||||
CreatedByResponse,
|
||||
CreateToolRequest,
|
||||
McpRefreshResponse,
|
||||
ToolResponse,
|
||||
)
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.workflow.mcp_tool_session import discover_mcp_tools
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
McpDefinitionError,
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
||||
|
||||
class ToolManagementError(ValueError):
|
||||
"""Recoverable tool-management error with an MCP/HTTP friendly code."""
|
||||
|
||||
def __init__(self, error_code: str, message: str, *, status_code: int = 400):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
def build_tool_response(tool: Any, include_created_by: bool = False) -> ToolResponse:
|
||||
"""Build a public response from a ToolModel-like object."""
|
||||
created_by = None
|
||||
if include_created_by and tool.created_by_user:
|
||||
created_by = CreatedByResponse(
|
||||
id=tool.created_by_user.id,
|
||||
provider_id=tool.created_by_user.provider_id,
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
id=tool.id,
|
||||
tool_uuid=tool.tool_uuid,
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
category=tool.category,
|
||||
icon=tool.icon,
|
||||
icon_color=tool.icon_color,
|
||||
status=tool.status,
|
||||
definition=tool.definition,
|
||||
created_at=tool.created_at,
|
||||
updated_at=tool.updated_at,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
|
||||
def _credential_uuid_from_definition(definition: dict[str, Any]) -> Optional[str]:
|
||||
config = definition.get("config")
|
||||
if not isinstance(config, dict):
|
||||
return None
|
||||
credential_uuid = config.get("credential_uuid")
|
||||
return credential_uuid if isinstance(credential_uuid, str) else None
|
||||
|
||||
|
||||
async def fetch_credential(credential_uuid: Optional[str], organization_id: int):
|
||||
"""Best-effort credential lookup for MCP auth/discovery."""
|
||||
if not credential_uuid:
|
||||
return None
|
||||
try:
|
||||
return await db_client.get_credential_by_uuid(credential_uuid, organization_id)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(f"Tool credential fetch failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def validate_tool_credential_references(
|
||||
definition: dict[str, Any], *, organization_id: int
|
||||
) -> None:
|
||||
"""Ensure credential UUID references belong to the caller's organization."""
|
||||
credential_uuid = _credential_uuid_from_definition(definition)
|
||||
if not credential_uuid:
|
||||
return
|
||||
|
||||
credential = await db_client.get_credential_by_uuid(
|
||||
credential_uuid, organization_id
|
||||
)
|
||||
if not credential:
|
||||
raise ToolManagementError(
|
||||
"credential_not_found",
|
||||
(
|
||||
f"Credential '{credential_uuid}' was not found in this organization. "
|
||||
"Create it in the UI first, then retry with its credential_uuid."
|
||||
),
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
|
||||
async def populate_discovered_tools(
|
||||
definition: dict[str, Any], *, organization_id: int
|
||||
) -> dict[str, Any]:
|
||||
"""Best-effort MCP discovery before saving a tool definition.
|
||||
|
||||
Non-MCP definitions pass through untouched. For MCP definitions, a dead
|
||||
server yields ``discovered_tools: []`` and does not block creation.
|
||||
"""
|
||||
if not isinstance(definition, dict) or definition.get("type") != "mcp":
|
||||
return definition
|
||||
try:
|
||||
cfg = validate_mcp_definition(definition)
|
||||
except McpDefinitionError:
|
||||
return definition
|
||||
|
||||
credential = await fetch_credential(cfg.get("credential_uuid"), organization_id)
|
||||
|
||||
async def _run() -> list:
|
||||
try:
|
||||
return await discover_mcp_tools(
|
||||
url=cfg["url"],
|
||||
credential=credential,
|
||||
timeout_secs=cfg["timeout_secs"],
|
||||
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
||||
)
|
||||
except BaseException as e: # noqa: BLE001
|
||||
logger.warning(f"MCP discovery failed; caching empty list: {e}")
|
||||
return []
|
||||
|
||||
discovered = await asyncio.ensure_future(_run())
|
||||
definition["config"]["discovered_tools"] = discovered
|
||||
return definition
|
||||
|
||||
|
||||
async def create_tool_for_user(
|
||||
request: CreateToolRequest,
|
||||
user: UserModel,
|
||||
*,
|
||||
source: str = "api",
|
||||
) -> ToolResponse:
|
||||
"""Create a reusable tool for the authenticated user's selected org."""
|
||||
if not user.selected_organization_id:
|
||||
raise ToolManagementError(
|
||||
"organization_required",
|
||||
"No organization selected for the user",
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
definition = request.definition.model_dump()
|
||||
await validate_tool_credential_references(
|
||||
definition, organization_id=user.selected_organization_id
|
||||
)
|
||||
definition = await populate_discovered_tools(
|
||||
definition,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
tool = await db_client.create_tool(
|
||||
organization_id=user.selected_organization_id,
|
||||
user_id=user.id,
|
||||
name=request.name,
|
||||
definition=definition,
|
||||
category=request.category,
|
||||
description=request.description,
|
||||
icon=request.icon,
|
||||
icon_color=request.icon_color,
|
||||
)
|
||||
|
||||
capture_event(
|
||||
distinct_id=str(user.provider_id),
|
||||
event=PostHogEvent.TOOL_CREATED,
|
||||
properties={
|
||||
"tool_name": request.name,
|
||||
"tool_category": request.category,
|
||||
"source": source,
|
||||
"organization_id": user.selected_organization_id,
|
||||
},
|
||||
)
|
||||
|
||||
return build_tool_response(tool)
|
||||
|
||||
|
||||
async def refresh_mcp_tool_for_user(
|
||||
tool_uuid: str,
|
||||
user: UserModel,
|
||||
) -> McpRefreshResponse:
|
||||
"""Refresh cached MCP catalog for a tool owned by the user's org."""
|
||||
if not user.selected_organization_id:
|
||||
raise ToolManagementError(
|
||||
"organization_required",
|
||||
"No organization selected for the user",
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
tool = await db_client.get_tool_by_uuid(
|
||||
tool_uuid, user.selected_organization_id, include_archived=True
|
||||
)
|
||||
if not tool:
|
||||
raise ToolManagementError("tool_not_found", "Tool not found", status_code=404)
|
||||
if tool.category != ToolCategory.MCP.value:
|
||||
raise ToolManagementError(
|
||||
"not_mcp_tool", "Tool is not an MCP tool", status_code=400
|
||||
)
|
||||
|
||||
try:
|
||||
cfg = validate_mcp_definition(tool.definition)
|
||||
except McpDefinitionError as e:
|
||||
raise ToolManagementError(
|
||||
"invalid_mcp_definition",
|
||||
f"Invalid MCP definition: {e}",
|
||||
status_code=400,
|
||||
) from e
|
||||
|
||||
credential = await fetch_credential(
|
||||
cfg.get("credential_uuid"), user.selected_organization_id
|
||||
)
|
||||
|
||||
try:
|
||||
discovered = await discover_mcp_tools(
|
||||
url=cfg["url"],
|
||||
credential=credential,
|
||||
timeout_secs=cfg["timeout_secs"],
|
||||
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(f"MCP refresh discovery failed: {e}")
|
||||
discovered = []
|
||||
|
||||
if not discovered:
|
||||
error = (
|
||||
f"Could not reach the MCP server at {cfg['url']} "
|
||||
f"(or it exposes no tools). Previously cached list retained."
|
||||
)
|
||||
return McpRefreshResponse(tool_uuid=tool_uuid, discovered_tools=[], error=error)
|
||||
|
||||
new_def = dict(tool.definition or {})
|
||||
new_def["config"] = {**new_def.get("config", {}), "discovered_tools": discovered}
|
||||
await db_client.update_tool(
|
||||
tool_uuid=tool_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
definition=new_def,
|
||||
)
|
||||
return McpRefreshResponse(
|
||||
tool_uuid=tool_uuid, discovered_tools=discovered, error=None
|
||||
)
|
||||
|
|
@ -4,70 +4,27 @@ LLM-function-name namespacing. No I/O, no MCP protocol here."""
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from pydantic import ValidationError
|
||||
|
||||
DEFAULT_TIMEOUT_SECS = 30
|
||||
DEFAULT_SSE_READ_TIMEOUT_SECS = 300
|
||||
from api.schemas.tool import (
|
||||
DEFAULT_MCP_SSE_READ_TIMEOUT_SECS,
|
||||
DEFAULT_MCP_TIMEOUT_SECS,
|
||||
McpToolDefinition,
|
||||
)
|
||||
from api.schemas.tool import (
|
||||
McpToolConfig as McpToolConfig,
|
||||
)
|
||||
|
||||
DEFAULT_TIMEOUT_SECS = DEFAULT_MCP_TIMEOUT_SECS
|
||||
DEFAULT_SSE_READ_TIMEOUT_SECS = DEFAULT_MCP_SSE_READ_TIMEOUT_SECS
|
||||
|
||||
|
||||
class McpDefinitionError(ValueError):
|
||||
"""Raised when an MCP tool definition is structurally invalid."""
|
||||
|
||||
|
||||
class McpToolConfig(BaseModel):
|
||||
"""Configuration for an MCP tool definition."""
|
||||
|
||||
transport: Literal["streamable_http"] = Field(
|
||||
default="streamable_http", description="MCP transport protocol"
|
||||
)
|
||||
url: str = Field(description="MCP server URL (must be http:// or https://)")
|
||||
credential_uuid: Optional[str] = Field(
|
||||
default=None, description="Reference to ExternalCredentialModel for auth"
|
||||
)
|
||||
tools_filter: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Allowlist of MCP tool names to expose (empty = all tools)",
|
||||
)
|
||||
timeout_secs: int = Field(
|
||||
default=DEFAULT_TIMEOUT_SECS, description="Connection timeout in seconds"
|
||||
)
|
||||
sse_read_timeout_secs: int = Field(
|
||||
default=DEFAULT_SSE_READ_TIMEOUT_SECS,
|
||||
description="SSE read timeout in seconds",
|
||||
)
|
||||
discovered_tools: list[dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Server-managed cache of the MCP server's tool catalog "
|
||||
"[{name, description}]. Populated best-effort by the backend."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def validate_url(cls, v: str) -> str:
|
||||
if not isinstance(v, str) or not v.startswith(("http://", "https://")):
|
||||
raise ValueError("config.url must be an http(s) URL")
|
||||
return v
|
||||
|
||||
@field_validator("tools_filter")
|
||||
@classmethod
|
||||
def validate_tools_filter(cls, v: list[str]) -> list[str]:
|
||||
if not all(isinstance(tool_name, str) for tool_name in v):
|
||||
raise ValueError("config.tools_filter must be a list of strings")
|
||||
return v
|
||||
|
||||
|
||||
class McpToolDefinition(BaseModel):
|
||||
"""Persisted MCP tool definition."""
|
||||
|
||||
schema_version: int = Field(default=1, description="Schema version")
|
||||
type: Literal["mcp"] = Field(description="Tool type")
|
||||
config: McpToolConfig = Field(description="MCP server configuration")
|
||||
|
||||
|
||||
def _format_validation_error(error: ValidationError) -> str:
|
||||
parts: list[str] = []
|
||||
for item in error.errors():
|
||||
|
|
|
|||
164
api/tests/test_mcp_tool_creation.py
Normal file
164
api/tests/test_mcp_tool_creation.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from api.app import app
|
||||
from api.mcp_server.server import mcp
|
||||
from api.mcp_server.tools.tool_creation import create_tool
|
||||
from api.schemas.tool import CreateToolRequest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authed_user() -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = 11
|
||||
user.provider_id = "provider-11"
|
||||
user.selected_organization_id = 22
|
||||
return user
|
||||
|
||||
|
||||
def _tool_model(**overrides):
|
||||
now = datetime.now(UTC)
|
||||
values = {
|
||||
"id": 3,
|
||||
"tool_uuid": "tool-uuid-3",
|
||||
"name": "Lookup Account",
|
||||
"description": "Lookup an account by phone number",
|
||||
"category": "http_api",
|
||||
"icon": "globe",
|
||||
"icon_color": "#3B82F6",
|
||||
"status": "active",
|
||||
"definition": {
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {"method": "POST", "url": "https://api.example.com/lookup"},
|
||||
},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
values.update(overrides)
|
||||
return SimpleNamespace(**values)
|
||||
|
||||
|
||||
def _http_tool_request(**config_overrides) -> CreateToolRequest:
|
||||
config = {"method": "post", "url": "https://api.example.com/lookup"}
|
||||
config.update(config_overrides)
|
||||
return CreateToolRequest(
|
||||
name="Lookup Account",
|
||||
description="Lookup an account by phone number",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": config,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_create_tool_creates_reusable_tool(authed_user: MagicMock):
|
||||
create_tool_mock = AsyncMock(return_value=_tool_model())
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.tools.tool_creation.authenticate_mcp_request",
|
||||
AsyncMock(return_value=authed_user),
|
||||
),
|
||||
patch(
|
||||
"api.services.tool_management.db_client.create_tool",
|
||||
create_tool_mock,
|
||||
),
|
||||
patch("api.services.tool_management.capture_event") as capture_event_mock,
|
||||
):
|
||||
result = await create_tool(_http_tool_request())
|
||||
|
||||
assert result["created"] is True
|
||||
assert result["tool_uuid"] == "tool-uuid-3"
|
||||
assert result["category"] == "http_api"
|
||||
create_tool_mock.assert_awaited_once()
|
||||
assert create_tool_mock.call_args.kwargs["organization_id"] == 22
|
||||
assert create_tool_mock.call_args.kwargs["user_id"] == 11
|
||||
assert create_tool_mock.call_args.kwargs["definition"]["config"]["method"] == "POST"
|
||||
capture_event_mock.assert_called_once()
|
||||
assert capture_event_mock.call_args.kwargs["properties"]["source"] == "mcp"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_create_tool_rejects_unknown_credential(authed_user: MagicMock):
|
||||
create_tool_mock = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.tools.tool_creation.authenticate_mcp_request",
|
||||
AsyncMock(return_value=authed_user),
|
||||
),
|
||||
patch(
|
||||
"api.services.tool_management.db_client.get_credential_by_uuid",
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"api.services.tool_management.db_client.create_tool",
|
||||
create_tool_mock,
|
||||
),
|
||||
):
|
||||
result = await create_tool(_http_tool_request(credential_uuid="cred-missing"))
|
||||
|
||||
assert result["created"] is False
|
||||
assert result["error_code"] == "credential_not_found"
|
||||
create_tool_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_sdk_openapi_exposes_create_tool_schema_and_llm_hints():
|
||||
sdk_routes = [
|
||||
r
|
||||
for r in app.routes
|
||||
if getattr(r, "openapi_extra", None)
|
||||
and "x-sdk-method" in (r.openapi_extra or {})
|
||||
]
|
||||
spec = get_openapi(title=app.title, version=app.version, routes=sdk_routes)
|
||||
operations = [
|
||||
op
|
||||
for path_item in spec["paths"].values()
|
||||
for op in path_item.values()
|
||||
if isinstance(op, dict)
|
||||
]
|
||||
assert any(op.get("x-sdk-method") == "create_tool" for op in operations)
|
||||
|
||||
credential_schema = spec["components"]["schemas"]["HttpApiConfig"]["properties"][
|
||||
"credential_uuid"
|
||||
]
|
||||
assert "list_credentials" in credential_schema["llm_hint"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_create_tool_schema_includes_validation_and_llm_hints():
|
||||
tools = await mcp.list_tools()
|
||||
create_tool_spec = next(t for t in tools if t.name == "create_tool")
|
||||
|
||||
request_schema = create_tool_spec.parameters["properties"]["request"]
|
||||
definition_schema = request_schema["properties"]["definition"]
|
||||
http_config = definition_schema["oneOf"][0]["properties"]["config"]
|
||||
|
||||
assert request_schema["properties"]["category"]["enum"] == [
|
||||
"http_api",
|
||||
"end_call",
|
||||
"transfer_call",
|
||||
"calculator",
|
||||
"native",
|
||||
"integration",
|
||||
"mcp",
|
||||
]
|
||||
assert http_config["properties"]["method"]["enum"] == [
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
]
|
||||
assert (
|
||||
"list_credentials" in http_config["properties"]["credential_uuid"]["llm_hint"]
|
||||
)
|
||||
|
|
@ -16,10 +16,20 @@ Test coverage:
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.routes.tool import CreateToolRequest, McpToolDefinition, UpdateToolRequest
|
||||
from api.routes.tool import (
|
||||
CreateToolRequest,
|
||||
McpToolConfig,
|
||||
McpToolDefinition,
|
||||
UpdateToolRequest,
|
||||
_populate_discovered_tools,
|
||||
refresh_mcp_tools,
|
||||
)
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
|
@ -279,10 +289,6 @@ async def test_post_tool_mcp_invalid_url_returns_422(test_client_factory, db_ses
|
|||
|
||||
# ── Task 6: discovered_tools field and _populate_discovered_tools helper ──────
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from api.routes.tool import McpToolConfig, _populate_discovered_tools
|
||||
|
||||
|
||||
def test_mcp_config_accepts_discovered_tools():
|
||||
cfg = McpToolConfig(
|
||||
|
|
@ -296,10 +302,10 @@ def test_mcp_config_accepts_discovered_tools():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_overwrites_cache(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
import api.services.tool_management as tool_svc
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
tool_svc,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
|
||||
)
|
||||
|
|
@ -327,10 +333,10 @@ async def test_populate_discovered_tools_non_mcp_is_noop():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_server_down_sets_empty(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
import api.services.tool_management as tool_svc
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
tool_svc,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(side_effect=RuntimeError("connection refused")),
|
||||
)
|
||||
|
|
@ -345,10 +351,6 @@ async def test_populate_discovered_tools_server_down_sets_empty(monkeypatch):
|
|||
|
||||
# ── Task 7: POST /{tool_uuid}/mcp/refresh ─────────────────────────────────────
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.routes.tool import refresh_mcp_tools
|
||||
|
||||
|
||||
def _fake_user(org_id=1):
|
||||
u = MagicMock()
|
||||
|
|
@ -373,19 +375,19 @@ def _mcp_tool_model(org_id=1):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_success(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
import api.services.tool_management as tool_svc
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
tool_svc.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client,
|
||||
tool_svc.db_client,
|
||||
"update_tool",
|
||||
AsyncMock(return_value=tool),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
tool_svc,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
|
||||
)
|
||||
|
|
@ -396,29 +398,29 @@ async def test_refresh_success(monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_server_down_returns_200_with_error(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
import api.services.tool_management as tool_svc
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
tool_svc.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
monkeypatch.setattr(tool_mod.db_client, "update_tool", AsyncMock(return_value=tool))
|
||||
monkeypatch.setattr(tool_mod, "discover_mcp_tools", AsyncMock(return_value=[]))
|
||||
monkeypatch.setattr(tool_svc.db_client, "update_tool", AsyncMock(return_value=tool))
|
||||
monkeypatch.setattr(tool_svc, "discover_mcp_tools", AsyncMock(return_value=[]))
|
||||
resp = await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert resp.discovered_tools == []
|
||||
assert resp.error # non-empty human-readable message
|
||||
# update_tool should NOT be called when discovery returns empty
|
||||
tool_mod.db_client.update_tool.assert_not_called()
|
||||
tool_svc.db_client.update_tool.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_non_mcp_is_400(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
import api.services.tool_management as tool_svc
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
tool.category = "http_api"
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
tool_svc.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
|
|
@ -427,10 +429,10 @@ async def test_refresh_non_mcp_is_400(monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_not_found_is_404(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
import api.services.tool_management as tool_svc
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=None)
|
||||
tool_svc.db_client, "get_tool_by_uuid", AsyncMock(return_value=None)
|
||||
)
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
await refresh_mcp_tools("nope", user=_fake_user())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue