mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
feat: user defined custom tools as part of workflow execution (#94)
* feat: add custom tools functionality * Show tools in nodes * integrate tool calling with pipeline engine
This commit is contained in:
parent
cc2d3e70d2
commit
3e55af9256
65 changed files with 5483 additions and 6673 deletions
|
|
@ -57,6 +57,7 @@ class NodeDataDTO(BaseModel):
|
|||
detect_voicemail: bool = False
|
||||
delayed_start: bool = False
|
||||
delayed_start_duration: Optional[float] = None
|
||||
tool_uuids: Optional[List[str]] = None
|
||||
trigger_path: Optional[str] = None
|
||||
# Webhook node specific fields
|
||||
enabled: bool = True
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ import asyncio
|
|||
from loguru import logger
|
||||
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.services.workflow.pipecat_engine_utils import (
|
||||
get_function_schema,
|
||||
render_template,
|
||||
|
|
@ -105,6 +106,16 @@ class PipecatEngine:
|
|||
# Track current LLM reference text for TTS aggregation correction
|
||||
self._current_llm_reference_text: str = ""
|
||||
|
||||
# Custom tool manager (initialized in initialize())
|
||||
self._custom_tool_manager: Optional[CustomToolManager] = None
|
||||
|
||||
async def _get_organization_id(self) -> Optional[int]:
|
||||
"""Get and cache the organization ID from workflow run."""
|
||||
if self._custom_tool_manager:
|
||||
return await self._custom_tool_manager.get_organization_id()
|
||||
# Fallback for when manager is not yet initialized
|
||||
return await get_organization_id_from_workflow_run(self._workflow_run_id)
|
||||
|
||||
@property
|
||||
def builtin_function_schemas(self) -> list[dict]:
|
||||
"""Get built-in function schemas (calculator and timezone tools)."""
|
||||
|
|
@ -146,6 +157,9 @@ class PipecatEngine:
|
|||
# Helper that encapsulates variable extraction logic
|
||||
self._variable_extraction_manager = VariableExtractionManager(self)
|
||||
|
||||
# Helper that encapsulates custom tool management
|
||||
self._custom_tool_manager = CustomToolManager(self)
|
||||
|
||||
# Add current time in EST (America/New_York) to gathered context
|
||||
try:
|
||||
est_time_result = get_current_time("America/New_York")
|
||||
|
|
@ -360,6 +374,10 @@ class PipecatEngine:
|
|||
outgoing_edge.get_function_name(), outgoing_edge.target
|
||||
)
|
||||
|
||||
# Register custom tool handlers for this node
|
||||
if node.tool_uuids and self._custom_tool_manager:
|
||||
await self._custom_tool_manager.register_handlers(node.tool_uuids)
|
||||
|
||||
# Set up system message and functions
|
||||
(
|
||||
system_message,
|
||||
|
|
@ -492,9 +510,7 @@ class PipecatEngine:
|
|||
# Apply disposition mapping - first try call_disposition if it is,
|
||||
# extracted from the call conversation then fall back to reason
|
||||
call_disposition = self._gathered_context.get("call_disposition", "")
|
||||
organization_id = await get_organization_id_from_workflow_run(
|
||||
self._workflow_run_id
|
||||
)
|
||||
organization_id = await self._get_organization_id()
|
||||
|
||||
# If client is disconnected before we get a chance to disconnect from
|
||||
# the bot, lets consider that as final disposition
|
||||
|
|
@ -618,6 +634,13 @@ class PipecatEngine:
|
|||
# Add built-in function schemas (calculator and timezone tools)
|
||||
functions.extend(self.builtin_function_schemas)
|
||||
|
||||
# Add custom tools from node.tool_uuids
|
||||
if node.tool_uuids and self._custom_tool_manager:
|
||||
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(
|
||||
node.tool_uuids
|
||||
)
|
||||
functions.extend(custom_tool_schemas)
|
||||
|
||||
# Transition functions (schema only; registration handled elsewhere)
|
||||
for outgoing_edge in node.out_edges:
|
||||
function_schema = self._get_function_schema(
|
||||
|
|
|
|||
189
api/services/workflow/pipecat_engine_custom_tools.py
Normal file
189
api/services/workflow/pipecat_engine_custom_tools.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
"""Custom tool management for PipecatEngine.
|
||||
|
||||
This module handles fetching, registering, and executing user-defined tools
|
||||
during workflow execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.workflow.disposition_mapper import (
|
||||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_utils import get_function_schema
|
||||
from api.services.workflow.tools.custom_tool import (
|
||||
execute_http_tool,
|
||||
tool_to_function_schema,
|
||||
)
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.frames.frames import FunctionCallResultProperties
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
class CustomToolManager:
|
||||
"""Manager for custom tool registration and execution.
|
||||
|
||||
This class handles:
|
||||
1. Fetching tools from the database based on tool UUIDs
|
||||
2. Converting tools to LLM function schemas
|
||||
3. Registering tool execution handlers with the LLM
|
||||
4. Executing HTTP API tools when invoked by the LLM
|
||||
"""
|
||||
|
||||
def __init__(self, engine: "PipecatEngine") -> None:
|
||||
self._engine = engine
|
||||
self._organization_id: Optional[int] = None
|
||||
# Cache: maps function_name -> (tool, schema)
|
||||
self._tools_cache: dict[str, tuple[Any, dict]] = {}
|
||||
|
||||
async def get_organization_id(self) -> Optional[int]:
|
||||
"""Get and cache the organization ID from workflow run."""
|
||||
if self._organization_id is None:
|
||||
self._organization_id = await get_organization_id_from_workflow_run(
|
||||
self._engine._workflow_run_id
|
||||
)
|
||||
return self._organization_id
|
||||
|
||||
async def get_tool_schemas(self, tool_uuids: list[str]) -> list[FunctionSchema]:
|
||||
"""Fetch custom tools and convert them to function schemas.
|
||||
|
||||
Args:
|
||||
tool_uuids: List of tool UUIDs to fetch
|
||||
|
||||
Returns:
|
||||
List of FunctionSchema objects for LLM
|
||||
"""
|
||||
organization_id = await self.get_organization_id()
|
||||
if not organization_id:
|
||||
logger.warning("Cannot fetch custom tools: organization_id not available")
|
||||
return []
|
||||
|
||||
try:
|
||||
tools = await db_client.get_tools_by_uuids(tool_uuids, organization_id)
|
||||
|
||||
schemas: list[FunctionSchema] = []
|
||||
for tool in tools:
|
||||
raw_schema = tool_to_function_schema(tool)
|
||||
function_name = raw_schema["function"]["name"]
|
||||
|
||||
# Cache the tool for later execution
|
||||
self._tools_cache[function_name] = (tool, raw_schema)
|
||||
|
||||
# Convert to FunctionSchema object for compatibility with update_llm_context
|
||||
func_schema = get_function_schema(
|
||||
function_name,
|
||||
raw_schema["function"]["description"],
|
||||
properties=raw_schema["function"]["parameters"].get(
|
||||
"properties", {}
|
||||
),
|
||||
required=raw_schema["function"]["parameters"].get("required", []),
|
||||
)
|
||||
schemas.append(func_schema)
|
||||
|
||||
logger.debug(
|
||||
f"Loaded {len(schemas)} custom tools for node: "
|
||||
f"{[s.name for s in schemas]}"
|
||||
)
|
||||
return schemas
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch custom tools: {e}")
|
||||
return []
|
||||
|
||||
async def register_handlers(self, tool_uuids: list[str]) -> None:
|
||||
"""Register custom tool execution handlers with the LLM.
|
||||
|
||||
Args:
|
||||
tool_uuids: List of tool UUIDs to register handlers for
|
||||
"""
|
||||
organization_id = await self.get_organization_id()
|
||||
if not organization_id:
|
||||
logger.warning(
|
||||
"Cannot register custom tool handlers: organization_id not available"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
tools = await db_client.get_tools_by_uuids(tool_uuids, organization_id)
|
||||
|
||||
for tool in tools:
|
||||
schema = tool_to_function_schema(tool)
|
||||
function_name = schema["function"]["name"]
|
||||
|
||||
# Cache the tool for potential later use
|
||||
self._tools_cache[function_name] = (tool, schema)
|
||||
|
||||
# Create and register the handler
|
||||
handler = self._create_handler(tool, function_name)
|
||||
self._engine.llm.register_function(function_name, handler)
|
||||
|
||||
logger.debug(
|
||||
f"Registered custom tool handler: {function_name} "
|
||||
f"(tool_uuid: {tool.tool_uuid})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register custom tool handlers: {e}")
|
||||
|
||||
def _create_handler(self, tool: Any, function_name: str):
|
||||
"""Create a handler function for a custom tool.
|
||||
|
||||
Args:
|
||||
tool: The ToolModel instance
|
||||
function_name: The function name used by the LLM
|
||||
|
||||
Returns:
|
||||
Async handler function for the tool
|
||||
"""
|
||||
# Run LLM after tool execution to continue conversation
|
||||
properties = FunctionCallResultProperties(run_llm=True)
|
||||
|
||||
async def custom_tool_handler(
|
||||
function_call_params: FunctionCallParams,
|
||||
) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: {function_name}")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
|
||||
try:
|
||||
# Execute the HTTP API tool
|
||||
result = await execute_http_tool(
|
||||
tool=tool,
|
||||
arguments=function_call_params.arguments,
|
||||
call_context_vars=self._engine._call_context_vars,
|
||||
organization_id=self._organization_id,
|
||||
)
|
||||
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Custom tool '{function_name}' execution failed: {e}")
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": str(e)},
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
return custom_tool_handler
|
||||
|
||||
def get_cached_tool(self, function_name: str) -> Optional[tuple[Any, dict]]:
|
||||
"""Get a cached tool by its function name.
|
||||
|
||||
Args:
|
||||
function_name: The function name used by the LLM
|
||||
|
||||
Returns:
|
||||
Tuple of (tool, schema) if found, None otherwise
|
||||
"""
|
||||
return self._tools_cache.get(function_name)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the tools cache."""
|
||||
self._tools_cache.clear()
|
||||
180
api/services/workflow/tools/custom_tool.py
Normal file
180
api/services/workflow/tools/custom_tool.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
"""Custom tool execution for user-defined HTTP API tools."""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.utils.credential_auth import build_auth_header
|
||||
|
||||
# Map tool parameter types to JSON schema types
|
||||
TYPE_MAP = {
|
||||
"string": "string",
|
||||
"number": "number",
|
||||
"boolean": "boolean",
|
||||
}
|
||||
|
||||
|
||||
def tool_to_function_schema(tool: Any) -> Dict[str, Any]:
|
||||
"""Convert a ToolModel to an LLM function schema.
|
||||
|
||||
Args:
|
||||
tool: ToolModel instance with name, description, and definition
|
||||
|
||||
Returns:
|
||||
Function schema dict compatible with OpenAI/Anthropic function calling
|
||||
"""
|
||||
definition = tool.definition or {}
|
||||
config = definition.get("config", {})
|
||||
parameters = config.get("parameters", []) or []
|
||||
|
||||
# Build properties and required list from parameters
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param in parameters:
|
||||
param_name = param.get("name", "")
|
||||
param_type = param.get("type", "string")
|
||||
param_desc = param.get("description", "")
|
||||
param_required = param.get("required", True)
|
||||
|
||||
if not param_name:
|
||||
continue
|
||||
|
||||
properties[param_name] = {
|
||||
"type": TYPE_MAP.get(param_type, "string"),
|
||||
"description": param_desc,
|
||||
}
|
||||
|
||||
if param_required:
|
||||
required.append(param_name)
|
||||
|
||||
# Sanitize tool name for function name (lowercase, underscores only)
|
||||
function_name = re.sub(r"[^a-z0-9_]", "_", tool.name.lower())
|
||||
# Remove consecutive underscores and trim
|
||||
function_name = re.sub(r"_+", "_", function_name).strip("_")
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"description": tool.description or f"Execute {tool.name} tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
},
|
||||
},
|
||||
"_tool_uuid": tool.tool_uuid,
|
||||
}
|
||||
|
||||
|
||||
async def execute_http_tool(
|
||||
tool: Any,
|
||||
arguments: Dict[str, Any],
|
||||
call_context_vars: Optional[Dict[str, Any]] = None,
|
||||
organization_id: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute an HTTP API tool.
|
||||
|
||||
Args:
|
||||
tool: ToolModel instance
|
||||
arguments: Arguments passed by the LLM (parameter name -> value)
|
||||
call_context_vars: Additional context variables from the call (unused for now)
|
||||
organization_id: Organization ID for credential lookup
|
||||
|
||||
Returns:
|
||||
Result dict with response data or error
|
||||
"""
|
||||
definition = tool.definition or {}
|
||||
config = definition.get("config", {})
|
||||
|
||||
# Get HTTP method and URL
|
||||
method = config.get("method", "POST").upper()
|
||||
url = config.get("url", "")
|
||||
|
||||
# Get headers from config
|
||||
headers = dict(config.get("headers", {}) or {})
|
||||
|
||||
# Add auth header if credential is configured
|
||||
credential_uuid = config.get("credential_uuid")
|
||||
if credential_uuid and organization_id:
|
||||
try:
|
||||
credential = await db_client.get_credential_by_uuid(
|
||||
credential_uuid, organization_id
|
||||
)
|
||||
if credential:
|
||||
auth_header = build_auth_header(credential)
|
||||
headers.update(auth_header)
|
||||
logger.debug(f"Applied credential '{credential.name}' to tool request")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Credential {credential_uuid} not found for tool '{tool.name}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch credential for tool '{tool.name}': {e}")
|
||||
|
||||
# Get timeout
|
||||
timeout_ms = config.get("timeout_ms", 5000)
|
||||
timeout_seconds = timeout_ms / 1000
|
||||
|
||||
# Build request: JSON body for POST/PUT/PATCH, query params for GET/DELETE
|
||||
body = None
|
||||
params = None
|
||||
if method in ("POST", "PUT", "PATCH"):
|
||||
body = arguments
|
||||
elif method in ("GET", "DELETE") and arguments:
|
||||
params = arguments
|
||||
|
||||
logger.info(
|
||||
f"Executing custom tool '{tool.name}' ({tool.tool_uuid}): {method} {url}"
|
||||
)
|
||||
logger.debug(f"Request body: {body}, params: {params}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout_seconds) as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
json=body,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Try to parse JSON response
|
||||
try:
|
||||
response_data = response.json()
|
||||
except Exception:
|
||||
response_data = {"raw_response": response.text}
|
||||
|
||||
result = {
|
||||
"status": "success",
|
||||
"status_code": response.status_code,
|
||||
"data": response_data,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"Custom tool '{tool.name}' completed with status {response.status_code}"
|
||||
)
|
||||
return result
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error(f"Custom tool '{tool.name}' timed out after {timeout_seconds}s")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Request timed out after {timeout_seconds} seconds",
|
||||
}
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Custom tool '{tool.name}' request failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Request failed: {str(e)}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Custom tool '{tool.name}' execution failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Tool execution failed: {str(e)}",
|
||||
}
|
||||
|
|
@ -47,6 +47,7 @@ class Node:
|
|||
self.detect_voicemail = data.detect_voicemail
|
||||
self.delayed_start = data.delayed_start
|
||||
self.delayed_start_duration = data.delayed_start_duration
|
||||
self.tool_uuids = data.tool_uuids
|
||||
|
||||
self.data = data
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue