dograh/api/services/tool_management.py
2026-05-31 16:50:44 +05:30

251 lines
8 KiB
Python

"""Service layer for reusable tool management.
Routes and MCP tools both use this module so validation, credential
scoping, MCP discovery, and analytics stay consistent.
"""
from __future__ import annotations
import asyncio
from typing import Any, Optional
from loguru import logger
from api.db import db_client
from api.db.models import UserModel
from api.enums import PostHogEvent, ToolCategory
from api.schemas.tool import (
CreatedByResponse,
CreateToolRequest,
McpRefreshResponse,
ToolResponse,
)
from api.services.posthog_client import capture_event
from api.services.workflow.mcp_tool_session import discover_mcp_tools
from api.services.workflow.tools.mcp_tool import (
McpDefinitionError,
validate_mcp_definition,
)
class ToolManagementError(ValueError):
"""Recoverable tool-management error with an MCP/HTTP friendly code."""
def __init__(self, error_code: str, message: str, *, status_code: int = 400):
super().__init__(message)
self.error_code = error_code
self.message = message
self.status_code = status_code
def build_tool_response(tool: Any, include_created_by: bool = False) -> ToolResponse:
"""Build a public response from a ToolModel-like object."""
created_by = None
if include_created_by and tool.created_by_user:
created_by = CreatedByResponse(
id=tool.created_by_user.id,
provider_id=tool.created_by_user.provider_id,
)
return ToolResponse(
id=tool.id,
tool_uuid=tool.tool_uuid,
name=tool.name,
description=tool.description,
category=tool.category,
icon=tool.icon,
icon_color=tool.icon_color,
status=tool.status,
definition=tool.definition,
created_at=tool.created_at,
updated_at=tool.updated_at,
created_by=created_by,
)
def _credential_uuid_from_definition(definition: dict[str, Any]) -> Optional[str]:
config = definition.get("config")
if not isinstance(config, dict):
return None
credential_uuid = config.get("credential_uuid")
return credential_uuid if isinstance(credential_uuid, str) else None
async def fetch_credential(credential_uuid: Optional[str], organization_id: int):
"""Best-effort credential lookup for MCP auth/discovery."""
if not credential_uuid:
return None
try:
return await db_client.get_credential_by_uuid(credential_uuid, organization_id)
except Exception as e: # noqa: BLE001
logger.warning(f"Tool credential fetch failed: {e}")
return None
async def validate_tool_credential_references(
definition: dict[str, Any], *, organization_id: int
) -> None:
"""Ensure credential UUID references belong to the caller's organization."""
credential_uuid = _credential_uuid_from_definition(definition)
if not credential_uuid:
return
credential = await db_client.get_credential_by_uuid(
credential_uuid, organization_id
)
if not credential:
raise ToolManagementError(
"credential_not_found",
(
f"Credential '{credential_uuid}' was not found in this organization. "
"Create it in the UI first, then retry with its credential_uuid."
),
status_code=404,
)
async def populate_discovered_tools(
definition: dict[str, Any], *, organization_id: int
) -> dict[str, Any]:
"""Best-effort MCP discovery before saving a tool definition.
Non-MCP definitions pass through untouched. For MCP definitions, a dead
server yields ``discovered_tools: []`` and does not block creation.
"""
if not isinstance(definition, dict) or definition.get("type") != "mcp":
return definition
try:
cfg = validate_mcp_definition(definition)
except McpDefinitionError:
return definition
credential = await fetch_credential(cfg.get("credential_uuid"), organization_id)
async def _run() -> list:
try:
return await discover_mcp_tools(
url=cfg["url"],
credential=credential,
timeout_secs=cfg["timeout_secs"],
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
)
except BaseException as e: # noqa: BLE001
logger.warning(f"MCP discovery failed; caching empty list: {e}")
return []
discovered = await asyncio.ensure_future(_run())
definition["config"]["discovered_tools"] = discovered
return definition
async def create_tool_for_user(
request: CreateToolRequest,
user: UserModel,
*,
source: str = "api",
) -> ToolResponse:
"""Create a reusable tool for the authenticated user's selected org."""
if not user.selected_organization_id:
raise ToolManagementError(
"organization_required",
"No organization selected for the user",
status_code=400,
)
definition = request.definition.model_dump()
await validate_tool_credential_references(
definition, organization_id=user.selected_organization_id
)
definition = await populate_discovered_tools(
definition,
organization_id=user.selected_organization_id,
)
tool = await db_client.create_tool(
organization_id=user.selected_organization_id,
user_id=user.id,
name=request.name,
definition=definition,
category=request.category,
description=request.description,
icon=request.icon,
icon_color=request.icon_color,
)
capture_event(
distinct_id=str(user.provider_id),
event=PostHogEvent.TOOL_CREATED,
properties={
"tool_name": request.name,
"tool_category": request.category,
"source": source,
"organization_id": user.selected_organization_id,
},
)
return build_tool_response(tool)
async def refresh_mcp_tool_for_user(
tool_uuid: str,
user: UserModel,
) -> McpRefreshResponse:
"""Refresh cached MCP catalog for a tool owned by the user's org."""
if not user.selected_organization_id:
raise ToolManagementError(
"organization_required",
"No organization selected for the user",
status_code=400,
)
tool = await db_client.get_tool_by_uuid(
tool_uuid, user.selected_organization_id, include_archived=True
)
if not tool:
raise ToolManagementError("tool_not_found", "Tool not found", status_code=404)
if tool.category != ToolCategory.MCP.value:
raise ToolManagementError(
"not_mcp_tool", "Tool is not an MCP tool", status_code=400
)
try:
cfg = validate_mcp_definition(tool.definition)
except McpDefinitionError as e:
raise ToolManagementError(
"invalid_mcp_definition",
f"Invalid MCP definition: {e}",
status_code=400,
) from e
credential = await fetch_credential(
cfg.get("credential_uuid"), user.selected_organization_id
)
try:
discovered = await discover_mcp_tools(
url=cfg["url"],
credential=credential,
timeout_secs=cfg["timeout_secs"],
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
)
except Exception as e: # noqa: BLE001
logger.warning(f"MCP refresh discovery failed: {e}")
discovered = []
if not discovered:
error = (
f"Could not reach the MCP server at {cfg['url']} "
f"(or it exposes no tools). Previously cached list retained."
)
return McpRefreshResponse(tool_uuid=tool_uuid, discovered_tools=[], error=error)
new_def = dict(tool.definition or {})
new_def["config"] = {**new_def.get("config", {}), "discovered_tools": discovered}
await db_client.update_tool(
tool_uuid=tool_uuid,
organization_id=user.selected_organization_id,
definition=new_def,
)
return McpRefreshResponse(
tool_uuid=tool_uuid, discovered_tools=discovered, error=None
)