mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
251 lines
8 KiB
Python
251 lines
8 KiB
Python
"""Service layer for reusable tool management.
|
|
|
|
Routes and MCP tools both use this module so validation, credential
|
|
scoping, MCP discovery, and analytics stay consistent.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import Any, Optional
|
|
|
|
from loguru import logger
|
|
|
|
from api.db import db_client
|
|
from api.db.models import UserModel
|
|
from api.enums import PostHogEvent, ToolCategory
|
|
from api.schemas.tool import (
|
|
CreatedByResponse,
|
|
CreateToolRequest,
|
|
McpRefreshResponse,
|
|
ToolResponse,
|
|
)
|
|
from api.services.posthog_client import capture_event
|
|
from api.services.workflow.mcp_tool_session import discover_mcp_tools
|
|
from api.services.workflow.tools.mcp_tool import (
|
|
McpDefinitionError,
|
|
validate_mcp_definition,
|
|
)
|
|
|
|
|
|
class ToolManagementError(ValueError):
|
|
"""Recoverable tool-management error with an MCP/HTTP friendly code."""
|
|
|
|
def __init__(self, error_code: str, message: str, *, status_code: int = 400):
|
|
super().__init__(message)
|
|
self.error_code = error_code
|
|
self.message = message
|
|
self.status_code = status_code
|
|
|
|
|
|
def build_tool_response(tool: Any, include_created_by: bool = False) -> ToolResponse:
|
|
"""Build a public response from a ToolModel-like object."""
|
|
created_by = None
|
|
if include_created_by and tool.created_by_user:
|
|
created_by = CreatedByResponse(
|
|
id=tool.created_by_user.id,
|
|
provider_id=tool.created_by_user.provider_id,
|
|
)
|
|
|
|
return ToolResponse(
|
|
id=tool.id,
|
|
tool_uuid=tool.tool_uuid,
|
|
name=tool.name,
|
|
description=tool.description,
|
|
category=tool.category,
|
|
icon=tool.icon,
|
|
icon_color=tool.icon_color,
|
|
status=tool.status,
|
|
definition=tool.definition,
|
|
created_at=tool.created_at,
|
|
updated_at=tool.updated_at,
|
|
created_by=created_by,
|
|
)
|
|
|
|
|
|
def _credential_uuid_from_definition(definition: dict[str, Any]) -> Optional[str]:
|
|
config = definition.get("config")
|
|
if not isinstance(config, dict):
|
|
return None
|
|
credential_uuid = config.get("credential_uuid")
|
|
return credential_uuid if isinstance(credential_uuid, str) else None
|
|
|
|
|
|
async def fetch_credential(credential_uuid: Optional[str], organization_id: int):
|
|
"""Best-effort credential lookup for MCP auth/discovery."""
|
|
if not credential_uuid:
|
|
return None
|
|
try:
|
|
return await db_client.get_credential_by_uuid(credential_uuid, organization_id)
|
|
except Exception as e: # noqa: BLE001
|
|
logger.warning(f"Tool credential fetch failed: {e}")
|
|
return None
|
|
|
|
|
|
async def validate_tool_credential_references(
|
|
definition: dict[str, Any], *, organization_id: int
|
|
) -> None:
|
|
"""Ensure credential UUID references belong to the caller's organization."""
|
|
credential_uuid = _credential_uuid_from_definition(definition)
|
|
if not credential_uuid:
|
|
return
|
|
|
|
credential = await db_client.get_credential_by_uuid(
|
|
credential_uuid, organization_id
|
|
)
|
|
if not credential:
|
|
raise ToolManagementError(
|
|
"credential_not_found",
|
|
(
|
|
f"Credential '{credential_uuid}' was not found in this organization. "
|
|
"Create it in the UI first, then retry with its credential_uuid."
|
|
),
|
|
status_code=404,
|
|
)
|
|
|
|
|
|
async def populate_discovered_tools(
|
|
definition: dict[str, Any], *, organization_id: int
|
|
) -> dict[str, Any]:
|
|
"""Best-effort MCP discovery before saving a tool definition.
|
|
|
|
Non-MCP definitions pass through untouched. For MCP definitions, a dead
|
|
server yields ``discovered_tools: []`` and does not block creation.
|
|
"""
|
|
if not isinstance(definition, dict) or definition.get("type") != "mcp":
|
|
return definition
|
|
try:
|
|
cfg = validate_mcp_definition(definition)
|
|
except McpDefinitionError:
|
|
return definition
|
|
|
|
credential = await fetch_credential(cfg.get("credential_uuid"), organization_id)
|
|
|
|
async def _run() -> list:
|
|
try:
|
|
return await discover_mcp_tools(
|
|
url=cfg["url"],
|
|
credential=credential,
|
|
timeout_secs=cfg["timeout_secs"],
|
|
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
|
)
|
|
except BaseException as e: # noqa: BLE001
|
|
logger.warning(f"MCP discovery failed; caching empty list: {e}")
|
|
return []
|
|
|
|
discovered = await asyncio.ensure_future(_run())
|
|
definition["config"]["discovered_tools"] = discovered
|
|
return definition
|
|
|
|
|
|
async def create_tool_for_user(
|
|
request: CreateToolRequest,
|
|
user: UserModel,
|
|
*,
|
|
source: str = "api",
|
|
) -> ToolResponse:
|
|
"""Create a reusable tool for the authenticated user's selected org."""
|
|
if not user.selected_organization_id:
|
|
raise ToolManagementError(
|
|
"organization_required",
|
|
"No organization selected for the user",
|
|
status_code=400,
|
|
)
|
|
|
|
definition = request.definition.model_dump()
|
|
await validate_tool_credential_references(
|
|
definition, organization_id=user.selected_organization_id
|
|
)
|
|
definition = await populate_discovered_tools(
|
|
definition,
|
|
organization_id=user.selected_organization_id,
|
|
)
|
|
|
|
tool = await db_client.create_tool(
|
|
organization_id=user.selected_organization_id,
|
|
user_id=user.id,
|
|
name=request.name,
|
|
definition=definition,
|
|
category=request.category,
|
|
description=request.description,
|
|
icon=request.icon,
|
|
icon_color=request.icon_color,
|
|
)
|
|
|
|
capture_event(
|
|
distinct_id=str(user.provider_id),
|
|
event=PostHogEvent.TOOL_CREATED,
|
|
properties={
|
|
"tool_name": request.name,
|
|
"tool_category": request.category,
|
|
"source": source,
|
|
"organization_id": user.selected_organization_id,
|
|
},
|
|
)
|
|
|
|
return build_tool_response(tool)
|
|
|
|
|
|
async def refresh_mcp_tool_for_user(
|
|
tool_uuid: str,
|
|
user: UserModel,
|
|
) -> McpRefreshResponse:
|
|
"""Refresh cached MCP catalog for a tool owned by the user's org."""
|
|
if not user.selected_organization_id:
|
|
raise ToolManagementError(
|
|
"organization_required",
|
|
"No organization selected for the user",
|
|
status_code=400,
|
|
)
|
|
|
|
tool = await db_client.get_tool_by_uuid(
|
|
tool_uuid, user.selected_organization_id, include_archived=True
|
|
)
|
|
if not tool:
|
|
raise ToolManagementError("tool_not_found", "Tool not found", status_code=404)
|
|
if tool.category != ToolCategory.MCP.value:
|
|
raise ToolManagementError(
|
|
"not_mcp_tool", "Tool is not an MCP tool", status_code=400
|
|
)
|
|
|
|
try:
|
|
cfg = validate_mcp_definition(tool.definition)
|
|
except McpDefinitionError as e:
|
|
raise ToolManagementError(
|
|
"invalid_mcp_definition",
|
|
f"Invalid MCP definition: {e}",
|
|
status_code=400,
|
|
) from e
|
|
|
|
credential = await fetch_credential(
|
|
cfg.get("credential_uuid"), user.selected_organization_id
|
|
)
|
|
|
|
try:
|
|
discovered = await discover_mcp_tools(
|
|
url=cfg["url"],
|
|
credential=credential,
|
|
timeout_secs=cfg["timeout_secs"],
|
|
sse_read_timeout_secs=cfg["sse_read_timeout_secs"],
|
|
)
|
|
except Exception as e: # noqa: BLE001
|
|
logger.warning(f"MCP refresh discovery failed: {e}")
|
|
discovered = []
|
|
|
|
if not discovered:
|
|
error = (
|
|
f"Could not reach the MCP server at {cfg['url']} "
|
|
f"(or it exposes no tools). Previously cached list retained."
|
|
)
|
|
return McpRefreshResponse(tool_uuid=tool_uuid, discovered_tools=[], error=error)
|
|
|
|
new_def = dict(tool.definition or {})
|
|
new_def["config"] = {**new_def.get("config", {}), "discovered_tools": discovered}
|
|
await db_client.update_tool(
|
|
tool_uuid=tool_uuid,
|
|
organization_id=user.selected_organization_id,
|
|
definition=new_def,
|
|
)
|
|
return McpRefreshResponse(
|
|
tool_uuid=tool_uuid, discovered_tools=discovered, error=None
|
|
)
|