mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
feat: user defined custom tools as part of workflow execution (#94)
* feat: add custom tools functionality * Show tools in nodes * integrate tool calling with pipeline engine
This commit is contained in:
parent
cc2d3e70d2
commit
3e55af9256
65 changed files with 5483 additions and 6673 deletions
276
api/db/tool_client.py
Normal file
276
api/db/tool_client.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
"""Database client for managing tools."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import ToolModel
|
||||
from api.enums import ToolCategory, ToolStatus
|
||||
|
||||
|
||||
class ToolClient(BaseDBClient):
|
||||
"""Client for managing tools (organization-scoped, UUID-referenced)."""
|
||||
|
||||
async def create_tool(
|
||||
self,
|
||||
organization_id: int,
|
||||
user_id: int,
|
||||
name: str,
|
||||
definition: dict,
|
||||
category: str = ToolCategory.HTTP_API.value,
|
||||
description: Optional[str] = None,
|
||||
icon: Optional[str] = None,
|
||||
icon_color: Optional[str] = None,
|
||||
) -> ToolModel:
|
||||
"""Create a new tool.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
user_id: ID of the user creating the tool
|
||||
name: Display name for the tool
|
||||
definition: JSON definition of the tool
|
||||
category: Tool category (http_api, native, integration)
|
||||
description: Optional description
|
||||
icon: Optional icon identifier
|
||||
icon_color: Optional hex color code
|
||||
|
||||
Returns:
|
||||
The created ToolModel with auto-generated UUID
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
tool = ToolModel(
|
||||
organization_id=organization_id,
|
||||
created_by=user_id,
|
||||
name=name,
|
||||
description=description,
|
||||
category=category,
|
||||
icon=icon,
|
||||
icon_color=icon_color,
|
||||
definition=definition,
|
||||
status=ToolStatus.ACTIVE.value,
|
||||
)
|
||||
|
||||
session.add(tool)
|
||||
await session.commit()
|
||||
await session.refresh(tool)
|
||||
|
||||
logger.info(
|
||||
f"Created tool '{name}' ({tool.tool_uuid}) "
|
||||
f"for organization {organization_id}"
|
||||
)
|
||||
return tool
|
||||
|
||||
async def get_tools_for_organization(
|
||||
self,
|
||||
organization_id: int,
|
||||
status: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
) -> List[ToolModel]:
|
||||
"""Get all tools for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
status: Optional filter by status (active, archived, draft)
|
||||
category: Optional filter by category (http_api, native, integration)
|
||||
|
||||
Returns:
|
||||
List of ToolModel instances
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(ToolModel).where(
|
||||
ToolModel.organization_id == organization_id
|
||||
)
|
||||
|
||||
if status:
|
||||
query = query.where(ToolModel.status == status)
|
||||
else:
|
||||
# By default, exclude archived tools
|
||||
query = query.where(ToolModel.status != ToolStatus.ARCHIVED.value)
|
||||
|
||||
if category:
|
||||
query = query.where(ToolModel.category == category)
|
||||
|
||||
query = query.order_by(ToolModel.name)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_tool_by_uuid(
|
||||
self,
|
||||
tool_uuid: str,
|
||||
organization_id: int,
|
||||
include_archived: bool = False,
|
||||
) -> Optional[ToolModel]:
|
||||
"""Get a tool by its UUID, scoped to organization.
|
||||
|
||||
Args:
|
||||
tool_uuid: The unique tool UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
include_archived: If True, include archived tools
|
||||
|
||||
Returns:
|
||||
ToolModel if found and authorized, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(ToolModel)
|
||||
.where(
|
||||
ToolModel.tool_uuid == tool_uuid,
|
||||
ToolModel.organization_id == organization_id,
|
||||
)
|
||||
.options(selectinload(ToolModel.created_by_user))
|
||||
)
|
||||
|
||||
if not include_archived:
|
||||
query = query.where(ToolModel.status != ToolStatus.ARCHIVED.value)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_tool(
|
||||
self,
|
||||
tool_uuid: str,
|
||||
organization_id: int,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
definition: Optional[dict] = None,
|
||||
icon: Optional[str] = None,
|
||||
icon_color: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
) -> Optional[ToolModel]:
|
||||
"""Update a tool by UUID.
|
||||
|
||||
Args:
|
||||
tool_uuid: The unique tool UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
name: New name (if provided)
|
||||
description: New description (if provided)
|
||||
definition: New definition (if provided)
|
||||
icon: New icon (if provided)
|
||||
icon_color: New icon color (if provided)
|
||||
status: New status (if provided)
|
||||
|
||||
Returns:
|
||||
Updated ToolModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# First check if tool exists and belongs to organization
|
||||
tool = await self.get_tool_by_uuid(
|
||||
tool_uuid, organization_id, include_archived=True
|
||||
)
|
||||
if not tool:
|
||||
return None
|
||||
|
||||
# Build update values
|
||||
update_values = {"updated_at": datetime.now(UTC)}
|
||||
if name is not None:
|
||||
update_values["name"] = name
|
||||
if description is not None:
|
||||
update_values["description"] = description
|
||||
if definition is not None:
|
||||
update_values["definition"] = definition
|
||||
if icon is not None:
|
||||
update_values["icon"] = icon
|
||||
if icon_color is not None:
|
||||
update_values["icon_color"] = icon_color
|
||||
if status is not None:
|
||||
update_values["status"] = status
|
||||
|
||||
await session.execute(
|
||||
update(ToolModel)
|
||||
.where(
|
||||
ToolModel.tool_uuid == tool_uuid,
|
||||
ToolModel.organization_id == organization_id,
|
||||
)
|
||||
.values(**update_values)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Fetch updated tool
|
||||
result = await session.execute(
|
||||
select(ToolModel)
|
||||
.where(ToolModel.tool_uuid == tool_uuid)
|
||||
.options(selectinload(ToolModel.created_by_user))
|
||||
)
|
||||
updated_tool = result.scalar_one()
|
||||
|
||||
logger.info(f"Updated tool {tool_uuid} for organization {organization_id}")
|
||||
return updated_tool
|
||||
|
||||
async def archive_tool(self, tool_uuid: str, organization_id: int) -> bool:
|
||||
"""Soft delete a tool by setting its status to archived.
|
||||
|
||||
Args:
|
||||
tool_uuid: The unique tool UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
True if tool was archived, False if not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
update(ToolModel)
|
||||
.where(
|
||||
ToolModel.tool_uuid == tool_uuid,
|
||||
ToolModel.organization_id == organization_id,
|
||||
ToolModel.status != ToolStatus.ARCHIVED.value,
|
||||
)
|
||||
.values(
|
||||
status=ToolStatus.ARCHIVED.value,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
logger.info(
|
||||
f"Archived tool {tool_uuid} for organization {organization_id}"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def validate_tool_uuid(self, tool_uuid: str, organization_id: int) -> bool:
|
||||
"""Check if a tool UUID exists and belongs to the organization.
|
||||
|
||||
This is useful for workflow validation to ensure referenced tools exist.
|
||||
|
||||
Args:
|
||||
tool_uuid: The tool UUID to validate
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
tool = await self.get_tool_by_uuid(tool_uuid, organization_id)
|
||||
return tool is not None
|
||||
|
||||
async def get_tools_by_uuids(
|
||||
self,
|
||||
tool_uuids: List[str],
|
||||
organization_id: int,
|
||||
) -> List[ToolModel]:
|
||||
"""Get multiple tools by their UUIDs.
|
||||
|
||||
Args:
|
||||
tool_uuids: List of tool UUIDs to fetch
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
List of ToolModel instances (only active tools)
|
||||
"""
|
||||
if not tool_uuids:
|
||||
return []
|
||||
|
||||
async with self.async_session() as session:
|
||||
query = select(ToolModel).where(
|
||||
ToolModel.tool_uuid.in_(tool_uuids),
|
||||
ToolModel.organization_id == organization_id,
|
||||
ToolModel.status == ToolStatus.ACTIVE.value,
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
Loading…
Add table
Add a link
Reference in a new issue