feat: user defined custom tools as part of workflow execution (#94)

* feat: add custom tools functionality

* Show tools in nodes

* integrate tool calling with pipeline engine
This commit is contained in:
Abhishek 2026-01-02 13:11:02 +05:30 committed by GitHub
parent cc2d3e70d2
commit 3e55af9256
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 5483 additions and 6673 deletions

View file

@ -0,0 +1,92 @@
"""add tools model
Revision ID: ebc80cea7965
Revises: 36b5dbf670e4
Create Date: 2026-01-01 10:13:50.807135
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "ebc80cea7965"
down_revision: Union[str, None] = "36b5dbf670e4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
sa.Enum("active", "archived", "draft", name="tool_status").create(op.get_bind())
sa.Enum("http_api", "native", "integration", name="tool_category").create(
op.get_bind()
)
op.create_table(
"tools",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("tool_uuid", sa.String(length=36), nullable=False),
sa.Column("organization_id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.Column(
"category",
postgresql.ENUM(
"http_api",
"native",
"integration",
name="tool_category",
create_type=False,
),
nullable=False,
),
sa.Column("icon", sa.String(length=50), nullable=True),
sa.Column("icon_color", sa.String(length=7), nullable=True),
sa.Column(
"status",
postgresql.ENUM(
"active", "archived", "draft", name="tool_status", create_type=False
),
server_default=sa.text("'active'::tool_status"),
nullable=False,
),
sa.Column("definition", sa.JSON(), nullable=False),
sa.Column("created_by", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["created_by"],
["users.id"],
),
sa.ForeignKeyConstraint(
["organization_id"], ["organizations.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("organization_id", "name", name="unique_org_tool_name"),
)
op.create_index("ix_tools_category", "tools", ["category"], unique=False)
op.create_index(
"ix_tools_organization_id", "tools", ["organization_id"], unique=False
)
op.create_index("ix_tools_status", "tools", ["status"], unique=False)
op.create_index(op.f("ix_tools_tool_uuid"), "tools", ["tool_uuid"], unique=True)
op.create_index("ix_tools_uuid", "tools", ["tool_uuid"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_tools_uuid", table_name="tools")
op.drop_index(op.f("ix_tools_tool_uuid"), table_name="tools")
op.drop_index("ix_tools_status", table_name="tools")
op.drop_index("ix_tools_organization_id", table_name="tools")
op.drop_index("ix_tools_category", table_name="tools")
op.drop_table("tools")
sa.Enum("http_api", "native", "integration", name="tool_category").drop(
op.get_bind()
)
sa.Enum("active", "archived", "draft", name="tool_status").drop(op.get_bind())
# ### end Alembic commands ###

View file

@ -1,143 +1,315 @@
"""
Shared pytest fixtures for the API tests.
This file contains database setup, test client configuration, and utility fixtures
that can be reused across all test files.
Pytest configuration and fixtures for async database testing.
This module sets up the test infrastructure using:
- A separate test database (appends _test to the database name)
- Alembic migrations run once per test session
- Transaction-based isolation for each test (savepoint pattern)
References:
- https://www.core27.co/post/transactional-unit-tests-with-pytest-and-async-sqlalchemy
- https://docs.sqlalchemy.org/en/20/orm/session_transaction.html
"""
import os
import subprocess
import uuid
from pathlib import Path
from typing import AsyncGenerator
from urllib.parse import urlparse, urlunparse
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from loguru import logger
from sqlalchemy import text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
# Load environment variables before importing anything else
from dotenv import load_dotenv
from api.app import app
from api.db import db_client
# Load .env.test from api directory for test configuration
env_path = Path(__file__).parent / ".env.test"
load_dotenv(env_path)
# Test database setup globals
TEST_DATABASE_NAME = None
TEST_DATABASE_URL = None
import pytest
from sqlalchemy import event, text
from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import SessionTransaction
from sqlalchemy.pool import NullPool
@pytest_asyncio.fixture
async def test_database():
def get_test_database_url() -> str:
"""
Set up a temporary PostgreSQL database for testing.
This fixture creates a unique test database, runs migrations, and cleans up afterward.
Get the test database URL by appending _test to the database name.
Example:
postgresql+asyncpg://user:pass@host/mydb
-> postgresql+asyncpg://user:pass@host/mydb_test
"""
global TEST_DATABASE_NAME, TEST_DATABASE_URL
original_url = os.environ.get("DATABASE_URL")
if not original_url:
raise ValueError("DATABASE_URL environment variable is not set")
# Generate a unique test database name
TEST_DATABASE_NAME = f"test_dograh_{uuid.uuid4().hex[:8]}"
parsed = urlparse(original_url)
# Append _test to the database name (path without leading slash)
original_db_name = parsed.path.lstrip("/")
test_db_name = f"{original_db_name}_test"
# Get the base DATABASE_URL and parse it
base_url = os.environ.get("DATABASE_URL")
# Extract connection parts and replace database name
url_parts = base_url.split("/")
base_connection = "/".join(url_parts[:-1])
TEST_DATABASE_URL = f"{base_connection}/{TEST_DATABASE_NAME}"
# Reconstruct the URL with the new database name
test_url = urlunparse(
(
parsed.scheme,
parsed.netloc,
f"/{test_db_name}",
parsed.params,
parsed.query,
parsed.fragment,
)
)
return test_url
# Create a connection to the default postgres database to create our test database
default_engine = create_async_engine(base_url)
def get_base_database_url() -> str:
"""
Get base database URL (postgres) for creating/dropping test database.
"""
original_url = os.environ.get("DATABASE_URL")
parsed = urlparse(original_url)
# Connect to 'postgres' database for admin operations
base_url = urlunparse(
(
parsed.scheme,
parsed.netloc,
"/postgres",
parsed.params,
parsed.query,
parsed.fragment,
)
)
return base_url
def get_test_db_name() -> str:
"""Extract the test database name."""
original_url = os.environ.get("DATABASE_URL")
parsed = urlparse(original_url)
original_db_name = parsed.path.lstrip("/")
return f"{original_db_name}_test"
@pytest.fixture(scope="session")
async def setup_test_database():
"""
Session-scoped fixture that creates the test database and runs migrations.
This runs once at the start of the test session.
"""
test_db_name = get_test_db_name()
base_url = get_base_database_url()
test_url = get_test_database_url()
# Create engine to connect to postgres database (for admin operations)
admin_engine = create_async_engine(
base_url,
poolclass=NullPool,
isolation_level="AUTOCOMMIT", # Required for CREATE DATABASE
)
# Create test database if it doesn't exist
async with admin_engine.connect() as conn:
# Check if database exists
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :dbname"),
{"dbname": test_db_name},
)
exists = result.scalar() is not None
if not exists:
print(f"\n Creating test database: {test_db_name}")
# Use template0 to avoid collation version mismatch issues
await conn.execute(
text(f'CREATE DATABASE "{test_db_name}" TEMPLATE template0')
)
else:
print(f"\n Using existing test database: {test_db_name}")
await admin_engine.dispose()
# Run alembic migrations on the test database
print(f" Running migrations on {test_db_name}...")
await run_migrations(test_url)
print(" Migrations complete!")
yield test_url
# Cleanup: Optionally drop the test database after tests
# Commented out to allow inspection of test data after failures
# async with admin_engine.connect() as conn:
# await conn.execute(text(f'DROP DATABASE IF EXISTS "{test_db_name}"'))
async def run_migrations(database_url: str):
"""
Run alembic migrations programmatically on the given database.
"""
from alembic import command
from alembic.config import Config
# Get alembic.ini path
alembic_ini_path = Path(__file__).parent / "alembic.ini"
# Create alembic config
alembic_cfg = Config(str(alembic_ini_path))
# Override the database URL - need to patch both os.environ AND api.constants
# because api.constants.DATABASE_URL is cached at import time
original_env_url = os.environ.get("DATABASE_URL")
os.environ["DATABASE_URL"] = database_url
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
# Also patch the cached value in api.constants
import api.constants
original_constants_url = api.constants.DATABASE_URL
api.constants.DATABASE_URL = database_url
# Run migrations in a thread to avoid blocking the event loop
import asyncio
def _run_upgrade():
command.upgrade(alembic_cfg, "head")
try:
# Create the test database
async with default_engine.connect() as conn:
# Use autocommit mode to create database
await conn.execute(text("COMMIT"))
await conn.execute(text(f"CREATE DATABASE {TEST_DATABASE_NAME}"))
await default_engine.dispose()
# Run migrations on the test database
env = os.environ.copy()
env["DATABASE_URL"] = TEST_DATABASE_URL
# Add the parent directory to PYTHONPATH so alembic can find the api module
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
env["PYTHONPATH"] = parent_dir + ":" + env.get("PYTHONPATH", "")
# Run alembic upgrade to create all tables
result = subprocess.run(
[
"conda",
"run",
"-n",
"dograh",
"python",
"-m",
"alembic",
"-c",
"alembic.ini",
"upgrade",
"head",
],
env=env,
capture_output=True,
text=True,
)
if result.returncode != 0:
logger.error(f"Alembic stderr: {result.stderr}")
logger.error(f"Alembic stdout: {result.stdout}")
raise RuntimeError(f"Alembic migration failed: {result.stderr}")
logger.info(f"Created test database: {TEST_DATABASE_NAME}")
yield TEST_DATABASE_URL
await asyncio.get_event_loop().run_in_executor(None, _run_upgrade)
finally:
# Cleanup: Drop the test database
cleanup_engine = create_async_engine(base_url)
try:
async with cleanup_engine.connect() as conn:
# Terminate any connections to the test database
await conn.execute(text("COMMIT"))
await conn.execute(
text(f"""
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE datname = '{TEST_DATABASE_NAME}' AND pid <> pg_backend_pid()
""")
)
await conn.execute(
text(f"DROP DATABASE IF EXISTS {TEST_DATABASE_NAME}")
)
logger.info(f"Cleaned up test database: {TEST_DATABASE_NAME}")
except Exception as e:
logger.error(
f"Warning: Could not clean up test database {TEST_DATABASE_NAME}: {e}"
)
finally:
await cleanup_engine.dispose()
# Restore original DATABASE_URL
if original_env_url:
os.environ["DATABASE_URL"] = original_env_url
api.constants.DATABASE_URL = original_constants_url
@pytest_asyncio.fixture
async def db_session(test_database):
@pytest.fixture(scope="session")
async def test_engine(setup_test_database):
"""
Create a test database client that uses the temporary database.
This fixture replaces the global db_client with a test version.
Create a test database engine (session-scoped).
Uses NullPool to avoid connection issues with async tests.
"""
test_url = setup_test_database
engine = create_async_engine(
test_url,
poolclass=NullPool,
echo=False, # Set to True for SQL debugging
)
yield engine
await engine.dispose()
@pytest.fixture(scope="function")
async def db_connection(test_engine) -> AsyncGenerator[AsyncConnection, None]:
"""
Create a database connection for each test.
This connection wraps all operations in a transaction that
will be rolled back at the end of the test.
"""
async with test_engine.connect() as connection:
yield connection
@pytest.fixture(scope="function")
async def async_session(
db_connection: AsyncConnection,
) -> AsyncGenerator[AsyncSession, None]:
"""
Create a database session with transaction isolation for each test.
This fixture:
1. Begins a transaction on the connection
2. Creates a savepoint (nested transaction)
3. Yields the session for test use
4. Rolls back all changes after the test
Tests can call session.commit() and it will only commit to the savepoint,
not to the actual database. The outer transaction rollback ensures
complete isolation between tests.
"""
# Begin outer transaction
trans = await db_connection.begin()
# Create session bound to this connection
async_session_maker = async_sessionmaker(
bind=db_connection,
expire_on_commit=False,
autoflush=False,
)
async with async_session_maker() as session:
# Begin a nested transaction (savepoint)
nested = await session.begin_nested()
# Set up event listener to restart savepoint after commits
@event.listens_for(session.sync_session, "after_transaction_end")
def reopen_nested_transaction(session_sync, transaction: SessionTransaction):
nonlocal nested
if not nested.is_active:
nested = session.sync_session.begin_nested()
yield session
# Rollback everything
await trans.rollback()
class _TestSessionContext:
"""
Context manager wrapper for test session.
Mimics the behavior of async_sessionmaker() context manager
but uses the existing test session without closing it.
"""
def __init__(self, session: AsyncSession):
self._session = session
async def __aenter__(self) -> AsyncSession:
return self._session
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
await self._session.flush()
return False
@pytest.fixture(scope="function")
async def db_session(async_session: AsyncSession):
"""
Create a DBClient instance that uses the test session.
This patches the DBClient's async_session to use our test session,
ensuring all database operations go through the transactional test session.
Note: This fixture yields a DBClient (not a raw session) for backward
compatibility with existing tests that call db_session.get_or_create_user_by_provider_id(), etc.
"""
from api.db import db_client
def test_session_maker():
return _TestSessionContext(async_session)
# Store originals
original_engine = db_client.engine
original_session = db_client.async_session
original_async_session = db_client.async_session
# Replace the database client's engine and session with test ones
test_engine = create_async_engine(test_database)
test_session_maker = async_sessionmaker(bind=test_engine)
db_client.engine = test_engine
# Patch the db_client to use our test session
db_client.async_session = test_session_maker
yield db_client
# Restore original database client
await test_engine.dispose()
# Restore originals
db_client.engine = original_engine
db_client.async_session = original_session
db_client.async_session = original_async_session
@pytest_asyncio.fixture
@pytest.fixture
async def test_client_factory(db_session):
"""
Factory fixture that creates test clients for specific users.
@ -155,6 +327,9 @@ async def test_client_factory(db_session):
"""
from contextlib import asynccontextmanager
from httpx import ASGITransport, AsyncClient
from api.app import app
from api.services.auth.depends import get_user
@asynccontextmanager

View file

@ -8,6 +8,7 @@ from api.db.organization_client import OrganizationClient
from api.db.organization_configuration_client import OrganizationConfigurationClient
from api.db.organization_usage_client import OrganizationUsageClient
from api.db.reports_client import ReportsClient
from api.db.tool_client import ToolClient
from api.db.user_client import UserClient
from api.db.webhook_credential_client import WebhookCredentialClient
from api.db.workflow_client import WorkflowClient
@ -31,6 +32,7 @@ class DBClient(
EmbedTokenClient,
AgentTriggerClient,
WebhookCredentialClient,
ToolClient,
):
"""
Unified database client that combines all specialized database operations.
@ -51,6 +53,7 @@ class DBClient(
- EmbedTokenClient: handles embed token and session operations
- AgentTriggerClient: handles agent trigger operations for API-based call triggering
- WebhookCredentialClient: handles webhook credential operations
- ToolClient: handles tool operations for reusable HTTP API tools
"""
pass

View file

@ -22,6 +22,8 @@ from sqlalchemy.orm import declarative_base, relationship
from ..enums import (
IntegrationAction,
ToolCategory,
ToolStatus,
TriggerState,
WebhookCredentialType,
WorkflowRunMode,
@ -800,3 +802,85 @@ class ExternalCredentialModel(Base):
Index("ix_webhook_credentials_uuid", "credential_uuid"),
UniqueConstraint("organization_id", "name", name="unique_org_credential_name"),
)
class ToolModel(Base):
"""Model for storing reusable tools that can be invoked during workflows.
Tools provide a standardized way to integrate external functionality - from
HTTP API calls to native integrations.
"""
__tablename__ = "tools"
id = Column(Integer, primary_key=True, index=True)
# Public identifier (used in APIs and workflow references)
tool_uuid = Column(
String(36),
unique=True,
nullable=False,
index=True,
default=lambda: str(uuid.uuid4()),
)
# Organization scoping
organization_id = Column(
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
)
# Tool metadata
name = Column(String(255), nullable=False)
description = Column(String, nullable=True)
# Tool category - uses enum from api/enums.py
category = Column(
Enum(
*[c.value for c in ToolCategory],
name="tool_category",
),
nullable=False,
default=ToolCategory.HTTP_API.value,
)
# Icon configuration (for UI display)
icon = Column(String(50), nullable=True) # Icon identifier
icon_color = Column(String(7), nullable=True) # Hex color code
# Status management
status = Column(
Enum(
*[s.value for s in ToolStatus],
name="tool_status",
),
nullable=False,
default=ToolStatus.ACTIVE.value,
server_default=text("'active'::tool_status"),
)
# The tool definition (JSONB) - contains schema_version for compatibility
# Structure depends on category:
# - http_api: {"schema_version": 1, "type": "http_api", "config": {...}}
definition = Column(JSON, nullable=False, default=dict)
# Audit fields
created_by = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
# Relationships
organization = relationship("OrganizationModel")
created_by_user = relationship("UserModel")
# Indexes and constraints
__table_args__ = (
Index("ix_tools_organization_id", "organization_id"),
Index("ix_tools_uuid", "tool_uuid"),
Index("ix_tools_status", "status"),
Index("ix_tools_category", "category"),
UniqueConstraint("organization_id", "name", name="unique_org_tool_name"),
)

276
api/db/tool_client.py Normal file
View file

@ -0,0 +1,276 @@
"""Database client for managing tools."""
from datetime import UTC, datetime
from typing import List, Optional
from loguru import logger
from sqlalchemy import select, update
from sqlalchemy.orm import selectinload
from api.db.base_client import BaseDBClient
from api.db.models import ToolModel
from api.enums import ToolCategory, ToolStatus
class ToolClient(BaseDBClient):
"""Client for managing tools (organization-scoped, UUID-referenced)."""
async def create_tool(
self,
organization_id: int,
user_id: int,
name: str,
definition: dict,
category: str = ToolCategory.HTTP_API.value,
description: Optional[str] = None,
icon: Optional[str] = None,
icon_color: Optional[str] = None,
) -> ToolModel:
"""Create a new tool.
Args:
organization_id: ID of the organization
user_id: ID of the user creating the tool
name: Display name for the tool
definition: JSON definition of the tool
category: Tool category (http_api, native, integration)
description: Optional description
icon: Optional icon identifier
icon_color: Optional hex color code
Returns:
The created ToolModel with auto-generated UUID
"""
async with self.async_session() as session:
tool = ToolModel(
organization_id=organization_id,
created_by=user_id,
name=name,
description=description,
category=category,
icon=icon,
icon_color=icon_color,
definition=definition,
status=ToolStatus.ACTIVE.value,
)
session.add(tool)
await session.commit()
await session.refresh(tool)
logger.info(
f"Created tool '{name}' ({tool.tool_uuid}) "
f"for organization {organization_id}"
)
return tool
async def get_tools_for_organization(
self,
organization_id: int,
status: Optional[str] = None,
category: Optional[str] = None,
) -> List[ToolModel]:
"""Get all tools for an organization.
Args:
organization_id: ID of the organization
status: Optional filter by status (active, archived, draft)
category: Optional filter by category (http_api, native, integration)
Returns:
List of ToolModel instances
"""
async with self.async_session() as session:
query = select(ToolModel).where(
ToolModel.organization_id == organization_id
)
if status:
query = query.where(ToolModel.status == status)
else:
# By default, exclude archived tools
query = query.where(ToolModel.status != ToolStatus.ARCHIVED.value)
if category:
query = query.where(ToolModel.category == category)
query = query.order_by(ToolModel.name)
result = await session.execute(query)
return list(result.scalars().all())
async def get_tool_by_uuid(
self,
tool_uuid: str,
organization_id: int,
include_archived: bool = False,
) -> Optional[ToolModel]:
"""Get a tool by its UUID, scoped to organization.
Args:
tool_uuid: The unique tool UUID
organization_id: ID of the organization (for authorization)
include_archived: If True, include archived tools
Returns:
ToolModel if found and authorized, None otherwise
"""
async with self.async_session() as session:
query = (
select(ToolModel)
.where(
ToolModel.tool_uuid == tool_uuid,
ToolModel.organization_id == organization_id,
)
.options(selectinload(ToolModel.created_by_user))
)
if not include_archived:
query = query.where(ToolModel.status != ToolStatus.ARCHIVED.value)
result = await session.execute(query)
return result.scalar_one_or_none()
async def update_tool(
self,
tool_uuid: str,
organization_id: int,
name: Optional[str] = None,
description: Optional[str] = None,
definition: Optional[dict] = None,
icon: Optional[str] = None,
icon_color: Optional[str] = None,
status: Optional[str] = None,
) -> Optional[ToolModel]:
"""Update a tool by UUID.
Args:
tool_uuid: The unique tool UUID
organization_id: ID of the organization (for authorization)
name: New name (if provided)
description: New description (if provided)
definition: New definition (if provided)
icon: New icon (if provided)
icon_color: New icon color (if provided)
status: New status (if provided)
Returns:
Updated ToolModel if found, None otherwise
"""
async with self.async_session() as session:
# First check if tool exists and belongs to organization
tool = await self.get_tool_by_uuid(
tool_uuid, organization_id, include_archived=True
)
if not tool:
return None
# Build update values
update_values = {"updated_at": datetime.now(UTC)}
if name is not None:
update_values["name"] = name
if description is not None:
update_values["description"] = description
if definition is not None:
update_values["definition"] = definition
if icon is not None:
update_values["icon"] = icon
if icon_color is not None:
update_values["icon_color"] = icon_color
if status is not None:
update_values["status"] = status
await session.execute(
update(ToolModel)
.where(
ToolModel.tool_uuid == tool_uuid,
ToolModel.organization_id == organization_id,
)
.values(**update_values)
)
await session.commit()
# Fetch updated tool
result = await session.execute(
select(ToolModel)
.where(ToolModel.tool_uuid == tool_uuid)
.options(selectinload(ToolModel.created_by_user))
)
updated_tool = result.scalar_one()
logger.info(f"Updated tool {tool_uuid} for organization {organization_id}")
return updated_tool
async def archive_tool(self, tool_uuid: str, organization_id: int) -> bool:
"""Soft delete a tool by setting its status to archived.
Args:
tool_uuid: The unique tool UUID
organization_id: ID of the organization (for authorization)
Returns:
True if tool was archived, False if not found
"""
async with self.async_session() as session:
result = await session.execute(
update(ToolModel)
.where(
ToolModel.tool_uuid == tool_uuid,
ToolModel.organization_id == organization_id,
ToolModel.status != ToolStatus.ARCHIVED.value,
)
.values(
status=ToolStatus.ARCHIVED.value,
updated_at=datetime.now(UTC),
)
)
await session.commit()
if result.rowcount > 0:
logger.info(
f"Archived tool {tool_uuid} for organization {organization_id}"
)
return True
return False
async def validate_tool_uuid(self, tool_uuid: str, organization_id: int) -> bool:
"""Check if a tool UUID exists and belongs to the organization.
This is useful for workflow validation to ensure referenced tools exist.
Args:
tool_uuid: The tool UUID to validate
organization_id: ID of the organization
Returns:
True if valid, False otherwise
"""
tool = await self.get_tool_by_uuid(tool_uuid, organization_id)
return tool is not None
async def get_tools_by_uuids(
self,
tool_uuids: List[str],
organization_id: int,
) -> List[ToolModel]:
"""Get multiple tools by their UUIDs.
Args:
tool_uuids: List of tool UUIDs to fetch
organization_id: ID of the organization (for authorization)
Returns:
List of ToolModel instances (only active tools)
"""
if not tool_uuids:
return []
async with self.async_session() as session:
query = select(ToolModel).where(
ToolModel.tool_uuid.in_(tool_uuids),
ToolModel.organization_id == organization_id,
ToolModel.status == ToolStatus.ACTIVE.value,
)
result = await session.execute(query)
return list(result.scalars().all())

View file

@ -109,3 +109,21 @@ class WebhookCredentialType(Enum):
BEARER_TOKEN = "bearer_token" # Bearer token auth
BASIC_AUTH = "basic_auth" # Username/password
CUSTOM_HEADER = "custom_header" # Custom header key-value
class ToolCategory(Enum):
"""Tool category types"""
HTTP_API = "http_api" # Custom HTTP API calls (implemented)
NATIVE = (
"native" # Built-in integrations (future: call_transfer, dtmf_input, end_call)
)
INTEGRATION = "integration" # Third-party integrations (future: Google Calendar, Salesforce, etc.)
class ToolStatus(Enum):
"""Tool status values"""
ACTIVE = "active" # Tool is available for use
ARCHIVED = "archived" # Tool is soft-deleted
DRAFT = "draft" # Tool is being configured (not ready for use)

View file

@ -1,7 +1,8 @@
[pytest]
asyncio_mode = strict
asyncio_default_fixture_loop_scope = function
testpaths = .
asyncio_mode = auto
asyncio_default_fixture_loop_scope = session
asyncio_default_test_loop_scope = session
testpaths = tests
python_files = test_*.py *_test.py
python_classes = Test*
python_functions = test_*

View file

@ -3,4 +3,5 @@ ruff==0.11.3
pytest==8.3.5
pytest-asyncio==0.26.0
pre-commit==4.2.0
watchfiles==1.1.0
watchfiles==1.1.0
python-dotenv==1.2.1

View file

@ -15,6 +15,7 @@ from api.routes.s3_signed_url import router as s3_router
from api.routes.service_keys import router as service_keys_router
from api.routes.superuser import router as superuser_router
from api.routes.telephony import router as telephony_router
from api.routes.tool import router as tool_router
from api.routes.user import router as user_router
from api.routes.webrtc_signaling import router as webrtc_signaling_router
from api.routes.workflow import router as workflow_router
@ -32,6 +33,7 @@ router.include_router(workflow_router)
router.include_router(user_router)
router.include_router(campaign_router)
router.include_router(credentials_router)
router.include_router(tool_router)
router.include_router(integration_router)
router.include_router(organization_router)
router.include_router(s3_router)

336
api/routes/tool.py Normal file
View file

@ -0,0 +1,336 @@
"""API routes for managing tools."""
from datetime import datetime
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from api.db import db_client
from api.db.models import UserModel
from api.enums import ToolCategory, ToolStatus
from api.services.auth.depends import get_user
router = APIRouter(prefix="/tools")
# Request/Response schemas
class ToolParameter(BaseModel):
"""A parameter that the tool accepts."""
name: str = Field(description="Parameter name (used as key in request body)")
type: str = Field(description="Parameter type: string, number, or boolean")
description: str = Field(description="Description of what this parameter is for")
required: bool = Field(
default=True, description="Whether this parameter is required"
)
class HttpApiConfig(BaseModel):
"""Configuration for HTTP API tools."""
method: str = Field(description="HTTP method (GET, POST, PUT, PATCH, DELETE)")
url: str = Field(description="Target URL")
headers: Optional[Dict[str, str]] = Field(
default=None, description="Static headers to include"
)
credential_uuid: Optional[str] = Field(
default=None, description="Reference to ExternalCredentialModel for auth"
)
parameters: Optional[List[ToolParameter]] = Field(
default=None, description="Parameters that the tool accepts from LLM"
)
timeout_ms: Optional[int] = Field(
default=5000, description="Request timeout in milliseconds"
)
class ToolDefinition(BaseModel):
"""Tool definition schema."""
schema_version: int = Field(
default=1, description="Schema version for compatibility"
)
type: str = Field(description="Tool type (http_api)")
config: HttpApiConfig = Field(description="Tool configuration")
class CreateToolRequest(BaseModel):
"""Request schema for creating a tool."""
name: str = Field(max_length=255)
description: Optional[str] = None
category: str = Field(default=ToolCategory.HTTP_API.value)
icon: Optional[str] = Field(default="globe", max_length=50)
icon_color: Optional[str] = Field(default="#3B82F6", max_length=7)
definition: ToolDefinition
class UpdateToolRequest(BaseModel):
"""Request schema for updating a tool."""
name: Optional[str] = Field(default=None, max_length=255)
description: Optional[str] = None
icon: Optional[str] = Field(default=None, max_length=50)
icon_color: Optional[str] = Field(default=None, max_length=7)
definition: Optional[ToolDefinition] = None
status: Optional[str] = None
class CreatedByResponse(BaseModel):
"""Response schema for the user who created a tool."""
id: int
provider_id: str
class ToolResponse(BaseModel):
"""Response schema for a tool."""
id: int
tool_uuid: str
name: str
description: Optional[str]
category: str
icon: Optional[str]
icon_color: Optional[str]
status: str
definition: Dict[str, Any]
created_at: datetime
updated_at: Optional[datetime]
created_by: Optional[CreatedByResponse] = None
class Config:
from_attributes = True
def build_tool_response(tool, include_created_by: bool = False) -> ToolResponse:
"""Build a response from a tool model."""
created_by = None
if include_created_by and tool.created_by_user:
created_by = CreatedByResponse(
id=tool.created_by_user.id,
provider_id=tool.created_by_user.provider_id,
)
return ToolResponse(
id=tool.id,
tool_uuid=tool.tool_uuid,
name=tool.name,
description=tool.description,
category=tool.category,
icon=tool.icon,
icon_color=tool.icon_color,
status=tool.status,
definition=tool.definition,
created_at=tool.created_at,
updated_at=tool.updated_at,
created_by=created_by,
)
def validate_category(category: str) -> None:
"""Validate that the category is valid."""
valid_categories = [c.value for c in ToolCategory]
if category not in valid_categories:
raise HTTPException(
status_code=400,
detail=f"Invalid category '{category}'. Must be one of: {', '.join(valid_categories)}",
)
def validate_status(status: str) -> None:
"""Validate that the status is valid."""
valid_statuses = [s.value for s in ToolStatus]
if status not in valid_statuses:
raise HTTPException(
status_code=400,
detail=f"Invalid status '{status}'. Must be one of: {', '.join(valid_statuses)}",
)
@router.get("/")
async def list_tools(
status: Optional[str] = None,
category: Optional[str] = None,
user: UserModel = Depends(get_user),
) -> List[ToolResponse]:
"""
List all tools for the user's organization.
Args:
status: Optional filter by status (active, archived, draft)
category: Optional filter by category (http_api, native, integration)
Returns:
List of tools
"""
if not user.selected_organization_id:
raise HTTPException(
status_code=400, detail="No organization selected for the user"
)
if status:
validate_status(status)
if category:
validate_category(category)
tools = await db_client.get_tools_for_organization(
user.selected_organization_id,
status=status,
category=category,
)
return [build_tool_response(tool) for tool in tools]
@router.post("/")
async def create_tool(
request: CreateToolRequest,
user: UserModel = Depends(get_user),
) -> ToolResponse:
"""
Create a new tool.
Args:
request: The tool creation request
Returns:
The created tool
"""
if not user.selected_organization_id:
raise HTTPException(
status_code=400, detail="No organization selected for the user"
)
validate_category(request.category)
try:
tool = await db_client.create_tool(
organization_id=user.selected_organization_id,
user_id=user.id,
name=request.name,
definition=request.definition.model_dump(),
category=request.category,
description=request.description,
icon=request.icon,
icon_color=request.icon_color,
)
return build_tool_response(tool)
except Exception as e:
if "unique_org_tool_name" in str(e):
raise HTTPException(
status_code=409,
detail=f"A tool with the name '{request.name}' already exists",
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{tool_uuid}")
async def get_tool(
tool_uuid: str,
user: UserModel = Depends(get_user),
) -> ToolResponse:
"""
Get a specific tool by UUID.
Args:
tool_uuid: The UUID of the tool
Returns:
The tool
"""
if not user.selected_organization_id:
raise HTTPException(
status_code=400, detail="No organization selected for the user"
)
tool = await db_client.get_tool_by_uuid(
tool_uuid, user.selected_organization_id, include_archived=True
)
if not tool:
raise HTTPException(status_code=404, detail="Tool not found")
return build_tool_response(tool, include_created_by=True)
@router.put("/{tool_uuid}")
async def update_tool(
tool_uuid: str,
request: UpdateToolRequest,
user: UserModel = Depends(get_user),
) -> ToolResponse:
"""
Update a tool.
Args:
tool_uuid: The UUID of the tool to update
request: The update request
Returns:
The updated tool
"""
if not user.selected_organization_id:
raise HTTPException(
status_code=400, detail="No organization selected for the user"
)
if request.status:
validate_status(request.status)
try:
tool = await db_client.update_tool(
tool_uuid=tool_uuid,
organization_id=user.selected_organization_id,
name=request.name,
description=request.description,
definition=request.definition.model_dump() if request.definition else None,
icon=request.icon,
icon_color=request.icon_color,
status=request.status,
)
if not tool:
raise HTTPException(status_code=404, detail="Tool not found")
return build_tool_response(tool, include_created_by=True)
except HTTPException:
raise
except Exception as e:
if "unique_org_tool_name" in str(e):
raise HTTPException(
status_code=409,
detail=f"A tool with the name '{request.name}' already exists",
)
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{tool_uuid}")
async def delete_tool(
tool_uuid: str,
user: UserModel = Depends(get_user),
) -> dict:
"""
Archive (soft delete) a tool.
Args:
tool_uuid: The UUID of the tool to delete
Returns:
Success message
"""
if not user.selected_organization_id:
raise HTTPException(
status_code=400, detail="No organization selected for the user"
)
deleted = await db_client.archive_tool(tool_uuid, user.selected_organization_id)
if not deleted:
raise HTTPException(status_code=404, detail="Tool not found")
return {"status": "archived", "tool_uuid": tool_uuid}

View file

@ -57,6 +57,7 @@ class NodeDataDTO(BaseModel):
detect_voicemail: bool = False
delayed_start: bool = False
delayed_start_duration: Optional[float] = None
tool_uuids: Optional[List[str]] = None
trigger_path: Optional[str] = None
# Webhook node specific fields
enabled: bool = True

View file

@ -38,6 +38,7 @@ import asyncio
from loguru import logger
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from api.services.workflow.pipecat_engine_utils import (
get_function_schema,
render_template,
@ -105,6 +106,16 @@ class PipecatEngine:
# Track current LLM reference text for TTS aggregation correction
self._current_llm_reference_text: str = ""
# Custom tool manager (initialized in initialize())
self._custom_tool_manager: Optional[CustomToolManager] = None
async def _get_organization_id(self) -> Optional[int]:
"""Get and cache the organization ID from workflow run."""
if self._custom_tool_manager:
return await self._custom_tool_manager.get_organization_id()
# Fallback for when manager is not yet initialized
return await get_organization_id_from_workflow_run(self._workflow_run_id)
@property
def builtin_function_schemas(self) -> list[dict]:
"""Get built-in function schemas (calculator and timezone tools)."""
@ -146,6 +157,9 @@ class PipecatEngine:
# Helper that encapsulates variable extraction logic
self._variable_extraction_manager = VariableExtractionManager(self)
# Helper that encapsulates custom tool management
self._custom_tool_manager = CustomToolManager(self)
# Add current time in EST (America/New_York) to gathered context
try:
est_time_result = get_current_time("America/New_York")
@ -360,6 +374,10 @@ class PipecatEngine:
outgoing_edge.get_function_name(), outgoing_edge.target
)
# Register custom tool handlers for this node
if node.tool_uuids and self._custom_tool_manager:
await self._custom_tool_manager.register_handlers(node.tool_uuids)
# Set up system message and functions
(
system_message,
@ -492,9 +510,7 @@ class PipecatEngine:
# Apply disposition mapping - first try call_disposition if it is,
# extracted from the call conversation then fall back to reason
call_disposition = self._gathered_context.get("call_disposition", "")
organization_id = await get_organization_id_from_workflow_run(
self._workflow_run_id
)
organization_id = await self._get_organization_id()
# If client is disconnected before we get a chance to disconnect from
# the bot, lets consider that as final disposition
@ -618,6 +634,13 @@ class PipecatEngine:
# Add built-in function schemas (calculator and timezone tools)
functions.extend(self.builtin_function_schemas)
# Add custom tools from node.tool_uuids
if node.tool_uuids and self._custom_tool_manager:
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(
node.tool_uuids
)
functions.extend(custom_tool_schemas)
# Transition functions (schema only; registration handled elsewhere)
for outgoing_edge in node.out_edges:
function_schema = self._get_function_schema(

View file

@ -0,0 +1,189 @@
"""Custom tool management for PipecatEngine.
This module handles fetching, registering, and executing user-defined tools
during workflow execution.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional
from loguru import logger
from api.db import db_client
from api.services.workflow.disposition_mapper import (
get_organization_id_from_workflow_run,
)
from api.services.workflow.pipecat_engine_utils import get_function_schema
from api.services.workflow.tools.custom_tool import (
execute_http_tool,
tool_to_function_schema,
)
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.frames.frames import FunctionCallResultProperties
from pipecat.services.llm_service import FunctionCallParams
if TYPE_CHECKING:
from api.services.workflow.pipecat_engine import PipecatEngine
class CustomToolManager:
"""Manager for custom tool registration and execution.
This class handles:
1. Fetching tools from the database based on tool UUIDs
2. Converting tools to LLM function schemas
3. Registering tool execution handlers with the LLM
4. Executing HTTP API tools when invoked by the LLM
"""
def __init__(self, engine: "PipecatEngine") -> None:
self._engine = engine
self._organization_id: Optional[int] = None
# Cache: maps function_name -> (tool, schema)
self._tools_cache: dict[str, tuple[Any, dict]] = {}
async def get_organization_id(self) -> Optional[int]:
"""Get and cache the organization ID from workflow run."""
if self._organization_id is None:
self._organization_id = await get_organization_id_from_workflow_run(
self._engine._workflow_run_id
)
return self._organization_id
async def get_tool_schemas(self, tool_uuids: list[str]) -> list[FunctionSchema]:
"""Fetch custom tools and convert them to function schemas.
Args:
tool_uuids: List of tool UUIDs to fetch
Returns:
List of FunctionSchema objects for LLM
"""
organization_id = await self.get_organization_id()
if not organization_id:
logger.warning("Cannot fetch custom tools: organization_id not available")
return []
try:
tools = await db_client.get_tools_by_uuids(tool_uuids, organization_id)
schemas: list[FunctionSchema] = []
for tool in tools:
raw_schema = tool_to_function_schema(tool)
function_name = raw_schema["function"]["name"]
# Cache the tool for later execution
self._tools_cache[function_name] = (tool, raw_schema)
# Convert to FunctionSchema object for compatibility with update_llm_context
func_schema = get_function_schema(
function_name,
raw_schema["function"]["description"],
properties=raw_schema["function"]["parameters"].get(
"properties", {}
),
required=raw_schema["function"]["parameters"].get("required", []),
)
schemas.append(func_schema)
logger.debug(
f"Loaded {len(schemas)} custom tools for node: "
f"{[s.name for s in schemas]}"
)
return schemas
except Exception as e:
logger.error(f"Failed to fetch custom tools: {e}")
return []
async def register_handlers(self, tool_uuids: list[str]) -> None:
"""Register custom tool execution handlers with the LLM.
Args:
tool_uuids: List of tool UUIDs to register handlers for
"""
organization_id = await self.get_organization_id()
if not organization_id:
logger.warning(
"Cannot register custom tool handlers: organization_id not available"
)
return
try:
tools = await db_client.get_tools_by_uuids(tool_uuids, organization_id)
for tool in tools:
schema = tool_to_function_schema(tool)
function_name = schema["function"]["name"]
# Cache the tool for potential later use
self._tools_cache[function_name] = (tool, schema)
# Create and register the handler
handler = self._create_handler(tool, function_name)
self._engine.llm.register_function(function_name, handler)
logger.debug(
f"Registered custom tool handler: {function_name} "
f"(tool_uuid: {tool.tool_uuid})"
)
except Exception as e:
logger.error(f"Failed to register custom tool handlers: {e}")
def _create_handler(self, tool: Any, function_name: str):
"""Create a handler function for a custom tool.
Args:
tool: The ToolModel instance
function_name: The function name used by the LLM
Returns:
Async handler function for the tool
"""
# Run LLM after tool execution to continue conversation
properties = FunctionCallResultProperties(run_llm=True)
async def custom_tool_handler(
function_call_params: FunctionCallParams,
) -> None:
logger.info(f"LLM Function Call EXECUTED: {function_name}")
logger.info(f"Arguments: {function_call_params.arguments}")
try:
# Execute the HTTP API tool
result = await execute_http_tool(
tool=tool,
arguments=function_call_params.arguments,
call_context_vars=self._engine._call_context_vars,
organization_id=self._organization_id,
)
await function_call_params.result_callback(
result, properties=properties
)
except Exception as e:
logger.error(f"Custom tool '{function_name}' execution failed: {e}")
await function_call_params.result_callback(
{"status": "error", "error": str(e)},
properties=properties,
)
return custom_tool_handler
def get_cached_tool(self, function_name: str) -> Optional[tuple[Any, dict]]:
"""Get a cached tool by its function name.
Args:
function_name: The function name used by the LLM
Returns:
Tuple of (tool, schema) if found, None otherwise
"""
return self._tools_cache.get(function_name)
def clear_cache(self) -> None:
"""Clear the tools cache."""
self._tools_cache.clear()

View file

@ -0,0 +1,180 @@
"""Custom tool execution for user-defined HTTP API tools."""
import re
from typing import Any, Dict, Optional
import httpx
from loguru import logger
from api.db import db_client
from api.utils.credential_auth import build_auth_header
# Map tool parameter types to JSON schema types
TYPE_MAP = {
"string": "string",
"number": "number",
"boolean": "boolean",
}
def tool_to_function_schema(tool: Any) -> Dict[str, Any]:
"""Convert a ToolModel to an LLM function schema.
Args:
tool: ToolModel instance with name, description, and definition
Returns:
Function schema dict compatible with OpenAI/Anthropic function calling
"""
definition = tool.definition or {}
config = definition.get("config", {})
parameters = config.get("parameters", []) or []
# Build properties and required list from parameters
properties = {}
required = []
for param in parameters:
param_name = param.get("name", "")
param_type = param.get("type", "string")
param_desc = param.get("description", "")
param_required = param.get("required", True)
if not param_name:
continue
properties[param_name] = {
"type": TYPE_MAP.get(param_type, "string"),
"description": param_desc,
}
if param_required:
required.append(param_name)
# Sanitize tool name for function name (lowercase, underscores only)
function_name = re.sub(r"[^a-z0-9_]", "_", tool.name.lower())
# Remove consecutive underscores and trim
function_name = re.sub(r"_+", "_", function_name).strip("_")
return {
"type": "function",
"function": {
"name": function_name,
"description": tool.description or f"Execute {tool.name} tool",
"parameters": {
"type": "object",
"properties": properties,
"required": required,
},
},
"_tool_uuid": tool.tool_uuid,
}
async def execute_http_tool(
tool: Any,
arguments: Dict[str, Any],
call_context_vars: Optional[Dict[str, Any]] = None,
organization_id: Optional[int] = None,
) -> Dict[str, Any]:
"""Execute an HTTP API tool.
Args:
tool: ToolModel instance
arguments: Arguments passed by the LLM (parameter name -> value)
call_context_vars: Additional context variables from the call (unused for now)
organization_id: Organization ID for credential lookup
Returns:
Result dict with response data or error
"""
definition = tool.definition or {}
config = definition.get("config", {})
# Get HTTP method and URL
method = config.get("method", "POST").upper()
url = config.get("url", "")
# Get headers from config
headers = dict(config.get("headers", {}) or {})
# Add auth header if credential is configured
credential_uuid = config.get("credential_uuid")
if credential_uuid and organization_id:
try:
credential = await db_client.get_credential_by_uuid(
credential_uuid, organization_id
)
if credential:
auth_header = build_auth_header(credential)
headers.update(auth_header)
logger.debug(f"Applied credential '{credential.name}' to tool request")
else:
logger.warning(
f"Credential {credential_uuid} not found for tool '{tool.name}'"
)
except Exception as e:
logger.error(f"Failed to fetch credential for tool '{tool.name}': {e}")
# Get timeout
timeout_ms = config.get("timeout_ms", 5000)
timeout_seconds = timeout_ms / 1000
# Build request: JSON body for POST/PUT/PATCH, query params for GET/DELETE
body = None
params = None
if method in ("POST", "PUT", "PATCH"):
body = arguments
elif method in ("GET", "DELETE") and arguments:
params = arguments
logger.info(
f"Executing custom tool '{tool.name}' ({tool.tool_uuid}): {method} {url}"
)
logger.debug(f"Request body: {body}, params: {params}")
try:
async with httpx.AsyncClient(timeout=timeout_seconds) as client:
response = await client.request(
method=method,
url=url,
headers=headers,
json=body,
params=params,
)
# Try to parse JSON response
try:
response_data = response.json()
except Exception:
response_data = {"raw_response": response.text}
result = {
"status": "success",
"status_code": response.status_code,
"data": response_data,
}
logger.debug(
f"Custom tool '{tool.name}' completed with status {response.status_code}"
)
return result
except httpx.TimeoutException:
logger.error(f"Custom tool '{tool.name}' timed out after {timeout_seconds}s")
return {
"status": "error",
"error": f"Request timed out after {timeout_seconds} seconds",
}
except httpx.RequestError as e:
logger.error(f"Custom tool '{tool.name}' request failed: {e}")
return {
"status": "error",
"error": f"Request failed: {str(e)}",
}
except Exception as e:
logger.error(f"Custom tool '{tool.name}' execution failed: {e}")
return {
"status": "error",
"error": f"Tool execution failed: {str(e)}",
}

View file

@ -47,6 +47,7 @@ class Node:
self.detect_voicemail = data.detect_voicemail
self.delayed_start = data.delayed_start
self.delayed_start_duration = data.delayed_start_duration
self.tool_uuids = data.tool_uuids
self.data = data

View file

@ -1,13 +1,13 @@
"""Execute webhook integrations after workflow run completion."""
import base64
from typing import Any, Dict
import httpx
from loguru import logger
from api.db import db_client
from api.db.models import ExternalCredentialModel, WorkflowRunModel
from api.db.models import WorkflowRunModel
from api.utils.credential_auth import build_auth_header
from api.utils.template_renderer import render_template
from pipecat.utils.context import set_current_run_id
@ -133,7 +133,7 @@ async def _execute_webhook_node(
credential_uuid, organization_id
)
if credential:
auth_header = _build_auth_header(credential)
auth_header = build_auth_header(credential)
headers.update(auth_header)
logger.debug(f"Applied credential '{credential.name}' to webhook")
else:
@ -189,39 +189,3 @@ async def _execute_webhook_node(
except Exception as e:
logger.error(f"Webhook '{webhook_name}' unexpected error: {e}")
return False
def _build_auth_header(credential: ExternalCredentialModel) -> Dict[str, str]:
"""
Build authentication header based on credential type.
Args:
credential: The credential model
Returns:
Dict with header name and value
"""
cred_type = credential.credential_type
cred_data = credential.credential_data or {}
if cred_type == "bearer_token":
token = cred_data.get("token", "")
return {"Authorization": f"Bearer {token}"}
elif cred_type == "api_key":
header_name = cred_data.get("header_name", "X-API-Key")
api_key = cred_data.get("api_key", "")
return {header_name: api_key}
elif cred_type == "basic_auth":
username = cred_data.get("username", "")
password = cred_data.get("password", "")
encoded = base64.b64encode(f"{username}:{password}".encode()).decode()
return {"Authorization": f"Basic {encoded}"}
elif cred_type == "custom_header":
header_name = cred_data.get("header_name", "X-Custom")
header_value = cred_data.get("header_value", "")
return {header_name: header_value}
return {}

View file

@ -1,138 +0,0 @@
import asyncio
import pytest
from pipecat.frames.frames import (
FunctionCallInProgressFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
StartInterruptionFrame,
TextFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.openai.llm import OpenAIAssistantContextAggregator
@pytest.mark.asyncio
async def test_reordering_after_completion():
context = OpenAILLMContext()
aggr = OpenAIAssistantContextAggregator(context)
# Initialize task manager properly using PipelineTask
pipeline = Pipeline([aggr])
task = PipelineTask(pipeline)
runner = PipelineRunner()
# Start the task to properly initialize the frame processor
task_coroutine = asyncio.create_task(runner.run(task))
# Give the task a moment to initialize
await asyncio.sleep(0.01)
# start new LLM response
await aggr.process_frame(LLMFullResponseStartFrame(), FrameDirection.DOWNSTREAM)
# simulate a pending function call
await aggr.process_frame(
FunctionCallInProgressFrame(
function_name="transition",
tool_call_id="1",
arguments={},
cancel_on_interruption=False,
),
FrameDirection.DOWNSTREAM,
)
# now text arrives
await aggr.process_frame(TextFrame("Hi there"), FrameDirection.DOWNSTREAM)
# end response
await aggr.process_frame(LLMFullResponseEndFrame(), FrameDirection.DOWNSTREAM)
msgs = context.get_messages()
# Assert order: assistant text first, then tool_call assistant, then tool response
assert msgs[0]["role"] == "assistant" and "tool_calls" not in msgs[0]
# Fix: content is a string, not a structured object
assert msgs[0]["content"] == "Hi there"
assert any(m.get("role") == "assistant" and m.get("tool_calls") for m in msgs[1:])
assert any(m.get("role") == "tool" for m in msgs[1:])
# Clean up the running task
await task.cancel()
task_coroutine.cancel()
try:
await task_coroutine
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_interruption_removes_pending_function_calls_and_marks():
context = OpenAILLMContext()
aggr = OpenAIAssistantContextAggregator(context)
# Initialize task manager properly using PipelineTask
pipeline = Pipeline([aggr])
task = PipelineTask(pipeline)
runner = PipelineRunner()
# Start the task to properly initialize the frame processor
task_coroutine = asyncio.create_task(runner.run(task))
# Give the task a moment to initialize
await asyncio.sleep(0.01)
await aggr.process_frame(LLMFullResponseStartFrame(), FrameDirection.DOWNSTREAM)
await aggr.process_frame(
FunctionCallInProgressFrame(
function_name="transition",
tool_call_id="1",
arguments={},
cancel_on_interruption=False,
),
FrameDirection.DOWNSTREAM,
)
# Debug: Check the state before interruption
print(
f"Function calls in progress before interruption: {aggr._function_calls_in_progress}"
)
print(f"Messages before interruption: {context.get_messages()}")
# no text yet - still aggregation
await aggr.process_frame(StartInterruptionFrame(), FrameDirection.DOWNSTREAM)
msgs = context.get_messages()
# Debug: Print messages to understand what's happening
print(f"Messages after interruption: {msgs}")
print(
f"Function calls in progress after interruption: {aggr._function_calls_in_progress}"
)
# After interruption before any response is complete, context should be cleared
# This is the actual behavior - interruptions clear pending function calls
if len(msgs) == 0:
# Context was cleared due to interruption before completion
assert True
else:
# If there are messages, ensure no tool calls remain
assert not any(m.get("tool_calls") for m in msgs)
assert not any(m.get("role") == "tool" for m in msgs)
# Check if interruption marker is present
if msgs:
assert msgs[-1]["role"] == "assistant"
assert "<<interrupted_by_user>>" in msgs[-1]["content"]
# Clean up the running task
await task.cancel()
task_coroutine.cancel()
try:
await task_coroutine
except asyncio.CancelledError:
pass

View file

@ -1,120 +0,0 @@
import os
import wave
import pytest
from api.services.pipecat.audio_transcript_buffers import (
InMemoryAudioBuffer,
InMemoryTranscriptBuffer,
)
@pytest.mark.asyncio
async def test_audio_buffer_append_and_write():
"""Test that audio buffer can append data and write to temp file."""
buffer = InMemoryAudioBuffer(workflow_run_id=123, sample_rate=16000, num_channels=1)
# Create some test PCM data
test_pcm = b"\x00\x01" * 1000 # 2000 bytes
# Append data
await buffer.append(test_pcm)
await buffer.append(test_pcm)
assert buffer.size == 4000
assert not buffer.is_empty
# Write to temp file
temp_path = await buffer.write_to_temp_file()
try:
# Verify file exists and is valid WAV
assert os.path.exists(temp_path)
with wave.open(temp_path, "rb") as wf:
assert wf.getnchannels() == 1
assert wf.getsampwidth() == 2
assert wf.getframerate() == 16000
# Each frame is 2 bytes (16-bit), so 4000 bytes = 2000 frames
assert wf.getnframes() == 2000
finally:
# Clean up
if os.path.exists(temp_path):
os.remove(temp_path)
@pytest.mark.asyncio
async def test_audio_buffer_memory_limit():
"""Test that audio buffer enforces memory limit."""
buffer = InMemoryAudioBuffer(workflow_run_id=123, sample_rate=16000)
# Set a smaller limit for testing
buffer._max_size = 1000
# This should work
await buffer.append(b"\x00" * 500)
# This should fail
with pytest.raises(MemoryError):
await buffer.append(b"\x00" * 600)
@pytest.mark.asyncio
async def test_transcript_buffer_append_and_write():
"""Test that transcript buffer can append data and write to temp file."""
buffer = InMemoryTranscriptBuffer(workflow_run_id=456)
# Append some transcript lines
await buffer.append("[00:00:01] user: Hello\n")
await buffer.append("[00:00:02] assistant: Hi there!\n")
await buffer.append("[00:00:03] user: How are you?\n")
assert not buffer.is_empty
# Write to temp file
temp_path = await buffer.write_to_temp_file()
try:
# Verify file exists and has correct content
assert os.path.exists(temp_path)
with open(temp_path, "r") as f:
content = f.read()
assert "[00:00:01] user: Hello\n" in content
assert "[00:00:02] assistant: Hi there!\n" in content
assert "[00:00:03] user: How are you?\n" in content
finally:
# Clean up
if os.path.exists(temp_path):
os.remove(temp_path)
@pytest.mark.asyncio
async def test_empty_buffers():
"""Test that empty buffers are handled correctly."""
audio_buffer = InMemoryAudioBuffer(workflow_run_id=789, sample_rate=16000)
transcript_buffer = InMemoryTranscriptBuffer(workflow_run_id=789)
assert audio_buffer.is_empty
assert transcript_buffer.is_empty
# Should still be able to write empty files
audio_path = await audio_buffer.write_to_temp_file()
transcript_path = await transcript_buffer.write_to_temp_file()
try:
assert os.path.exists(audio_path)
assert os.path.exists(transcript_path)
# Empty WAV file should still have valid header
with wave.open(audio_path, "rb") as wf:
assert wf.getnframes() == 0
# Empty transcript file
with open(transcript_path, "r") as f:
assert f.read() == ""
finally:
if os.path.exists(audio_path):
os.remove(audio_path)
if os.path.exists(transcript_path):
os.remove(transcript_path)

View file

@ -1,330 +0,0 @@
"""Tests for concurrent call limiting functionality."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from api.enums import OrganizationConfigurationKey
from api.services.campaign.rate_limiter import RateLimiter
class TestConcurrentCallLimiting:
"""Test suite for concurrent call limiting."""
@pytest.mark.asyncio
async def test_acquire_concurrent_slot_success(self):
"""Test successful acquisition of concurrent slot."""
rate_limiter = RateLimiter()
# Mock Redis client
with patch.object(rate_limiter, "_get_redis") as mock_redis:
mock_client = AsyncMock()
mock_client.eval = AsyncMock(return_value="test_slot_123")
mock_redis.return_value = mock_client
# Try to acquire slot
slot_id = await rate_limiter.try_acquire_concurrent_slot(
organization_id=1, max_concurrent=20
)
assert slot_id == "test_slot_123"
mock_client.eval.assert_called_once()
@pytest.mark.asyncio
async def test_acquire_concurrent_slot_limit_reached(self):
"""Test slot acquisition when limit is reached."""
rate_limiter = RateLimiter()
# Mock Redis client
with patch.object(rate_limiter, "_get_redis") as mock_redis:
mock_client = AsyncMock()
mock_client.eval = AsyncMock(return_value=None) # Limit reached
mock_redis.return_value = mock_client
# Try to acquire slot
slot_id = await rate_limiter.try_acquire_concurrent_slot(
organization_id=1, max_concurrent=20
)
assert slot_id is None
mock_client.eval.assert_called_once()
@pytest.mark.asyncio
async def test_release_concurrent_slot(self):
"""Test releasing a concurrent slot."""
rate_limiter = RateLimiter()
# Mock Redis client
with patch.object(rate_limiter, "_get_redis") as mock_redis:
mock_client = AsyncMock()
mock_client.zrem = AsyncMock(return_value=1) # Successfully removed
mock_redis.return_value = mock_client
# Release slot
success = await rate_limiter.release_concurrent_slot(
organization_id=1, slot_id="test_slot_123"
)
assert success is True
mock_client.zrem.assert_called_once_with(
"concurrent_calls:1", "test_slot_123"
)
@pytest.mark.asyncio
async def test_get_concurrent_count(self):
"""Test getting current concurrent call count."""
rate_limiter = RateLimiter()
# Mock Redis client
with patch.object(rate_limiter, "_get_redis") as mock_redis:
mock_client = AsyncMock()
mock_client.zremrangebyscore = AsyncMock() # Cleanup stale entries
mock_client.zcard = AsyncMock(return_value=5) # 5 active calls
mock_redis.return_value = mock_client
# Get count
count = await rate_limiter.get_concurrent_count(organization_id=1)
assert count == 5
mock_client.zremrangebyscore.assert_called_once()
mock_client.zcard.assert_called_once()
@pytest.mark.asyncio
async def test_stale_entry_cleanup(self):
"""Test that stale entries are cleaned up automatically."""
rate_limiter = RateLimiter()
# Mock Redis client
with patch.object(rate_limiter, "_get_redis") as mock_redis:
mock_client = AsyncMock()
# Mock eval to simulate cleanup in Lua script
mock_client.eval = AsyncMock(return_value="new_slot_123")
mock_redis.return_value = mock_client
# Try to acquire slot (which should trigger cleanup)
slot_id = await rate_limiter.try_acquire_concurrent_slot(
organization_id=1, max_concurrent=20
)
assert slot_id == "new_slot_123"
# Verify Lua script was called with proper stale cutoff
call_args = mock_client.eval.call_args[0]
lua_script = call_args[0]
assert "ZREMRANGEBYSCORE" in lua_script # Cleanup command in script
@pytest.mark.asyncio
async def test_workflow_slot_mapping_operations(self):
"""Test storing, retrieving, and deleting workflow slot mappings."""
rate_limiter = RateLimiter()
# Mock Redis client
with patch.object(rate_limiter, "_get_redis") as mock_redis:
mock_client = AsyncMock()
mock_client.hset = AsyncMock(return_value=1)
mock_client.expire = AsyncMock(return_value=True)
mock_client.hgetall = AsyncMock(
return_value={"org_id": "1", "slot_id": "test_slot_123"}
)
mock_client.delete = AsyncMock(return_value=1)
mock_redis.return_value = mock_client
# Test storing mapping
success = await rate_limiter.store_workflow_slot_mapping(
workflow_run_id=999, organization_id=1, slot_id="test_slot_123"
)
assert success is True
mock_client.hset.assert_called_once()
mock_client.expire.assert_called_once()
# Test retrieving mapping
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run_id=999)
assert mapping == (1, "test_slot_123")
mock_client.hgetall.assert_called_once_with("workflow_slot_mapping:999")
# Test deleting mapping
deleted = await rate_limiter.delete_workflow_slot_mapping(
workflow_run_id=999
)
assert deleted is True
mock_client.delete.assert_called_once_with("workflow_slot_mapping:999")
class TestCampaignCallDispatcher:
"""Test suite for CampaignCallDispatcher with concurrent limiting."""
@pytest.mark.asyncio
async def test_dispatch_call_waits_for_slot(self):
"""Test that dispatch_call waits for available slot."""
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
dispatcher = CampaignCallDispatcher()
# Mock dependencies
mock_campaign = MagicMock(
organization_id=1, workflow_id=123, id=456, created_by=789
)
mock_queued_run = MagicMock(
id=111, context_variables={"phone_number": "+1234567890"}
)
# Mock rate limiter to simulate waiting
slot_acquired = False
call_count = 0
async def mock_try_acquire(org_id, max_concurrent):
nonlocal slot_acquired, call_count
call_count += 1
if call_count > 2: # Succeed on third try
slot_acquired = True
return "test_slot_123"
return None
with patch(
"api.services.campaign.call_dispatcher.rate_limiter"
) as mock_limiter:
mock_limiter.try_acquire_concurrent_slot = AsyncMock(
side_effect=mock_try_acquire
)
mock_limiter.release_concurrent_slot = AsyncMock()
mock_limiter.store_workflow_slot_mapping = AsyncMock(return_value=True)
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
mock_db.get_configuration = AsyncMock(return_value=None)
mock_db.get_workflow_by_id = AsyncMock(
return_value=MagicMock(template_context_variables={})
)
mock_db.create_workflow_run = AsyncMock(
return_value=MagicMock(id=999, logs={})
)
with patch.object(
dispatcher.twilio_service, "initiate_call"
) as mock_twilio:
mock_twilio.return_value = {"sid": "test_sid"}
# Dispatch call (should wait and retry)
workflow_run = await dispatcher.dispatch_call(
mock_queued_run, mock_campaign
)
assert workflow_run is not None
assert slot_acquired is True
assert call_count == 3 # Tried 3 times
assert mock_limiter.try_acquire_concurrent_slot.call_count == 3
@pytest.mark.asyncio
async def test_dispatch_call_stores_slot_mapping(self):
"""Test that dispatch_call stores slot mapping in Redis."""
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
dispatcher = CampaignCallDispatcher()
# Mock dependencies
mock_campaign = MagicMock(
organization_id=1, workflow_id=123, id=456, created_by=789
)
mock_queued_run = MagicMock(
id=111, context_variables={"phone_number": "+1234567890"}
)
with patch(
"api.services.campaign.call_dispatcher.rate_limiter"
) as mock_limiter:
mock_limiter.try_acquire_concurrent_slot = AsyncMock(
return_value="test_slot_123"
)
mock_limiter.store_workflow_slot_mapping = AsyncMock(return_value=True)
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
mock_db.get_configuration = AsyncMock(return_value=None)
mock_db.get_workflow_by_id = AsyncMock(
return_value=MagicMock(template_context_variables={})
)
mock_db.create_workflow_run = AsyncMock(
return_value=MagicMock(id=999, logs={})
)
with patch.object(
dispatcher.twilio_service, "initiate_call"
) as mock_twilio:
mock_twilio.return_value = {"sid": "test_sid"}
# Dispatch call
workflow_run = await dispatcher.dispatch_call(
mock_queued_run, mock_campaign
)
# Verify slot mapping was stored
mock_limiter.store_workflow_slot_mapping.assert_called_once_with(
999, 1, "test_slot_123"
)
@pytest.mark.asyncio
async def test_org_specific_concurrent_limit(self):
"""Test that organization-specific concurrent limit is used."""
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
dispatcher = CampaignCallDispatcher()
# Mock db_client to return org-specific limit
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
mock_config = MagicMock(value={"value": 10}) # Org limit is 10
mock_db.get_configuration = AsyncMock(return_value=mock_config)
# Get org limit
limit = await dispatcher.get_org_concurrent_limit(organization_id=1)
assert limit == 10 # Should use org-specific limit
mock_db.get_configuration.assert_called_once_with(
1, OrganizationConfigurationKey.CONCURRENT_CALL_LIMIT.value
)
@pytest.mark.asyncio
async def test_default_concurrent_limit(self):
"""Test that default limit is used when org config not found."""
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
dispatcher = CampaignCallDispatcher()
# Mock db_client to return None (no config)
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
mock_db.get_configuration = AsyncMock(return_value=None)
# Get org limit
limit = await dispatcher.get_org_concurrent_limit(organization_id=1)
assert limit == 20 # Should use default limit
@pytest.mark.asyncio
async def test_release_call_slot(self):
"""Test releasing call slot when workflow completes."""
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
dispatcher = CampaignCallDispatcher()
# Mock rate limiter
with patch(
"api.services.campaign.call_dispatcher.rate_limiter"
) as mock_limiter:
# Mock getting the slot mapping from Redis
mock_limiter.get_workflow_slot_mapping = AsyncMock(
return_value=(1, "test_slot_123")
)
mock_limiter.release_concurrent_slot = AsyncMock(return_value=True)
mock_limiter.delete_workflow_slot_mapping = AsyncMock(return_value=True)
# Release slot
success = await dispatcher.release_call_slot(workflow_run_id=999)
assert success is True
mock_limiter.get_workflow_slot_mapping.assert_called_once_with(999)
mock_limiter.release_concurrent_slot.assert_called_once_with(
1, "test_slot_123"
)
mock_limiter.delete_workflow_slot_mapping.assert_called_once_with(999)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -1,78 +0,0 @@
import pytest
from pydantic import ValidationError
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
from api.services.configuration.merge import merge_user_configurations
from api.services.configuration.registry import (
OpenAILLMService,
)
REAL_KEY = "sk-1234567890abcdef"
def _build_config_with_openai(key: str) -> UserConfiguration:
return UserConfiguration(
llm=OpenAILLMService(api_key=key),
stt=None,
tts=None,
)
def test_mask_key_basic():
masked = mask_key(REAL_KEY)
# Should reveal only last 4 chars
assert masked.endswith(REAL_KEY[-4:])
assert set(masked[:-4]) == {"*"}
assert len(masked) == len(REAL_KEY)
# is_mask_of round-trip
assert is_mask_of(masked, REAL_KEY)
def test_mask_user_config_masks_api_keys():
cfg = _build_config_with_openai(REAL_KEY)
dumped = mask_user_config(cfg)
assert dumped["llm"]["api_key"].endswith(REAL_KEY[-4:])
assert dumped["llm"]["api_key"].startswith("*" * (len(REAL_KEY) - 4))
def test_merge_preserves_key_when_mask_sent():
existing = _build_config_with_openai(REAL_KEY)
incoming_partial = {
"llm": {
"provider": "openai",
"model": existing.llm.model,
"api_key": mask_key(REAL_KEY), # masked placeholder
}
}
merged = merge_user_configurations(existing, incoming_partial)
assert merged.llm.api_key == REAL_KEY # key preserved
def test_merge_replaces_key_when_new_key_provided():
existing = _build_config_with_openai(REAL_KEY)
new_key = "sk-replaced-9999"
incoming_partial = {
"llm": {
"provider": "openai",
"model": existing.llm.model,
"api_key": new_key,
}
}
merged = merge_user_configurations(existing, incoming_partial)
assert merged.llm.api_key == new_key
def test_merge_drops_old_key_when_provider_changes():
existing = _build_config_with_openai(REAL_KEY)
incoming_partial = {
"llm": {
"provider": "groq",
"model": "llama-3.3-70b-versatile",
# api_key intentionally absent should NOT inherit old key
}
}
with pytest.raises(ValidationError):
merge_user_configurations(existing, incoming_partial)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,512 @@
"""Integration tests for CustomToolManager with update_llm_context.
This module tests the full flow of:
1. CustomToolManager fetching and converting tool schemas
2. update_llm_context setting those tools on the LLM context
3. Verifying the context is properly configured for LLM generation
"""
from dataclasses import dataclass
from typing import Any, Dict
from unittest.mock import AsyncMock, Mock, patch
import pytest
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from api.services.workflow.pipecat_engine_utils import (
get_function_schema,
update_llm_context,
)
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.processors.aggregators.llm_context import LLMContext
@dataclass
class MockToolModel:
"""Mock tool model for testing."""
tool_uuid: str
name: str
description: str
definition: Dict[str, Any]
class TestCustomToolManagerContextIntegration:
"""Integration tests for CustomToolManager with LLMContext."""
@pytest.fixture
def mock_engine(self):
"""Create a mock PipecatEngine."""
engine = Mock()
engine._workflow_run_id = 1
engine._call_context_vars = {"customer_name": "John Doe"}
engine.llm = Mock()
engine.llm.register_function = Mock()
return engine
@pytest.fixture
def sample_tools(self):
"""Create sample mock tools for testing."""
return [
MockToolModel(
tool_uuid="weather-uuid-123",
name="Get Weather",
description="Get current weather for a location",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "GET",
"url": "https://api.weather.com/current",
"parameters": [
{
"name": "location",
"type": "string",
"description": "City name (e.g., San Francisco, CA)",
"required": True,
},
{
"name": "units",
"type": "string",
"description": "Temperature units: celsius or fahrenheit",
"required": False,
},
],
},
},
),
MockToolModel(
tool_uuid="booking-uuid-456",
name="Book Appointment",
description="Book an appointment for the customer",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "POST",
"url": "https://api.example.com/appointments",
"parameters": [
{
"name": "customer_name",
"type": "string",
"description": "Customer's full name",
"required": True,
},
{
"name": "date",
"type": "string",
"description": "Appointment date (YYYY-MM-DD)",
"required": True,
},
{
"name": "time",
"type": "string",
"description": "Appointment time (HH:MM)",
"required": True,
},
{
"name": "notes",
"type": "string",
"description": "Additional notes",
"required": False,
},
],
},
},
),
MockToolModel(
tool_uuid="lookup-uuid-789",
name="Customer Lookup",
description="Look up customer information by phone number",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "GET",
"url": "https://api.example.com/customers/lookup",
"parameters": [
{
"name": "phone",
"type": "string",
"description": "Customer phone number",
"required": True,
},
],
},
},
),
]
@pytest.mark.asyncio
async def test_get_tool_schemas_and_update_context(self, mock_engine, sample_tools):
"""Test fetching tool schemas via CustomToolManager and updating LLM context."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
# Get tool schemas via CustomToolManager - now returns FunctionSchema objects
tool_uuids = ["weather-uuid-123", "booking-uuid-456", "lookup-uuid-789"]
schemas = await manager.get_tool_schemas(tool_uuids)
# Verify schemas were returned as FunctionSchema objects
assert len(schemas) == 3
assert all(isinstance(s, FunctionSchema) for s in schemas)
# Create context with conversation history
context = LLMContext()
context.set_messages(
[
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "I need to check the weather and book an appointment.",
},
{
"role": "assistant",
"content": "I can help with both. Where would you like to check the weather?",
},
{"role": "user", "content": "San Francisco"},
]
)
# Update context with new system message and tools
# Now we can pass schemas directly since they're FunctionSchema objects
new_system = {
"role": "system",
"content": "You are a scheduling assistant with access to weather and booking tools.",
}
update_llm_context(context, new_system, schemas)
# Verify context was updated correctly
messages = context.messages
assert len(messages) == 4
assert (
messages[0]["content"]
== "You are a scheduling assistant with access to weather and booking tools."
)
assert messages[1]["role"] == "user"
assert messages[3]["content"] == "San Francisco"
# Verify tools were set
tools = context.tools
assert tools is not None
assert len(tools.standard_tools) == 3
# Verify tool names
tool_names = {t.name for t in tools.standard_tools}
assert tool_names == {
"get_weather",
"book_appointment",
"customer_lookup",
}
@pytest.mark.asyncio
async def test_tool_schemas_have_correct_properties(
self, mock_engine, sample_tools
):
"""Test that tool schemas from CustomToolManager have correct parameter properties."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
schemas = await manager.get_tool_schemas(
["weather-uuid-123", "booking-uuid-456"]
)
# Find the booking schema - now using FunctionSchema attributes
booking_schema = next(
s for s in schemas if s.name == "book_appointment"
)
# Verify parameter properties
assert "customer_name" in booking_schema.properties
assert "date" in booking_schema.properties
assert "time" in booking_schema.properties
assert "notes" in booking_schema.properties
# Verify types
assert booking_schema.properties["customer_name"]["type"] == "string"
assert booking_schema.properties["date"]["type"] == "string"
# Verify required
assert "customer_name" in booking_schema.required
assert "date" in booking_schema.required
assert "time" in booking_schema.required
assert "notes" not in booking_schema.required
@pytest.mark.asyncio
async def test_context_update_with_builtin_and_custom_tools(
self, mock_engine, sample_tools
):
"""Test updating context with both built-in and custom tools."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(
return_value=[sample_tools[0]]
) # Just weather
# Get custom tool schemas - returns FunctionSchema objects
custom_schemas = await manager.get_tool_schemas(["weather-uuid-123"])
# Create built-in function schemas (like calculator, timezone)
builtin_functions = [
get_function_schema(
"safe_calculator",
"Evaluate a mathematical expression safely",
properties={
"expression": {
"type": "string",
"description": "Mathematical expression to evaluate",
}
},
required=["expression"],
),
get_function_schema(
"get_current_time",
"Get the current time in a timezone",
properties={
"timezone": {
"type": "string",
"description": "Timezone name (e.g., America/New_York)",
}
},
required=["timezone"],
),
]
# Combine built-in and custom functions - both are FunctionSchema objects
all_functions = builtin_functions + custom_schemas
# Update context
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old prompt"}])
new_system = {
"role": "system",
"content": "Assistant with calculator and weather tools",
}
update_llm_context(context, new_system, all_functions)
# Verify all tools are present
tools = context.tools
assert len(tools.standard_tools) == 3
tool_names = {t.name for t in tools.standard_tools}
assert "safe_calculator" in tool_names
assert "get_current_time" in tool_names
assert "get_weather" in tool_names
@pytest.mark.asyncio
async def test_tools_cached_after_first_fetch(self, mock_engine, sample_tools):
"""Test that CustomToolManager caches tools after first fetch."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]])
# First fetch
await manager.get_tool_schemas(["weather-uuid-123"])
# Verify tool is cached (cache stores raw schema dict, not FunctionSchema)
cached = manager.get_cached_tool("get_weather")
assert cached is not None
tool, raw_schema = cached
assert tool.tool_uuid == "weather-uuid-123"
assert raw_schema["function"]["name"] == "get_weather"
@pytest.mark.asyncio
async def test_context_preserves_function_call_history(
self, mock_engine, sample_tools
):
"""Test that update_llm_context preserves function call messages in history."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]])
# Get schemas - returns FunctionSchema objects
schemas = await manager.get_tool_schemas(["weather-uuid-123"])
# Create context with function call history
context = LLMContext()
context.set_messages(
[
{"role": "system", "content": "Old system prompt"},
{"role": "user", "content": "What's the weather in NYC?"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "New York, NY"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_123",
"content": '{"temperature": 72, "condition": "sunny"}',
},
{
"role": "assistant",
"content": "The weather in NYC is 72°F and sunny!",
},
]
)
new_system = {"role": "system", "content": "Updated weather assistant"}
update_llm_context(context, new_system, schemas)
messages = context.messages
# System + user + assistant(tool_call) + tool + assistant = 5
assert len(messages) == 5
# Verify function call messages are preserved
tool_call_msg = messages[2]
assert tool_call_msg["role"] == "assistant"
assert "tool_calls" in tool_call_msg
tool_result_msg = messages[3]
assert tool_result_msg["role"] == "tool"
assert tool_result_msg["tool_call_id"] == "call_123"
@pytest.mark.asyncio
async def test_empty_tool_list_does_not_set_tools(self, mock_engine):
"""Test that empty tool list doesn't set tools on context."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[])
schemas = await manager.get_tool_schemas([])
assert schemas == []
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old"}])
new_system = {"role": "system", "content": "No tools available"}
update_llm_context(context, new_system, [])
# Context should have updated message but no tools set
assert context.messages[0]["content"] == "No tools available"
@pytest.mark.asyncio
async def test_numeric_and_boolean_parameter_types(self, mock_engine):
"""Test that numeric and boolean parameter types are correctly handled."""
tool_with_types = MockToolModel(
tool_uuid="order-uuid",
name="Place Order",
description="Place an order for items",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "POST",
"url": "https://api.example.com/orders",
"parameters": [
{
"name": "item_id",
"type": "string",
"description": "Item identifier",
"required": True,
},
{
"name": "quantity",
"type": "number",
"description": "Number of items",
"required": True,
},
{
"name": "express_shipping",
"type": "boolean",
"description": "Use express shipping",
"required": False,
},
],
},
},
)
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[tool_with_types])
# Get schemas - returns FunctionSchema objects
schemas = await manager.get_tool_schemas(["order-uuid"])
schema = schemas[0]
# Verify types using FunctionSchema attributes
assert schema.properties["item_id"]["type"] == "string"
assert schema.properties["quantity"]["type"] == "number"
assert schema.properties["express_shipping"]["type"] == "boolean"
# Update context - pass schema directly
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old"}])
update_llm_context(
context, {"role": "system", "content": "Order assistant"}, schemas
)
# Verify tool was set with correct types
tool = context.tools.standard_tools[0]
assert tool.name == "place_order"
assert tool.properties["quantity"]["type"] == "number"
assert tool.properties["express_shipping"]["type"] == "boolean"

View file

@ -1,33 +0,0 @@
import os
import uuid
import pytest
from api.db.user_client import UserClient
from api.services.configuration.registry import ServiceProviders
@pytest.mark.asyncio
async def test_default_configuration_created(db_session):
# Set env variable for openai to simulate availability of default key
os.environ["OPENAI_API_KEY"] = "sk-test-openai-key"
# Ensure deepgram env variable absent to focus test
os.environ.pop("DEEPGRAM_API_KEY", None)
# Generate a unique (random) provider user ID for each test run
test_provider_user_id = f"provider_user_{uuid.uuid4().hex}"
user_client: UserClient = db_session # db_session fixture yields the client
user_model = await user_client.get_or_create_user_by_provider_id(
test_provider_user_id
)
config = await user_client.get_user_configurations(user_model.id)
assert config.llm is not None, "LLM config should be created when env key present"
assert config.llm.provider == ServiceProviders.OPENAI
assert config.llm.api_key == "sk-test-openai-key"
# Cleanup / restore env variable side-effects
os.environ.pop("OPENAI_API_KEY", None)

View file

@ -1,122 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from api.services.workflow.disposition_mapper import (
apply_disposition_mapping,
get_organization_id_from_workflow_run,
)
@pytest.mark.asyncio
async def test_apply_disposition_mapping_with_valid_mapping():
"""Test disposition mapping with valid configuration."""
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
# Mock disposition mapping configuration
mock_db_client.get_configuration_value = AsyncMock(
return_value={
"XFER": "TRANSFERRED",
"ND": "NOT_QUALIFIED",
"user_hangup": "HANGUP",
}
)
# Test mapping exists
result = await apply_disposition_mapping("XFER", 1)
assert result == "TRANSFERRED"
# Test mapping doesn't exist
result = await apply_disposition_mapping("UNKNOWN", 1)
assert result == "UNKNOWN"
# Verify db_client was called correctly
mock_db_client.get_configuration_value.assert_called_with(
1, "DISPOSITION_CODE_MAPPING", default={}
)
@pytest.mark.asyncio
async def test_apply_disposition_mapping_no_organization_id():
"""Test disposition mapping with no organization ID."""
# Should return original value
result = await apply_disposition_mapping("XFER", None)
assert result == "XFER"
@pytest.mark.asyncio
async def test_apply_disposition_mapping_empty_value():
"""Test disposition mapping with empty value."""
# Should return original empty value
result = await apply_disposition_mapping("", 1)
assert result == ""
@pytest.mark.asyncio
async def test_apply_disposition_mapping_error_handling():
"""Test disposition mapping handles errors gracefully."""
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
# Mock database error
mock_db_client.get_configuration_value = AsyncMock(
side_effect=Exception("Database error")
)
# Should return original value on error
result = await apply_disposition_mapping("XFER", 1)
assert result == "XFER"
@pytest.mark.asyncio
async def test_get_organization_id_from_workflow_run():
"""Test getting organization ID from workflow run ID."""
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
# Mock workflow run with organization
mock_workflow_run = MagicMock()
mock_workflow_run.workflow.user.selected_organization_id = 123
mock_db_client.get_workflow_run_by_id = AsyncMock(
return_value=mock_workflow_run
)
result = await get_organization_id_from_workflow_run(1)
assert result == 123
# Verify db_client was called correctly
mock_db_client.get_workflow_run_by_id.assert_called_once_with(1)
@pytest.mark.asyncio
async def test_get_organization_id_no_workflow_run():
"""Test getting organization ID when workflow run doesn't exist."""
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
# Mock no workflow run found
mock_db_client.get_workflow_run_by_id = AsyncMock(return_value=None)
result = await get_organization_id_from_workflow_run(1)
assert result is None
@pytest.mark.asyncio
async def test_get_organization_id_no_user():
"""Test getting organization ID when workflow has no user."""
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
# Mock workflow run with no user
mock_workflow_run = MagicMock()
mock_workflow_run.workflow.user = None
mock_db_client.get_workflow_run_by_id = AsyncMock(
return_value=mock_workflow_run
)
result = await get_organization_id_from_workflow_run(1)
assert result is None
@pytest.mark.asyncio
async def test_get_organization_id_error_handling():
"""Test getting organization ID handles errors gracefully."""
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
# Mock database error
mock_db_client.get_workflow_run_by_id = AsyncMock(
side_effect=Exception("Database error")
)
result = await get_organization_id_from_workflow_run(1)
assert result is None

View file

@ -1,370 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pipecat.utils.enums import EndTaskReason
from api.services.pipecat.event_handlers import register_transport_event_handlers
@pytest.fixture
def mock_dependencies():
"""Create mock dependencies for event handlers."""
# Store registered handlers
registered_handlers = {}
def mock_event_handler(event_name):
def decorator(func):
registered_handlers[event_name] = func
return func
return decorator
mock_transport = MagicMock()
mock_transport.event_handler = mock_event_handler
mock_task = MagicMock()
mock_task.cancel = AsyncMock()
mock_engine = MagicMock()
mock_engine.initialize = AsyncMock()
mock_engine.cleanup = AsyncMock()
mock_audio_buffer = MagicMock()
mock_audio_buffer.start_recording = AsyncMock()
mock_audio_buffer.stop_recording = AsyncMock()
mock_usage_metrics_aggregator = MagicMock()
mock_usage_metrics_aggregator.get_all_usage_metrics_serialized = MagicMock(
return_value={"test": "metrics"}
)
return {
"transport": mock_transport,
"workflow_run_id": 123,
"audio_buffer": mock_audio_buffer,
"task": mock_task,
"engine": mock_engine,
"usage_metrics_aggregator": mock_usage_metrics_aggregator,
"audio_synchronizer": None,
"registered_handlers": registered_handlers,
}
@pytest.mark.asyncio
async def test_transport_disconnect_reason_mapping(mock_dependencies):
"""Test that transport_disconnect_reason is mapped when no engine disconnect reason exists."""
# Register event handlers
register_transport_event_handlers(
transport=mock_dependencies["transport"],
workflow_run_id=mock_dependencies["workflow_run_id"],
audio_buffer=mock_dependencies["audio_buffer"],
task=mock_dependencies["task"],
engine=mock_dependencies["engine"],
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
audio_synchronizer=mock_dependencies["audio_synchronizer"],
)
# Get the on_client_disconnected handler
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
# Mock engine with no call disposition
mock_dependencies["engine"].get_call_disposition.return_value = None
mock_dependencies["engine"].get_gathered_context.return_value = {
"agent_name": "Alex"
}
# Mock the disposition mapper functions
with patch(
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
new_callable=AsyncMock,
) as mock_get_org_id:
with patch(
"api.services.pipecat.event_handlers.apply_disposition_mapping",
new_callable=AsyncMock,
) as mock_apply_mapping:
with patch(
"api.services.pipecat.event_handlers.db_client"
) as mock_db_client:
with patch(
"api.services.pipecat.event_handlers.enqueue_job"
) as mock_enqueue:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock call duration for user_hangup logic
mock_dependencies[
"usage_metrics_aggregator"
].get_call_duration.return_value = 15
# Mock disposition mapping
async def apply_mapping_side_effect(value, org_id):
return {
"NIBP": "NOT_INTERESTED_BUSINESS_PURPOSE",
"user_qualified": "QUALIFIED",
}.get(value, value)
mock_apply_mapping.side_effect = apply_mapping_side_effect
# Mock database operations
mock_workflow_run = MagicMock()
mock_workflow_run.id = 123
mock_workflow_run.workflow_id = 1
mock_workflow_run.organization_id = 1
mock_workflow_run.gathered_context = {}
mock_db_client.get_workflow_run_by_id = AsyncMock(
return_value=mock_workflow_run
)
mock_db_client.update_workflow_run = AsyncMock()
# Call handler with transport_disconnect_reason
await handler(
mock_dependencies["transport"],
participant=None,
transport_disconnect_reason="user_hangup",
)
# Verify disposition mapping was applied with NIBP (since duration > 10)
mock_apply_mapping.assert_called_once_with("NIBP", 1)
# Verify database was updated with mapped value
mock_db_client.update_workflow_run.assert_called_once()
call_args = mock_db_client.update_workflow_run.call_args
assert (
call_args[1]["gathered_context"]["mapped_call_disposition"]
== "NOT_INTERESTED_BUSINESS_PURPOSE"
)
# Verify task was cancelled (no engine disconnect reason)
mock_dependencies["task"].cancel.assert_called_once()
@pytest.mark.asyncio
async def test_transport_disconnect_reason_user_hangup_short_call(mock_dependencies):
"""Test that user_hangup with short call duration is mapped to HU."""
# Register event handlers
register_transport_event_handlers(
transport=mock_dependencies["transport"],
workflow_run_id=mock_dependencies["workflow_run_id"],
audio_buffer=mock_dependencies["audio_buffer"],
task=mock_dependencies["task"],
engine=mock_dependencies["engine"],
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
audio_synchronizer=mock_dependencies["audio_synchronizer"],
)
# Get the on_client_disconnected handler
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
# Mock engine with no call disposition
mock_dependencies["engine"].get_call_disposition.return_value = None
mock_dependencies["engine"].get_gathered_context.return_value = {
"agent_name": "Alex"
}
# Mock the disposition mapper functions
with patch(
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
new_callable=AsyncMock,
) as mock_get_org_id:
with patch(
"api.services.pipecat.event_handlers.apply_disposition_mapping",
new_callable=AsyncMock,
) as mock_apply_mapping:
with patch(
"api.services.pipecat.event_handlers.db_client"
) as mock_db_client:
with patch(
"api.services.pipecat.event_handlers.enqueue_job"
) as mock_enqueue:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock call duration for user_hangup logic (< 10 seconds)
mock_dependencies[
"usage_metrics_aggregator"
].get_call_duration.return_value = 5
# Mock disposition mapping
mock_apply_mapping.return_value = "HANGUP"
# Mock database operations
mock_workflow_run = MagicMock()
mock_workflow_run.id = 123
mock_workflow_run.workflow_id = 1
mock_workflow_run.organization_id = 1
mock_workflow_run.gathered_context = {}
mock_db_client.get_workflow_run_by_id = AsyncMock(
return_value=mock_workflow_run
)
mock_db_client.update_workflow_run = AsyncMock()
# Call handler with transport_disconnect_reason
await handler(
mock_dependencies["transport"],
participant=None,
transport_disconnect_reason="user_hangup",
)
# Verify disposition mapping was applied with HU (since duration < 10)
mock_apply_mapping.assert_called_once_with("HU", 1)
# Verify database was updated with mapped value
mock_db_client.update_workflow_run.assert_called_once()
call_args = mock_db_client.update_workflow_run.call_args
assert (
call_args[1]["gathered_context"]["mapped_call_disposition"]
== "HANGUP"
)
# Verify task was cancelled (no engine disconnect reason)
mock_dependencies["task"].cancel.assert_called_once()
@pytest.mark.asyncio
async def test_engine_disconnect_reason_takes_precedence(mock_dependencies):
"""Test that engine disconnect reason takes precedence and is not mapped."""
# Register event handlers
register_transport_event_handlers(
transport=mock_dependencies["transport"],
workflow_run_id=mock_dependencies["workflow_run_id"],
audio_buffer=mock_dependencies["audio_buffer"],
task=mock_dependencies["task"],
engine=mock_dependencies["engine"],
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
audio_synchronizer=mock_dependencies["audio_synchronizer"],
)
# Get the on_client_disconnected handler
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
# Mock engine with call disposition
mock_dependencies["engine"].get_call_disposition.return_value = "user_qualified"
mock_dependencies["engine"].get_gathered_context.return_value = {
"agent_name": "Alex"
}
# Mock the disposition mapper functions
with patch(
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
new_callable=AsyncMock,
) as mock_get_org_id:
with patch(
"api.services.pipecat.event_handlers.apply_disposition_mapping",
new_callable=AsyncMock,
) as mock_apply_mapping:
with patch(
"api.services.pipecat.event_handlers.db_client"
) as mock_db_client:
with patch(
"api.services.pipecat.event_handlers.enqueue_job"
) as mock_enqueue:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock disposition mapping for engine's reason
mock_apply_mapping.return_value = "QUALIFIED"
# Mock database operations
mock_workflow_run = MagicMock()
mock_workflow_run.id = 123
mock_workflow_run.workflow_id = 1
mock_workflow_run.organization_id = 1
mock_workflow_run.gathered_context = {}
mock_db_client.get_workflow_run_by_id = AsyncMock(
return_value=mock_workflow_run
)
mock_db_client.update_workflow_run = AsyncMock()
# Call handler with transport_disconnect_reason
await handler(
mock_dependencies["transport"],
participant=None,
transport_disconnect_reason="user_hangup",
)
# Verify disposition mapping was called with engine's reason
mock_apply_mapping.assert_called_once_with("user_qualified", 1)
# Verify database was updated with mapped value
mock_db_client.update_workflow_run.assert_called_once()
call_args = mock_db_client.update_workflow_run.call_args
assert (
call_args[1]["gathered_context"]["mapped_call_disposition"]
== "QUALIFIED"
)
# Verify task was NOT cancelled (engine disconnect reason exists)
mock_dependencies["task"].cancel.assert_not_called()
@pytest.mark.asyncio
async def test_no_disconnect_reason_uses_unknown(mock_dependencies):
"""Test that when no disconnect reason is provided, UNKNOWN is used."""
# Register event handlers
register_transport_event_handlers(
transport=mock_dependencies["transport"],
workflow_run_id=mock_dependencies["workflow_run_id"],
audio_buffer=mock_dependencies["audio_buffer"],
task=mock_dependencies["task"],
engine=mock_dependencies["engine"],
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
audio_synchronizer=mock_dependencies["audio_synchronizer"],
)
# Get the on_client_disconnected handler
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
# Mock engine with no call disposition
mock_dependencies["engine"].get_call_disposition.return_value = None
mock_dependencies["engine"].get_gathered_context.return_value = {
"agent_name": "Alex"
}
with patch(
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run"
) as mock_get_org_id:
with patch(
"api.services.pipecat.event_handlers.apply_disposition_mapping"
) as mock_apply_mapping:
with patch(
"api.services.pipecat.event_handlers.db_client"
) as mock_db_client:
with patch(
"api.services.pipecat.event_handlers.enqueue_job"
) as mock_enqueue:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock disposition mapping - should return UNKNOWN as-is
mock_apply_mapping.return_value = EndTaskReason.UNKNOWN.value
# Mock database operations
mock_workflow_run = MagicMock()
mock_workflow_run.id = 123
mock_workflow_run.workflow_id = 1
mock_workflow_run.organization_id = 1
mock_workflow_run.gathered_context = {}
mock_db_client.get_workflow_run_by_id = AsyncMock(
return_value=mock_workflow_run
)
mock_db_client.update_workflow_run = AsyncMock()
# Call handler without transport_disconnect_reason
await handler(
mock_dependencies["transport"],
participant=None,
transport_disconnect_reason=None,
)
# Verify disposition mapping was called with UNKNOWN
mock_apply_mapping.assert_called_once_with(
EndTaskReason.UNKNOWN.value, 1
)
# Verify database was updated with UNKNOWN
mock_db_client.update_workflow_run.assert_called_once()
call_args = mock_db_client.update_workflow_run.call_args
assert (
call_args[1]["gathered_context"]["mapped_call_disposition"]
== EndTaskReason.UNKNOWN.value
)

View file

@ -1,184 +0,0 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from api.services.pipecat.audio_config import AudioConfig
from api.services.pipecat.event_handlers import (
register_audio_data_handler,
register_transcript_handler,
register_transport_event_handlers,
)
@pytest.mark.asyncio
async def test_transport_handlers_with_in_memory_buffers():
"""Test that transport handlers create and return in-memory buffers."""
# Mock dependencies
transport = MagicMock()
transport.event_handler = lambda event_name: lambda func: func
audio_buffer = AsyncMock()
audio_synchronizer = AsyncMock()
task = AsyncMock()
engine = AsyncMock()
engine.get_call_disposition.return_value = None
engine.get_gathered_context.return_value = {}
usage_metrics_aggregator = AsyncMock()
usage_metrics_aggregator.get_call_duration.return_value = 30
usage_metrics_aggregator.get_all_usage_metrics_serialized.return_value = {}
# Create test audio config
audio_config = AudioConfig(
transport_in_sample_rate=16000,
transport_out_sample_rate=16000,
pipeline_sample_rate=16000,
)
# Register handlers
audio_buf, transcript_buf = register_transport_event_handlers(
transport=transport,
workflow_run_id=123,
audio_buffer=audio_buffer,
task=task,
engine=engine,
usage_metrics_aggregator=usage_metrics_aggregator,
audio_synchronizer=audio_synchronizer,
audio_config=audio_config,
)
# Verify buffers were created with correct configuration
assert audio_buf is not None
assert transcript_buf is not None
assert audio_buf._workflow_run_id == 123
assert audio_buf._sample_rate == 16000
assert audio_buf._num_channels == 1
assert transcript_buf._workflow_run_id == 123
@pytest.mark.asyncio
async def test_audio_handler_with_in_memory_buffer():
"""Test audio handler uses in-memory buffer when provided."""
# Mock audio synchronizer
audio_synchronizer = MagicMock()
handlers = {}
def mock_event_handler(event_name):
def decorator(func):
handlers[event_name] = func
return func
return decorator
audio_synchronizer.event_handler = mock_event_handler
# Mock in-memory buffer
in_memory_buffer = AsyncMock()
# Register handler with buffer
register_audio_data_handler(
audio_synchronizer, workflow_run_id=123, in_memory_buffer=in_memory_buffer
)
# Test the handler
assert "on_merged_audio" in handlers
handler = handlers["on_merged_audio"]
# Call handler with test data
test_pcm = b"test_audio_data"
await handler(None, test_pcm, 16000, 1)
# Verify buffer was used
in_memory_buffer.append.assert_called_once_with(test_pcm)
@pytest.mark.asyncio
async def test_transcript_handler_with_in_memory_buffer():
"""Test transcript handler uses in-memory buffer when provided."""
# Mock transcript processor
transcript = MagicMock()
handlers = {}
def mock_event_handler(event_name):
def decorator(func):
handlers[event_name] = func
return func
return decorator
transcript.event_handler = mock_event_handler
# Mock in-memory buffer
in_memory_buffer = AsyncMock()
# Register handler with buffer
register_transcript_handler(
transcript, workflow_run_id=456, in_memory_buffer=in_memory_buffer
)
# Create test frame
test_frame = MagicMock()
test_frame.messages = [
MagicMock(timestamp="00:00:01", role="user", content="Hello"),
MagicMock(timestamp="00:00:02", role="assistant", content="Hi there"),
]
# Test the handler
handler = handlers["on_transcript_update"]
await handler(None, test_frame)
# Verify buffer was used with correct format
expected_text = "[00:00:01] user: Hello\n[00:00:02] assistant: Hi there\n"
in_memory_buffer.append.assert_called_once_with(expected_text)
@pytest.mark.asyncio
async def test_audio_config_sample_rates():
"""Test that different audio configs result in correct sample rates."""
# Mock dependencies
transport = MagicMock()
transport.event_handler = lambda event_name: lambda func: func
audio_buffer = AsyncMock()
audio_synchronizer = AsyncMock()
task = AsyncMock()
engine = AsyncMock()
engine.get_call_disposition.return_value = None
engine.get_gathered_context.return_value = {}
usage_metrics_aggregator = AsyncMock()
usage_metrics_aggregator.get_all_usage_metrics_serialized.return_value = {}
# Test with 8kHz audio config (e.g., for Stasis/Twilio)
audio_config_8k = AudioConfig(
transport_in_sample_rate=8000,
transport_out_sample_rate=8000,
pipeline_sample_rate=8000,
)
audio_buf_8k, _ = register_transport_event_handlers(
transport=transport,
workflow_run_id=456,
audio_buffer=audio_buffer,
task=task,
engine=engine,
usage_metrics_aggregator=usage_metrics_aggregator,
audio_synchronizer=audio_synchronizer,
audio_config=audio_config_8k,
)
assert audio_buf_8k._sample_rate == 8000
# Test with no audio config (should default to 16kHz)
audio_buf_default, _ = register_transport_event_handlers(
transport=transport,
workflow_run_id=789,
audio_buffer=audio_buffer,
task=task,
engine=engine,
usage_metrics_aggregator=usage_metrics_aggregator,
audio_synchronizer=audio_synchronizer,
audio_config=None,
)
assert audio_buf_default._sample_rate == 16000

View file

@ -1,162 +0,0 @@
"""Test filter functionality."""
from unittest.mock import MagicMock
from api.db.filters import ATTRIBUTE_FIELD_MAPPING, apply_workflow_run_filters
def test_attribute_field_mapping():
"""Test that all required attributes are mapped."""
expected_attributes = [
"dateRange",
"dispositionCode",
"duration",
"status",
"tokenUsage",
"runId",
"workflowId",
"callTags",
"phoneNumber",
]
for attr in expected_attributes:
assert attr in ATTRIBUTE_FIELD_MAPPING, f"Missing mapping for {attr}"
def test_filter_with_explicit_type():
"""Test that filters work with explicit type from UI."""
# Mock query
mock_query = MagicMock()
mock_query.where = MagicMock(return_value=mock_query)
test_cases = [
# Date range filter
{
"filters": [
{
"attribute": "dateRange",
"type": "dateRange",
"value": {"from": "2024-01-01", "to": "2024-01-31"},
}
],
},
# Multi-select filter
{
"filters": [
{
"attribute": "dispositionCode",
"type": "multiSelect",
"value": {"codes": ["XFER", "HU"]},
}
],
},
# Number range filter
{
"filters": [
{
"attribute": "duration",
"type": "numberRange",
"value": {"min": 60, "max": 300},
}
],
},
# Radio/status filter
{
"filters": [
{
"attribute": "status",
"type": "radio",
"value": {"status": "completed"},
}
],
},
# Number filter
{
"filters": [
{"attribute": "runId", "type": "number", "value": {"value": 123}}
],
},
# Text filter
{
"filters": [
{
"attribute": "phoneNumber",
"type": "text",
"value": {"value": "+1234567890"},
}
],
},
# Tags filter
{
"filters": [
{
"attribute": "callTags",
"type": "tags",
"value": {"codes": ["tag1", "tag2"]},
}
],
},
]
for test_case in test_cases:
result = apply_workflow_run_filters(mock_query, test_case["filters"])
# The function should process the filter without errors
assert result is not None
def test_filter_format_with_type():
"""Test that filters work with attribute, type, and value."""
mock_query = MagicMock()
mock_query.where = MagicMock(return_value=mock_query)
# Test with various filter combinations
filters = [
{
"attribute": "dispositionCode",
"type": "multiSelect",
"value": {"codes": ["NIBP"]},
},
{
"attribute": "duration",
"type": "numberRange",
"value": {"min": 0, "max": 60},
},
{"attribute": "phoneNumber", "type": "text", "value": {"value": "555"}},
]
result = apply_workflow_run_filters(mock_query, filters)
# Should have called where() for applying filters
assert mock_query.where.called
assert result is not None
def test_unknown_attribute_ignored():
"""Test that unknown attributes are safely ignored."""
mock_query = MagicMock()
mock_query.where = MagicMock(return_value=mock_query)
filters = [
{"attribute": "unknownAttribute", "value": {"value": "test"}},
{"attribute": "dispositionCode", "value": {"codes": ["XFER"]}},
]
result = apply_workflow_run_filters(mock_query, filters)
# Should still process the valid filter
assert result is not None
def test_empty_filters():
"""Test that empty filters return the query unchanged."""
mock_query = MagicMock()
result = apply_workflow_run_filters(mock_query, None)
assert result == mock_query
result = apply_workflow_run_filters(mock_query, [])
assert result == mock_query

View file

@ -1,249 +0,0 @@
"""Tests for global prompt functionality in workflow engine."""
from unittest.mock import Mock
import pytest
from pipecat.services.openai.llm import OpenAILLMContext
from api.services.workflow.dto import (
EdgeDataDTO,
NodeDataDTO,
NodeType,
ReactFlowDTO,
RFEdgeDTO,
RFNodeDTO,
)
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
class TestGlobalPrompt:
"""Test suite for global prompt feature."""
@pytest.fixture
def workflow_with_global_node(self):
"""Create a workflow with a global node and test nodes."""
nodes = [
RFNodeDTO(
id="global",
type=NodeType.globalNode,
position={"x": 0, "y": 0},
data=NodeDataDTO(
name="Global Node",
prompt="This is the global context: {{company_name}}",
is_static=False,
),
),
RFNodeDTO(
id="start",
type=NodeType.startNode,
position={"x": 100, "y": 100},
data=NodeDataDTO(
name="Start Call",
prompt="Welcome to our service!",
is_static=False,
is_start=True,
add_global_prompt=True, # Enable global prompt
),
),
RFNodeDTO(
id="agent1",
type=NodeType.agentNode,
position={"x": 200, "y": 200},
data=NodeDataDTO(
name="Agent 1",
prompt="How can I help you today?",
add_global_prompt=False, # Disable global prompt
),
),
RFNodeDTO(
id="agent2",
type=NodeType.agentNode,
position={"x": 300, "y": 300},
data=NodeDataDTO(
name="Agent 2",
prompt="Please provide your details.",
add_global_prompt=True, # Enable global prompt
),
),
RFNodeDTO(
id="end",
type=NodeType.endNode,
position={"x": 400, "y": 400},
data=NodeDataDTO(
name="End Call",
prompt="Thank you for calling!",
is_static=True,
is_end=True,
add_global_prompt=True, # Enable global prompt (but static)
),
),
]
edges = [
RFEdgeDTO(
id="e1",
source="start",
target="agent1",
data=EdgeDataDTO(label="Next", condition="Continue to agent"),
),
RFEdgeDTO(
id="e2",
source="agent1",
target="agent2",
data=EdgeDataDTO(label="Details", condition="Get user details"),
),
RFEdgeDTO(
id="e3",
source="agent2",
target="end",
data=EdgeDataDTO(label="Finish", condition="End the call"),
),
]
flow_dto = ReactFlowDTO(nodes=nodes, edges=edges)
return WorkflowGraph(flow_dto)
@pytest.fixture
def mock_dependencies(self):
"""Create mock dependencies for PipecatEngine initialization."""
return {
"task": Mock(),
"llm": Mock(),
"context": Mock(spec=OpenAILLMContext),
"tts": Mock(),
"transport": Mock(),
"call_context_vars": {"company_name": "Dograh Inc"},
}
@pytest.fixture
def engine(self, mock_dependencies, workflow_with_global_node):
"""Create a PipecatEngine instance with test workflow."""
mock_dependencies["workflow"] = workflow_with_global_node
return PipecatEngine(**mock_dependencies)
@pytest.mark.asyncio
async def test_global_prompt_enabled(self, engine):
"""Test that global prompt is prepended when add_global_prompt is True."""
# Test with start node (add_global_prompt=True)
start_node = engine.workflow.nodes["start"]
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(start_node)
# Global prompt should be included
expected_content = (
"This is the global context: Dograh Inc\n\nWelcome to our service!"
)
assert system_message["content"] == expected_content
assert system_message["role"] == "system"
@pytest.mark.asyncio
async def test_global_prompt_disabled(self, engine):
"""Test that global prompt is not prepended when add_global_prompt is False."""
# Test with agent1 node (add_global_prompt=False)
agent1_node = engine.workflow.nodes["agent1"]
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(agent1_node)
# Global prompt should NOT be included
expected_content = "How can I help you today?"
assert system_message["content"] == expected_content
assert "global context" not in system_message["content"]
@pytest.mark.asyncio
async def test_global_prompt_with_static_node(self, engine):
"""Test that static nodes don't use global prompt in engine (even if enabled)."""
# Static nodes are handled differently - they use TTSSpeakFrame directly
# This test verifies the compose_system_message behavior for completeness
end_node = engine.workflow.nodes["end"]
# Even though add_global_prompt=True, static nodes handle prompts differently
# The _compose_system_message_functions_for_node is still called for consistency
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(end_node)
# For static nodes, the global prompt would still be composed if enabled
expected_content = (
"This is the global context: Dograh Inc\n\nThank you for calling!"
)
assert system_message["content"] == expected_content
@pytest.mark.asyncio
async def test_global_prompt_variable_substitution(self, engine):
"""Test that variables in global prompt are properly substituted."""
agent2_node = engine.workflow.nodes["agent2"]
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(agent2_node)
# Verify variable substitution in global prompt
assert "Dograh Inc" in system_message["content"]
assert "{{company_name}}" not in system_message["content"]
# Full expected content
expected_content = (
"This is the global context: Dograh Inc\n\nPlease provide your details."
)
assert system_message["content"] == expected_content
@pytest.mark.asyncio
async def test_no_global_node_scenario(self, engine):
"""Test behavior when there's no global node in the workflow."""
# Remove global node from workflow
engine.workflow.global_node_id = None
start_node = engine.workflow.nodes["start"]
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(start_node)
# Should only have the node's own prompt
assert system_message["content"] == "Welcome to our service!"
@pytest.mark.asyncio
async def test_empty_global_prompt(self, engine):
"""Test behavior when global prompt is empty."""
# Set global prompt to empty string
engine.workflow.nodes["global"].prompt = ""
start_node = engine.workflow.nodes["start"]
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(start_node)
# Should only have the node's own prompt (empty global prompt is filtered out)
assert system_message["content"] == "Welcome to our service!"
def test_default_add_global_prompt_value(self):
"""Test that add_global_prompt defaults to True in NodeDataDTO."""
node_data = NodeDataDTO(name="Test", prompt="Test prompt")
assert node_data.add_global_prompt is True
@pytest.mark.asyncio
async def test_multiple_prompts_concatenation(self, engine):
"""Test proper concatenation of global and node prompts."""
# Test with agent2 node that has global prompt enabled
agent2_node = engine.workflow.nodes["agent2"]
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(agent2_node)
# Should have global and node prompts concatenated with double newlines
# (extraction prompt is no longer included in system message)
expected_parts = [
"This is the global context: Dograh Inc",
"Please provide your details.",
]
expected_content = "\n\n".join(expected_parts)
assert system_message["content"] == expected_content

View file

@ -1,175 +0,0 @@
"""Unit tests for global prompt functionality - no DB dependencies."""
import sys
from pathlib import Path
# Add the api directory to the Python path
api_path = Path(__file__).parent.parent
sys.path.insert(0, str(api_path))
from services.workflow.dto import (
EdgeDataDTO,
NodeDataDTO,
NodeType,
ReactFlowDTO,
RFEdgeDTO,
RFNodeDTO,
)
from services.workflow.workflow import WorkflowGraph
def test_node_data_dto_default_global_prompt():
"""Test that add_global_prompt defaults to True."""
node_data = NodeDataDTO(name="Test Node", prompt="Test prompt")
assert node_data.add_global_prompt is True
print("✓ NodeDataDTO defaults add_global_prompt to True")
def test_node_data_dto_explicit_global_prompt():
"""Test explicit setting of add_global_prompt."""
# Test with False
node_data_false = NodeDataDTO(
name="Test Node", prompt="Test prompt", add_global_prompt=False
)
assert node_data_false.add_global_prompt is False
# Test with True
node_data_true = NodeDataDTO(
name="Test Node", prompt="Test prompt", add_global_prompt=True
)
assert node_data_true.add_global_prompt is True
print("✓ NodeDataDTO respects explicit add_global_prompt values")
def test_workflow_node_inherits_global_prompt_setting():
"""Test that workflow Node inherits add_global_prompt from NodeDataDTO."""
nodes = [
RFNodeDTO(
id="start",
type=NodeType.startNode,
position={"x": 0, "y": 0},
data=NodeDataDTO(
name="Start",
prompt="Start prompt",
is_start=True,
add_global_prompt=True,
),
),
RFNodeDTO(
id="node1",
type=NodeType.agentNode,
position={"x": 100, "y": 0},
data=NodeDataDTO(
name="Node with global", prompt="Test prompt", add_global_prompt=True
),
),
RFNodeDTO(
id="node2",
type=NodeType.agentNode,
position={"x": 200, "y": 0},
data=NodeDataDTO(
name="Node without global",
prompt="Test prompt",
add_global_prompt=False,
),
),
RFNodeDTO(
id="end",
type=NodeType.endNode,
position={"x": 300, "y": 0},
data=NodeDataDTO(
name="End", prompt="End prompt", is_end=True, add_global_prompt=True
),
),
]
edges = [
RFEdgeDTO(
id="e1",
source="start",
target="node1",
data=EdgeDataDTO(label="Next", condition="Continue"),
),
RFEdgeDTO(
id="e2",
source="node1",
target="node2",
data=EdgeDataDTO(label="Next", condition="Continue"),
),
RFEdgeDTO(
id="e3",
source="node2",
target="end",
data=EdgeDataDTO(label="End", condition="Finish"),
),
]
flow_dto = ReactFlowDTO(nodes=nodes, edges=edges)
workflow = WorkflowGraph(flow_dto)
assert workflow.nodes["start"].add_global_prompt is True
assert workflow.nodes["node1"].add_global_prompt is True
assert workflow.nodes["node2"].add_global_prompt is False
assert workflow.nodes["end"].add_global_prompt is True
print("✓ Workflow nodes correctly inherit add_global_prompt setting")
def test_compose_system_message_respects_global_prompt_flag():
"""Test that system message composition respects add_global_prompt flag."""
# This is a simplified version - in real tests we'd use the full engine
# But this demonstrates the logic
class MockNode:
def __init__(self, add_global_prompt, prompt):
self.add_global_prompt = add_global_prompt
self.prompt = prompt
self.out_edges = []
self.extraction_enabled = False
# Simulate the logic from _compose_system_message_functions_for_node
def compose_message(node, global_prompt):
prompts = []
# Only add global prompt if node.add_global_prompt is True
if global_prompt and node.add_global_prompt:
prompts.append(global_prompt)
prompts.append(node.prompt)
return "\n\n".join(p for p in prompts if p)
global_prompt = "This is the global context"
# Test with add_global_prompt=True
node_with_global = MockNode(add_global_prompt=True, prompt="Node prompt")
message_with = compose_message(node_with_global, global_prompt)
assert message_with == "This is the global context\n\nNode prompt"
# Test with add_global_prompt=False
node_without_global = MockNode(add_global_prompt=False, prompt="Node prompt")
message_without = compose_message(node_without_global, global_prompt)
assert message_without == "Node prompt"
print("✓ System message composition respects add_global_prompt flag")
def test_static_nodes_with_global_prompt():
"""Test static nodes can have add_global_prompt setting."""
static_node_data = NodeDataDTO(
name="Static Node", prompt="Static text", is_static=True, add_global_prompt=True
)
assert static_node_data.is_static is True
assert static_node_data.add_global_prompt is True
print("✓ Static nodes can have add_global_prompt setting")
if __name__ == "__main__":
# Run all tests
test_node_data_dto_default_global_prompt()
test_node_data_dto_explicit_global_prompt()
test_workflow_node_inherits_global_prompt_setting()
test_compose_system_message_respects_global_prompt_flag()
test_static_nodes_with_global_prompt()
print("\n✅ All unit tests passed!")

View file

@ -1,248 +0,0 @@
"""
Test cases for _leave_counter mechanism in transport clients.
This test suite verifies that the _leave_counter prevents premature disconnection
when both input and output transports are using the same client.
"""
import asyncio
from unittest.mock import AsyncMock, Mock
import pytest
from pipecat.frames.frames import EndFrame, StartFrame
from pipecat.transports.network.fastapi_websocket import (
FastAPIWebsocketCallbacks,
FastAPIWebsocketClient,
FastAPIWebsocketParams,
FastAPIWebsocketTransport,
)
from pipecat.transports.network.small_webrtc import SmallWebRTCClient
from api.services.telephony.stasis_rtp_client import StasisRTPClient
class TestLeaveCounterFastAPIWebsocket:
"""Test the _leave_counter mechanism in FastAPIWebsocketClient."""
@pytest.mark.asyncio
async def test_leave_counter_prevents_early_disconnect(self):
"""Test that disconnect only happens when both transports have disconnected."""
# Create mock websocket
mock_websocket = Mock()
mock_websocket.close = AsyncMock()
# Set client_state directly to WebSocketState.CONNECTED value
from starlette.websockets import WebSocketState
mock_websocket.client_state = WebSocketState.CONNECTED
# Create callbacks
callbacks = FastAPIWebsocketCallbacks(
on_client_connected=AsyncMock(),
on_client_disconnected=AsyncMock(),
on_session_timeout=AsyncMock(),
)
# Create client
client = FastAPIWebsocketClient(
mock_websocket, is_binary=False, callbacks=callbacks
)
# Create StartFrame
start_frame = StartFrame()
# Simulate both input and output transports calling setup
await client.setup(start_frame) # Input transport
assert client._leave_counter == 1
await client.setup(start_frame) # Output transport
assert client._leave_counter == 2
# First disconnect - should not actually disconnect
await client.disconnect()
assert client._leave_counter == 1
mock_websocket.close.assert_not_called()
callbacks.on_client_disconnected.assert_not_called()
# Second disconnect - should actually disconnect
await client.disconnect()
assert client._leave_counter == 0
mock_websocket.close.assert_called_once()
callbacks.on_client_disconnected.assert_called_once()
class TestLeaveCounterStasisRTP:
"""Test the _leave_counter mechanism in StasisRTPClient."""
@pytest.mark.asyncio
async def test_leave_counter_prevents_early_disconnect(self):
"""Test that disconnect only happens when both transports have disconnected."""
# Create mock connection
mock_connection = Mock()
mock_connection.is_connected.return_value = True
mock_connection.disconnect = AsyncMock()
mock_connection.notify_sockets_closed = AsyncMock()
# Mock event_handler as a callable that acts as a decorator
def mock_event_handler(event_name):
def decorator(func):
return func
return decorator
mock_connection.event_handler = mock_event_handler
# Create callbacks
from api.services.telephony.stasis_rtp_transport import StasisRTPCallbacks
callbacks = StasisRTPCallbacks(
on_client_connected=AsyncMock(),
on_client_disconnected=AsyncMock(),
on_client_closed=AsyncMock(),
)
# Create client
client = StasisRTPClient(mock_connection, callbacks)
# Create StartFrame
start_frame = StartFrame()
# Simulate both input and output transports calling setup
await client.setup(start_frame) # Input transport
assert client._leave_counter == 1
await client.setup(start_frame) # Output transport
assert client._leave_counter == 2
# First disconnect - should not actually disconnect
await client.disconnect()
assert client._leave_counter == 1
mock_connection.disconnect.assert_not_called()
# Second disconnect - should actually disconnect
await client.disconnect()
assert client._leave_counter == 0
mock_connection.disconnect.assert_called_once()
class TestLeaveCounterSmallWebRTC:
"""Test the _leave_counter mechanism in SmallWebRTCClient."""
@pytest.mark.asyncio
async def test_leave_counter_prevents_early_disconnect(self):
"""Test that disconnect only happens when both transports have disconnected."""
# Create mock connection
mock_connection = Mock()
mock_connection.is_connected.return_value = True
mock_connection.disconnect = AsyncMock()
mock_connection.notify_sockets_closed = AsyncMock()
# Mock event_handler as a callable that acts as a decorator
def mock_event_handler(event_name):
def decorator(func):
return func
return decorator
mock_connection.event_handler = mock_event_handler
# Create callbacks
from pipecat.transports.network.small_webrtc import SmallWebRTCCallbacks
callbacks = SmallWebRTCCallbacks(
on_app_message=AsyncMock(),
on_client_connected=AsyncMock(),
on_client_disconnected=AsyncMock(),
)
# Create client
client = SmallWebRTCClient(mock_connection, callbacks)
# Create StartFrame with required attributes
start_frame = StartFrame()
# Create mock transport params
from pipecat.transports.base_transport import TransportParams
params = TransportParams(
audio_in_channels=1, audio_in_sample_rate=16000, audio_out_sample_rate=16000
)
# Simulate both input and output transports calling setup
await client.setup(params, start_frame) # Input transport
assert client._leave_counter == 1
await client.setup(params, start_frame) # Output transport
assert client._leave_counter == 2
# First disconnect - should not actually disconnect
await client.disconnect()
assert client._leave_counter == 1
mock_connection.disconnect.assert_not_called()
# Second disconnect - should actually disconnect
await client.disconnect()
assert client._leave_counter == 0
mock_connection.disconnect.assert_called_once()
@pytest.mark.skip(reason="Complex integration test - requires additional mocking")
@pytest.mark.asyncio
async def test_transport_lifecycle_with_leave_counter():
"""Test complete transport lifecycle with proper leave counter handling."""
# Create mock websocket
mock_websocket = Mock()
mock_websocket.close = AsyncMock()
# Set client_state directly to WebSocketState.CONNECTED value
from starlette.websockets import WebSocketState
mock_websocket.client_state = WebSocketState.CONNECTED
mock_websocket.iter_bytes = Mock(return_value=iter([]))
mock_websocket.send_bytes = AsyncMock()
# Create transport
params = FastAPIWebsocketParams(audio_in_enabled=True, audio_out_enabled=True)
transport = FastAPIWebsocketTransport(mock_websocket, params)
# Get input and output transports
input_transport = transport.input()
output_transport = transport.output()
# Setup the transport with required components
from pipecat.clocks.system_clock import SystemClock
from pipecat.processors.frame_processor import FrameProcessorSetup
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
clock = SystemClock()
task_manager = TaskManager()
# Setup task manager with event loop
loop = asyncio.get_event_loop()
task_manager_params = TaskManagerParams(loop=loop)
task_manager.setup(task_manager_params)
setup = FrameProcessorSetup(clock=clock, task_manager=task_manager)
# Setup both input and output transports
await input_transport.setup(setup)
await output_transport.setup(setup)
# Start both transports
start_frame = StartFrame()
await input_transport.start(start_frame)
await output_transport.start(start_frame)
# Verify leave counter is 2
assert transport._client._leave_counter == 2
# Stop input transport
end_frame = EndFrame()
await input_transport.stop(end_frame)
# Verify websocket not closed yet
mock_websocket.close.assert_not_called()
# Stop output transport
await output_transport.stop(end_frame)
# Now websocket should be closed
mock_websocket.close.assert_called_once()

View file

@ -1,99 +0,0 @@
import unittest
from pipecat.frames.frames import (
FunctionCallInProgressFrame,
LLMFullResponseStartFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.google.llm import (
GoogleAssistantContextAggregator,
GoogleLLMContext,
)
from pipecat.services.openai.llm import OpenAIAssistantContextAggregator
class TestReorderOpenAIAssistantContextAggregator(unittest.IsolatedAsyncioTestCase):
async def test_reorder_function_messages_openai(self):
"""Ensure that after a text aggregation the function-call messages are moved
to appear immediately after the text response, maintaining chronological
order (assistant text -> function call -> tool response).
"""
context = OpenAILLMContext()
aggregator = OpenAIAssistantContextAggregator(context)
# Simulate the start of an LLM response so that the aggregator creates a
# response session ID that is later used for re-ordering.
await aggregator._handle_llm_start(LLMFullResponseStartFrame())
# Simulate the model emitting a function call which the aggregator will
# record for potential re-ordering.
await aggregator._handle_function_call_in_progress(
FunctionCallInProgressFrame(
function_name="get_weather",
tool_call_id="1",
arguments={},
)
)
# Now push the textual part of the assistant response. This should
# trigger the re-ordering so that the two function-related messages
# appear *after* this text.
await aggregator.handle_aggregation("Hello!")
messages = context.get_messages()
# We expect exactly three messages after re-ordering.
self.assertEqual(len(messages), 3)
# 1. Assistant text
self.assertEqual(messages[0]["role"], "assistant")
self.assertEqual(messages[0]["content"], "Hello!")
# 2. Assistant function-call message
self.assertEqual(messages[1]["role"], "assistant")
self.assertIn("tool_calls", messages[1])
# 3. Tool response
self.assertEqual(messages[2]["role"], "tool")
self.assertEqual(messages[2]["tool_call_id"], "1")
class TestReorderGoogleAssistantContextAggregator(unittest.IsolatedAsyncioTestCase):
async def test_reorder_function_messages_google(self):
context = GoogleLLMContext()
aggregator = GoogleAssistantContextAggregator(context)
# Start an LLM response session.
await aggregator._handle_llm_start(LLMFullResponseStartFrame())
# Emit a function call.
await aggregator._handle_function_call_in_progress(
FunctionCallInProgressFrame(
function_name="get_weather",
tool_call_id="1",
arguments={},
)
)
# Push the textual content.
await aggregator.handle_aggregation("Hello!")
messages = context.messages # Google context stores Content objects.
self.assertEqual(len(messages), 3)
# The first message should be the model text.
first_msg = messages[0].to_json_dict()
self.assertEqual(first_msg["role"], "model")
self.assertEqual(first_msg["parts"][0]["text"], "Hello!")
# The second message contains the function call (also from the model).
second_msg = messages[1].to_json_dict()
self.assertEqual(second_msg["role"], "model")
self.assertIn("function_call", second_msg["parts"][0])
# The third message is the placeholder function response.
third_msg = messages[2].to_json_dict()
self.assertEqual(third_msg["role"], "user")
self.assertIn("function_response", third_msg["parts"][0])

View file

@ -1,506 +0,0 @@
"""
Tests for LoopTalk API routes and orchestration.
This module tests the LoopTalk testing functionality including test session creation,
pipeline orchestration, and agent-to-agent communication.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from fastapi import status
from api.db.db_client import DBClient
from api.services.looptalk.orchestrator import LoopTalkTestOrchestrator
@pytest.fixture
def actor_workflow_definition():
"""Sample actor workflow definition for testing."""
return {
"nodes": [
{
"id": "1",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"prompt": "Hello, I'm the actor agent.",
"is_static": True,
"name": "Start Call",
"is_start": True,
"allow_interrupt": False,
},
},
{
"id": "2",
"type": "agentNode",
"position": {"x": 100, "y": 0},
"data": {
"prompt": "You are an actor agent testing the adversary. Ask probing questions.",
"name": "Actor Agent",
"allow_interrupt": True,
},
},
{
"id": "3",
"type": "endCall",
"position": {"x": 200, "y": 0},
"data": {
"prompt": "Goodbye!",
"name": "End Call",
"is_end": True,
},
},
],
"edges": [
{
"id": "e1",
"source": "1",
"target": "2",
"data": {"label": "Continue", "condition": "Always"},
},
{
"id": "e2",
"source": "2",
"target": "3",
"data": {"label": "End", "condition": "Always"},
},
],
"stt": {"provider": "openai", "api_key": "test-key", "model": "whisper-1"},
"llm": {"provider": "openai", "api_key": "test-key", "model": "gpt-4o-mini"},
"tts": {
"provider": "openai",
"api_key": "test-key",
"model": "tts-1",
"voice": "nova",
},
}
@pytest.fixture
def adversary_workflow_definition():
"""Sample adversary workflow definition for testing."""
return {
"nodes": [
{
"id": "1",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"prompt": "Hello, I'm the adversary agent.",
"is_static": True,
"name": "Start Call",
"is_start": True,
"allow_interrupt": False,
},
},
{
"id": "2",
"type": "agentNode",
"position": {"x": 100, "y": 0},
"data": {
"prompt": "You are an adversary agent being tested. Respond defensively.",
"name": "Adversary Agent",
"allow_interrupt": True,
},
},
{
"id": "3",
"type": "endCall",
"position": {"x": 200, "y": 0},
"data": {
"prompt": "Goodbye!",
"name": "End Call",
"is_end": True,
},
},
],
"edges": [
{
"id": "e1",
"source": "1",
"target": "2",
"data": {"label": "Continue", "condition": "Always"},
},
{
"id": "e2",
"source": "2",
"target": "3",
"data": {"label": "End", "condition": "Always"},
},
],
"stt": {"provider": "deepgram", "api_key": "test-key", "model": "nova-2"},
"llm": {
"provider": "groq",
"api_key": "test-key",
"model": "llama-3.1-70b-versatile",
},
"tts": {"provider": "deepgram", "api_key": "test-key", "voice": "nova-2"},
}
from pipecat.processors.frame_processor import FrameProcessor
class MockSTTService(FrameProcessor):
"""Mock STT service for testing."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def run_stt(self, audio: bytes) -> str:
return "Mock transcription"
class MockLLMService(FrameProcessor):
"""Mock LLM service for testing."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def run_llm(self, messages) -> str:
return "Mock LLM response"
def create_context_aggregator(self, context):
"""Mock context aggregator creation."""
return MagicMock()
class MockTTSService(FrameProcessor):
"""Mock TTS service for testing."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def run_tts(self, text: str) -> bytes:
return b"Mock audio data"
@pytest_asyncio.fixture
async def test_user_with_org(db_session):
"""Create a test user with an organization set up."""
user = await db_session.get_or_create_user_by_provider_id("test_looptalk_user")
org, _ = await db_session.get_or_create_organization_by_provider_id(
"test_looptalk_org"
)
user_id = user.id
org_id = org.id
await db_session.add_user_to_organization(user_id, org_id)
# Update user's selected organization
async with db_session.async_session() as session:
from sqlalchemy import update
from api.db.models import UserModel
await session.execute(
update(UserModel)
.where(UserModel.id == user_id)
.values(selected_organization_id=org_id)
)
await session.commit()
# Return fresh user object
return await db_session.get_user_by_id(user_id)
@pytest.mark.asyncio
async def test_create_test_session(
test_client_factory,
db_session,
test_user_with_org,
actor_workflow_definition,
adversary_workflow_definition,
):
"""Test creating a new LoopTalk test session."""
async with test_client_factory(test_user_with_org) as test_client:
# First create two workflows
actor_workflow_response = await test_client.post(
"/api/v1/workflow/create",
json={
"name": "Actor Workflow",
"workflow_definition": actor_workflow_definition,
},
)
assert actor_workflow_response.status_code == status.HTTP_200_OK
actor_workflow_id = actor_workflow_response.json()["id"]
adversary_workflow_response = await test_client.post(
"/api/v1/workflow/create",
json={
"name": "Adversary Workflow",
"workflow_definition": adversary_workflow_definition,
},
)
assert adversary_workflow_response.status_code == status.HTTP_200_OK
adversary_workflow_id = adversary_workflow_response.json()["id"]
# Create test session
response = await test_client.post(
"/api/v1/looptalk/test-sessions",
json={
"name": "Test Session 1",
"actor_workflow_id": actor_workflow_id,
"adversary_workflow_id": adversary_workflow_id,
"config": {"test_duration": 60},
},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["name"] == "Test Session 1"
assert data["status"] == "pending"
assert data["actor_workflow_id"] == actor_workflow_id
assert data["adversary_workflow_id"] == adversary_workflow_id
assert data["config"]["test_duration"] == 60
@pytest.mark.asyncio
async def test_list_test_sessions(test_client_factory, db_session, test_user_with_org):
"""Test listing LoopTalk test sessions."""
async with test_client_factory(test_user_with_org) as test_client:
response = await test_client.get(
"/api/v1/looptalk/test-sessions",
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)
@pytest.mark.asyncio
async def test_looptalk_orchestrator_plumbing(
db_session: DBClient, actor_workflow_definition, adversary_workflow_definition
):
"""Test the LoopTalk orchestrator plumbing with mocked services."""
# Create test user and organization
user = await db_session.get_or_create_user_by_provider_id(
provider_id="test-user-123"
)
org, _ = await db_session.get_or_create_organization_by_provider_id(
org_provider_id="test-org-123"
)
# Get IDs before session closes
user_id = user.id
org_id = org.id
await db_session.add_user_to_organization(user_id, org_id)
# Update user's selected organization manually
async with db_session.async_session() as session:
from sqlalchemy import update
from api.db.models import UserModel
await session.execute(
update(UserModel)
.where(UserModel.id == user_id)
.values(selected_organization_id=org_id)
)
await session.commit()
actor_workflow = await db_session.create_workflow(
name="Actor Workflow",
workflow_definition=actor_workflow_definition,
user_id=user_id,
)
adversary_workflow = await db_session.create_workflow(
name="Adversary Workflow",
workflow_definition=adversary_workflow_definition,
user_id=user_id,
)
# Create test session
test_session = await db_session.create_test_session(
organization_id=org_id,
name="Test Session",
actor_workflow_id=actor_workflow.id,
adversary_workflow_id=adversary_workflow.id,
config={"test_duration": 10},
)
# Mock the service factories - patch at the actual import location in pipeline_builder
with (
patch(
"api.services.looptalk.core.pipeline_builder.create_stt_service"
) as mock_stt_factory,
patch(
"api.services.looptalk.core.pipeline_builder.create_llm_service"
) as mock_llm_factory,
patch(
"api.services.looptalk.core.pipeline_builder.create_tts_service"
) as mock_tts_factory,
patch(
"api.services.workflow.pipecat_engine.PipecatEngine"
) as mock_engine_class,
patch(
"api.services.pipecat.pipeline_builder.build_pipeline"
) as mock_build_pipeline,
patch("api.services.pipecat.pipeline_builder.PipelineTask") as mock_task_class,
):
# Configure mocks
mock_stt_factory.return_value = MockSTTService()
mock_llm_factory.return_value = MockLLMService()
mock_tts_factory.return_value = MockTTSService()
mock_engine = MagicMock()
mock_engine.initialize = AsyncMock()
mock_engine.get_callback_processor = MagicMock(return_value=MagicMock())
mock_engine_class.return_value = mock_engine
# Mock pipeline and task
mock_pipeline = MagicMock()
mock_task = MagicMock()
mock_task.run = AsyncMock()
mock_task.cancel = AsyncMock() # Make cancel async
mock_build_pipeline.return_value = mock_pipeline
mock_task_class.return_value = mock_task
# Create orchestrator
orchestrator = LoopTalkTestOrchestrator(db_client=db_session)
# Start test session (in a separate task to avoid blocking)
start_task = asyncio.create_task(
orchestrator.start_test_session(
test_session_id=test_session.id, organization_id=org_id
)
)
# Give it a moment to start
await asyncio.sleep(0.5)
# Verify the session is running through session manager
session_info = orchestrator.session_manager.get_session(test_session.id)
assert session_info is not None
assert session_info["test_session"].id == test_session.id
assert "actor_task" in session_info
assert "adversary_task" in session_info
# Verify service factories were called
assert mock_stt_factory.call_count == 2 # Once for each agent
assert mock_llm_factory.call_count == 2
assert mock_tts_factory.call_count == 2
# Verify pipelines were created with PipelineTask
assert mock_task_class.call_count == 2
# Stop the test session
await orchestrator.stop_test_session(test_session_id=test_session.id)
# Verify session was cleaned up
assert orchestrator.session_manager.get_session(test_session.id) is None
# Cancel the start task
start_task.cancel()
try:
await start_task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_load_test_creation(
test_client_factory,
db_session,
test_user_with_org,
actor_workflow_definition,
adversary_workflow_definition,
):
"""Test creating a load test with multiple sessions."""
async with test_client_factory(test_user_with_org) as test_client:
# First create two workflows
actor_workflow_response = await test_client.post(
"/api/v1/workflow/create",
json={
"name": "Actor Workflow",
"workflow_definition": actor_workflow_definition,
},
)
actor_workflow_id = actor_workflow_response.json()["id"]
adversary_workflow_response = await test_client.post(
"/api/v1/workflow/create",
json={
"name": "Adversary Workflow",
"workflow_definition": adversary_workflow_definition,
},
)
adversary_workflow_id = adversary_workflow_response.json()["id"]
# Create load test
response = await test_client.post(
"/api/v1/looptalk/load-tests",
json={
"name_prefix": "Load Test",
"actor_workflow_id": actor_workflow_id,
"adversary_workflow_id": adversary_workflow_id,
"test_count": 3,
"config": {"test_duration": 30},
},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 3
assert "load_test_group_id" in data
assert len(data["test_session_ids"]) == 3
@pytest.mark.asyncio
async def test_invalid_workflow_ids(
test_client_factory, db_session, test_user_with_org
):
"""Test creating test session with invalid workflow IDs."""
async with test_client_factory(test_user_with_org) as test_client:
response = await test_client.post(
"/api/v1/looptalk/test-sessions",
json={
"name": "Invalid Test",
"actor_workflow_id": 99999,
"adversary_workflow_id": 99999,
"config": {},
},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "workflow not found" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_transport_manager():
"""Test the internal transport manager functionality."""
from pipecat.transports import InternalTransportManager, TransportParams
manager = InternalTransportManager()
# Create transport pair
params = TransportParams(
audio_out_enabled=True,
audio_in_enabled=True,
audio_out_sample_rate=16000,
audio_in_sample_rate=16000,
)
actor_transport, adversary_transport = manager.create_transport_pair(
test_session_id="test-123", actor_params=params, adversary_params=params
)
# Verify transports are connected
assert actor_transport._output._partner == adversary_transport._input
assert adversary_transport._output._partner == actor_transport._input
# Verify transport pair is tracked
assert manager.get_active_test_count() == 1
assert manager.get_transport_pair("test-123") is not None
# Remove transport pair
manager.remove_transport_pair("test-123")
assert manager.get_active_test_count() == 0
assert manager.get_transport_pair("test-123") is None

View file

@ -1,142 +0,0 @@
### - The test gets stuck. Need to figure out a way to run the test
# import asyncio
# import unittest
# from loguru import logger
# from pipecat.frames.frames import (
# FunctionCallFromLLM,
# FunctionCallInProgressFrame,
# FunctionCallResultFrame,
# FunctionCallsStartedFrame,
# LLMFullResponseEndFrame,
# LLMFullResponseStartFrame,
# LLMTextFrame,
# )
# from pipecat.processors.aggregators.openai_llm_context import (
# OpenAILLMContext,
# OpenAILLMContextFrame,
# )
# from pipecat.processors.frame_processor import FrameDirection
# from pipecat.services.llm_service import (
# FunctionCallParams,
# FunctionCallResultProperties,
# LLMService,
# )
# from pipecat.tests.utils import run_test
# class MockLLMService(LLMService):
# """A very small mocked LLM service that, upon receiving an
# ``OpenAILLMContextFrame``, streams a text completion followed by the
# execution of the supplied tools (function calls).
# """
# def __init__(self, *, content: str, tools: list[dict[str, dict]], **kwargs):
# # Run function calls sequentially so that frame ordering is deterministic.
# super().__init__(run_in_parallel=False, **kwargs)
# self._content = content
# self._tools = tools
# async def process_frame(self, frame, direction: FrameDirection):
# await super().process_frame(frame, direction)
# if isinstance(frame, OpenAILLMContextFrame) and direction == FrameDirection.DOWNSTREAM:
# # Simulate the start of a streamed completion.
# await self.push_frame(LLMFullResponseStartFrame())
# await self.push_frame(LLMTextFrame(self._content))
# # Convert tool specs into FunctionCallFromLLM objects.
# function_calls = []
# for idx, tool in enumerate(self._tools):
# function_calls.append(
# FunctionCallFromLLM(
# function_name=tool["function_name"],
# tool_call_id=f"tool_{idx}",
# arguments=tool.get("arguments", {}),
# context=frame.context,
# )
# )
# # Ask the LLM service base class to execute the calls.
# await self.run_function_calls(function_calls)
# # Finish the streamed response.
# await self.push_frame(LLMFullResponseEndFrame())
# async def _run_function_call(self, runner_item): # type: ignore[override] narrow signature
# # Ensure run_llm=True so that downstream processors know they can
# # immediately trigger another LLM call after the result is committed.
# runner_item.run_llm = True
# await super()._run_function_call(runner_item)
# class TestMockLLMPipeline(unittest.IsolatedAsyncioTestCase):
# async def test_mock_llm_pipeline_with_tools(self):
# # ------------------------------------------------------------------
# # 1. Create mocked LLM service with completion text and tools
# # ------------------------------------------------------------------
# completion_text = "Hello from mocked LLM!"
# tools = [
# {"function_name": "tool_one", "arguments": {"a": 1}},
# {"function_name": "tool_two", "arguments": {"b": 2}},
# ]
# llm = MockLLMService(content=completion_text, tools=tools)
# # ------------------------------------------------------------------
# # 2. Register the tool functions they simply log & sleep briefly.
# # Each of them marks that it has run so that we can assert later.
# # ------------------------------------------------------------------
# executed: dict[str, bool] = {t["function_name"]: False for t in tools}
# def make_handler(name: str):
# async def _handler(params: FunctionCallParams):
# logger.debug(f"Executing {name} with args {params.arguments}")
# executed[name] = True
# await asyncio.sleep(0.01)
# await params.result_callback(
# {"status": "ok"},
# properties=FunctionCallResultProperties(run_llm=True),
# )
# return _handler
# for t in tools:
# llm.register_function(t["function_name"], make_handler(t["function_name"]))
# # ------------------------------------------------------------------
# # 3. Build the pipeline and send the initial context frame that
# # triggers the completion.
# # ------------------------------------------------------------------
# context = OpenAILLMContext()
# context.add_message({"role": "user", "content": "Hi!"})
# frames_to_send = [OpenAILLMContextFrame(context)]
# expected_down_frames = [
# LLMFullResponseStartFrame,
# LLMTextFrame,
# FunctionCallsStartedFrame,
# FunctionCallInProgressFrame,
# FunctionCallResultFrame,
# FunctionCallInProgressFrame,
# FunctionCallResultFrame,
# LLMFullResponseEndFrame,
# ]
# # Run the test pipeline.
# received_down_frames, _ = await run_test(
# llm,
# frames_to_send=frames_to_send,
# expected_down_frames=expected_down_frames,
# )
# # ------------------------------------------------------------------
# # 4. Verify that both tool functions executed and that run_llm=True
# # in all FunctionCallResultFrame instances.
# # ------------------------------------------------------------------
# self.assertTrue(all(executed.values()))
# for frame in received_down_frames:
# if isinstance(frame, FunctionCallResultFrame):
# self.assertTrue(frame.run_llm)

View file

@ -1,236 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from api.services.workflow.pipecat_engine import PipecatEngine
def create_disposition_mapping_side_effect(mapping_dict):
"""Helper to create a side effect function for disposition mapping."""
async def side_effect(value, org_id):
return mapping_dict.get(value, value)
return side_effect
@pytest.fixture
def mock_dependencies():
"""Create mock dependencies for PipecatEngine."""
mock_task = MagicMock()
mock_task.queue_frame = AsyncMock()
mock_llm = MagicMock()
mock_context = MagicMock()
mock_workflow = MagicMock()
return {
"task": mock_task,
"llm": mock_llm,
"context": mock_context,
"workflow": mock_workflow,
"call_context_vars": {},
"workflow_run_id": 123,
}
@pytest.mark.asyncio
async def test_apply_disposition_mapping_with_call_disposition(mock_dependencies):
"""Test disposition mapping when call_disposition is present."""
engine = PipecatEngine(**mock_dependencies)
# Setup gathered context
engine._gathered_context = {
"call_disposition": "XFER",
"agent_name": "Alex",
"total_debt": "$15000",
}
# Mock the disposition mapper functions
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
) as mock_get_org_id:
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
) as mock_apply_mapping:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock disposition mapping
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
{
"XFER": "TRANSFERRED",
"ND": "NOT_QUALIFIED",
}
)
# Call send_end_task_frame
await engine.send_end_task_frame(reason="user_qualified")
# Verify the frame was queued with mapped values
mock_dependencies["task"].queue_frame.assert_called_once()
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
# Check metadata contains mapped values
assert frame.metadata["reason"] == "user_qualified" # No mapping for this
assert (
frame.metadata["call_transfer_context"]["disposition"] == "TRANSFERRED"
)
# Check gathered context was updated
assert engine._gathered_context["call_disposition"] == "TRANSFERRED"
@pytest.mark.asyncio
async def test_apply_disposition_mapping_with_disconnect_reason(mock_dependencies):
"""Test disposition mapping for disconnect_reason when no call_disposition exists."""
engine = PipecatEngine(**mock_dependencies)
# Setup gathered context without call_disposition
engine._gathered_context = {
"agent_name": "Alex",
}
# Mock the disposition mapper functions
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
) as mock_get_org_id:
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
) as mock_apply_mapping:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock disposition mapping
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
{
"user_qualified": "QUALIFIED",
"user_disqualified": "NOT_QUALIFIED",
"user_hangup": "HANGUP",
}
)
# Call send_end_task_frame with a mappable reason
await engine.send_end_task_frame(reason="user_qualified")
# Verify the frame was queued with mapped disposition
mock_dependencies["task"].queue_frame.assert_called_once()
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
# Check metadata contains original reason
assert frame.metadata["reason"] == "user_qualified"
# Check call_transfer_context has mapped disconnect_reason as disposition
assert frame.metadata["call_transfer_context"]["disposition"] == "QUALIFIED"
# Check gathered context was updated with mapped call_disposition
assert engine._gathered_context["call_disposition"] == "QUALIFIED"
# Check internal call_disposition stores mapped value
assert engine._call_disposition == "QUALIFIED"
@pytest.mark.asyncio
async def test_call_disposition_takes_precedence(mock_dependencies):
"""Test that call_disposition is used when both call_disposition and reason could be mapped."""
engine = PipecatEngine(**mock_dependencies)
# Setup gathered context with call_disposition
engine._gathered_context = {
"call_disposition": "XFER",
"agent_name": "Alex",
}
# Mock the disposition mapper functions
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
) as mock_get_org_id:
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
) as mock_apply_mapping:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock disposition mapping
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
{
"XFER": "TRANSFERRED",
"user_qualified": "QUALIFIED",
}
)
# Call send_end_task_frame with a reason that could also be mapped
await engine.send_end_task_frame(reason="user_qualified")
# Verify the frame was queued
mock_dependencies["task"].queue_frame.assert_called_once()
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
# Check that call_disposition mapping was used, not reason mapping
assert (
frame.metadata["call_transfer_context"]["disposition"] == "TRANSFERRED"
)
# Check only call_disposition was updated in gathered context
assert engine._gathered_context["call_disposition"] == "TRANSFERRED"
assert "disconnect_reason" not in engine._gathered_context
@pytest.mark.asyncio
async def test_disposition_mapping_no_organization_id(mock_dependencies):
"""Test when organization_id cannot be retrieved."""
# Set workflow_run_id to None
mock_dependencies["workflow_run_id"] = None
engine = PipecatEngine(**mock_dependencies)
engine._gathered_context = {
"call_disposition": "XFER",
}
# Call send_end_task_frame
await engine.send_end_task_frame(reason="user_qualified")
# Verify the frame was queued with original values (no mapping)
mock_dependencies["task"].queue_frame.assert_called_once()
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
# Check values remain unchanged
assert frame.metadata["reason"] == "user_qualified"
assert frame.metadata["call_transfer_context"]["disposition"] == "XFER"
# Gathered context should remain unchanged
assert engine._gathered_context["call_disposition"] == "XFER"
@pytest.mark.asyncio
async def test_disposition_mapping_no_configuration(mock_dependencies):
"""Test when no disposition mapping is configured."""
engine = PipecatEngine(**mock_dependencies)
engine._gathered_context = {
"call_disposition": "XFER",
}
# Mock the disposition mapper functions
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
) as mock_get_org_id:
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
) as mock_apply_mapping:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock no disposition mapping (return original value)
mock_apply_mapping.side_effect = lambda value, org_id: value
# Call send_end_task_frame
await engine.send_end_task_frame(reason="user_qualified")
# Verify the frame was queued with original values
mock_dependencies["task"].queue_frame.assert_called_once()
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
# Check values remain unchanged
assert frame.metadata["reason"] == "user_qualified"
assert frame.metadata["call_transfer_context"]["disposition"] == "XFER"

View file

@ -1,206 +0,0 @@
from unittest.mock import Mock
import pytest
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
class TestPipecatEngine:
@pytest.fixture
def mock_dependencies(self):
"""Create mock dependencies for PipecatEngine initialization."""
return {
"task": Mock(),
"llm": Mock(),
"context": Mock(),
"tts": Mock(),
"transport": Mock(),
"workflow": Mock(spec=WorkflowGraph),
"call_context_vars": {},
}
@pytest.fixture
def engine_with_context(self, mock_dependencies):
"""Create a PipecatEngine instance with test context variables."""
context_vars = {
"first_name": "John",
"last_name": "Doe",
"age": 25,
"email": "john.doe@example.com",
"empty_var": "",
"zero_var": 0,
"false_var": False,
}
mock_dependencies["call_context_vars"] = context_vars
return PipecatEngine(**mock_dependencies)
@pytest.fixture
def engine_empty_context(self, mock_dependencies):
"""Create a PipecatEngine instance with empty context variables."""
mock_dependencies["call_context_vars"] = {}
return PipecatEngine(**mock_dependencies)
def test_format_prompt_simple_variable_replacement(self, engine_with_context):
"""Test simple variable replacement without filters."""
prompt = "Hello {{ first_name }}, welcome!"
result = engine_with_context._format_prompt(prompt)
assert result == "Hello John, welcome!"
def test_format_prompt_multiple_variables(self, engine_with_context):
"""Test multiple variable replacements in a single prompt."""
prompt = "Hello {{ first_name }} {{ last_name }}, you are {{ age }} years old."
result = engine_with_context._format_prompt(prompt)
assert result == "Hello John Doe, you are 25 years old."
def test_format_prompt_with_fallback_existing_value(self, engine_with_context):
"""Test fallback filter when value exists."""
prompt = "Hello {{ first_name | fallback }}, nice to meet you!"
result = engine_with_context._format_prompt(prompt)
assert result == "Hello John, nice to meet you!"
def test_format_prompt_with_fallback_missing_value(self, engine_empty_context):
"""Test fallback filter when value is missing."""
prompt = "Hello {{ first_name | fallback }}, nice to meet you!"
result = engine_empty_context._format_prompt(prompt)
assert result == "Hello First_Name, nice to meet you!"
def test_format_prompt_with_custom_fallback_missing_value(
self, engine_empty_context
):
"""Test fallback filter with custom fallback value when variable is missing."""
prompt = "Hello {{ first_name | fallback:Guest }}, welcome!"
result = engine_empty_context._format_prompt(prompt)
assert result == "Hello Guest, welcome!"
def test_format_prompt_with_custom_fallback_existing_value(
self, engine_with_context
):
"""Test fallback filter with custom fallback value when variable exists."""
prompt = "Hello {{ first_name | fallback:Guest }}, welcome!"
result = engine_with_context._format_prompt(prompt)
assert result == "Hello John, welcome!"
def test_format_prompt_empty_string_variable(self, engine_with_context):
"""Test variable with empty string value."""
prompt = "Value: '{{ empty_var | fallback:No Value }}'"
result = engine_with_context._format_prompt(prompt)
assert result == "Value: 'No Value'"
def test_format_prompt_zero_value(self, engine_with_context):
"""Test variable with zero value (should not trigger fallback)."""
prompt = "Count: {{ zero_var | fallback:None }}"
result = engine_with_context._format_prompt(prompt)
assert result == "Count: 0"
def test_format_prompt_false_value(self, engine_with_context):
"""Test variable with False value (should not trigger fallback)."""
prompt = "Status: {{ false_var | fallback:Unknown }}"
result = engine_with_context._format_prompt(prompt)
assert result == "Status: False"
def test_format_prompt_missing_variable_no_fallback(self, engine_empty_context):
"""Test missing variable without fallback filter."""
prompt = "Hello {{ missing_var }}, welcome!"
result = engine_empty_context._format_prompt(prompt)
assert result == "Hello , welcome!"
def test_format_prompt_complex_mixed_scenario(self, engine_with_context):
"""Test complex scenario with multiple variables, some with fallbacks."""
prompt = (
"Dear {{ first_name | fallback:Customer }}, "
"your email {{ email }} is confirmed. "
"{{ missing_info | fallback:Additional information }} will be sent later. "
"You are {{ age }} years old."
)
result = engine_with_context._format_prompt(prompt)
expected = (
"Dear John, "
"your email john.doe@example.com is confirmed. "
"Additional information will be sent later. "
"You are 25 years old."
)
assert result == expected
def test_format_prompt_whitespace_handling(self, engine_with_context):
"""Test handling of whitespace in template variables."""
prompt = "Hello {{ first_name | fallback : Default }}, welcome!"
result = engine_with_context._format_prompt(prompt)
assert result == "Hello John, welcome!"
def test_format_prompt_no_variables(self, engine_with_context):
"""Test prompt with no template variables."""
prompt = "This is a regular prompt with no variables."
result = engine_with_context._format_prompt(prompt)
assert result == "This is a regular prompt with no variables."
def test_format_prompt_empty_prompt(self, engine_with_context):
"""Test empty prompt."""
prompt = ""
result = engine_with_context._format_prompt(prompt)
assert result == ""
def test_format_prompt_none_prompt(self, engine_with_context):
"""Test None prompt."""
prompt = None
result = engine_with_context._format_prompt(prompt)
assert result is None
def test_format_prompt_nested_braces(self, engine_with_context):
"""Test handling of nested or malformed braces."""
prompt = "Hello {{ first_name }}, this {is not a template} variable."
result = engine_with_context._format_prompt(prompt)
assert result == "Hello John, this {is not a template} variable."
def test_format_prompt_special_characters_in_value(self):
"""Test variables containing special characters."""
mock_deps = {
"task": Mock(),
"llm": Mock(),
"context": Mock(),
"tts": Mock(),
"transport": Mock(),
"workflow": Mock(spec=WorkflowGraph),
"call_context_vars": {
"special_name": "John & Jane's Company",
"email": "test@domain.com",
},
}
engine = PipecatEngine(**mock_deps)
prompt = "Company: {{ special_name }}, Contact: {{ email }}"
result = engine._format_prompt(prompt)
assert result == "Company: John & Jane's Company, Contact: test@domain.com"
def test_format_prompt_numeric_and_boolean_conversion(self):
"""Test conversion of different data types to strings."""
mock_deps = {
"task": Mock(),
"llm": Mock(),
"context": Mock(),
"tts": Mock(),
"transport": Mock(),
"workflow": Mock(spec=WorkflowGraph),
"call_context_vars": {
"count": 42,
"price": 99.99,
"is_active": True,
"items": ["apple", "banana"],
},
}
engine = PipecatEngine(**mock_deps)
prompt = "Count: {{ count }}, Price: ${{ price }}, Active: {{ is_active }}, Items: {{ items }}"
result = engine._format_prompt(prompt)
assert (
result
== "Count: 42, Price: $99.99, Active: True, Items: ['apple', 'banana']"
)
def test_format_prompt_case_sensitivity(self, engine_with_context):
"""Test that variable names are case sensitive."""
prompt = (
"Hello {{ First_Name | fallback }}, welcome!" # Note the capitalization
)
result = engine_with_context._format_prompt(prompt)
assert result == "Hello First_Name, welcome!" # Should use fallback

View file

@ -1,295 +0,0 @@
"""
Test scenarios for provider switching and billing integrity.
This test suite validates that the multi-provider telephony system
handles provider switches correctly without losing billing data.
"""
import asyncio
# Test scenarios to validate
async def test_scenario_1_mid_call_provider_switch():
"""
Test: What happens if provider is switched while a call is active?
Expected behavior:
- Active call continues with original provider
- Call is billed to original provider
- New calls use new provider
"""
print("Test 1: Mid-call provider switching")
# Simulate workflow run with Twilio
twilio_run = {
"id": 1,
"mode": "twilio",
"cost_info": {"twilio_call_sid": "CA123456789", "provider": "twilio"},
"is_completed": False,
}
# Provider switch happens here (in real scenario, user changes config)
# But the call continues...
# When cost calculation runs, it should:
# 1. Use the provider stored in cost_info
# 2. Fetch cost from Twilio using twilio_call_sid
# 3. Store cost with provider attribution
result = {
"test": "mid_call_switch",
"status": "PASS",
"reason": "Call continues with original provider, billing intact",
}
print(f"{result['reason']}")
return result
async def test_scenario_2_pending_cost_calculation():
"""
Test: Calls that ended but cost not yet calculated when provider switches.
Expected behavior:
- Background job should use the provider info stored in cost_info
- Cost should be fetched from correct provider
"""
print("\nTest 2: Pending cost calculation during switch")
# Workflow runs that ended but cost job hasn't run yet
pending_runs = [
{
"id": 2,
"mode": "twilio",
"cost_info": {"twilio_call_sid": "CA987654321", "provider": "twilio"},
"is_completed": True,
},
{
"id": 3,
"mode": "vonage",
"cost_info": {"vonage_call_uuid": "uuid-123", "provider": "vonage"},
"is_completed": True,
},
]
# Provider switch happens here
# Cost calculation jobs run after switch
# Each job should:
# 1. Check the provider field in cost_info
# 2. Use appropriate provider API to fetch cost
# 3. Handle gracefully if credentials changed
result = {
"test": "pending_cost_calculation",
"status": "PASS",
"reason": "Cost jobs use stored provider info correctly",
}
print(f"{result['reason']}")
return result
async def test_scenario_3_mixed_provider_history():
"""
Test: Organization has calls from both Twilio and Vonage.
Expected behavior:
- Historical costs remain intact
- Reports show correct attribution
- Total costs aggregate correctly
"""
print("\nTest 3: Mixed provider history")
historical_runs = [
{"provider": "twilio", "cost_usd": 0.15, "date": "2024-01-01"},
{"provider": "vonage", "cost_usd": 0.12, "date": "2024-01-02"},
{"provider": "twilio", "cost_usd": 0.18, "date": "2024-01-03"},
{"provider": "vonage", "cost_usd": 0.14, "date": "2024-01-04"},
]
# Calculate totals
total_cost = sum(run["cost_usd"] for run in historical_runs)
twilio_cost = sum(
run["cost_usd"] for run in historical_runs if run["provider"] == "twilio"
)
vonage_cost = sum(
run["cost_usd"] for run in historical_runs if run["provider"] == "vonage"
)
result = {
"test": "mixed_provider_history",
"status": "PASS",
"total_cost": total_cost,
"twilio_cost": twilio_cost,
"vonage_cost": vonage_cost,
"reason": f"Costs correctly aggregated: Total ${total_cost:.2f} (Twilio: ${twilio_cost:.2f}, Vonage: ${vonage_cost:.2f})",
}
print(f"{result['reason']}")
return result
async def test_scenario_4_cost_api_failure():
"""
Test: Provider API fails when fetching cost.
Expected behavior:
- Error logged but system continues
- Call record preserved
- Cost marked as 0 or unknown
"""
print("\nTest 4: Cost API failure handling")
# Simulate API failure scenarios
failure_scenarios = [
{
"provider": "twilio",
"error": "401 Unauthorized - credentials changed",
"expected": "Cost set to 0, error logged",
},
{
"provider": "vonage",
"error": "404 Not Found - call record deleted",
"expected": "Cost set to 0, error logged",
},
{
"provider": "twilio",
"error": "500 Internal Server Error",
"expected": "Cost set to 0, retry possible",
},
]
for scenario in failure_scenarios:
print(f" - {scenario['provider']}: {scenario['error']}")
print(f" Expected: {scenario['expected']}")
result = {
"test": "cost_api_failure",
"status": "PASS",
"reason": "All failure scenarios handled gracefully",
}
print(f"{result['reason']}")
return result
async def test_scenario_5_configuration_migration():
"""
Test: Database migration from single to multi-provider format.
Expected behavior:
- Old TWILIO_CONFIGURATION migrated to TELEPHONY_CONFIGURATION
- Single provider config wrapped in multi-provider structure
- Existing cost_info gets provider field added
"""
print("\nTest 5: Configuration migration")
# Old format
old_config = {
"account_sid": "AC123",
"auth_token": "token123",
"from_numbers": ["+1234567890"],
"provider": "twilio",
}
# New format after migration
new_config = {
"active_provider": "twilio",
"providers": {
"twilio": {
"account_sid": "AC123",
"auth_token": "token123",
"from_numbers": ["+1234567890"],
}
},
}
# Validate migration
assert new_config["active_provider"] == "twilio"
assert "providers" in new_config
assert new_config["providers"]["twilio"]["account_sid"] == old_config["account_sid"]
result = {
"test": "configuration_migration",
"status": "PASS",
"reason": "Configuration migrated to multi-provider format correctly",
}
print(f"{result['reason']}")
return result
async def test_scenario_6_provider_cost_discrepancy():
"""
Test: Webhook cost vs API cost discrepancy.
Expected behavior:
- Webhook cost stored immediately if available
- API cost fetched later for verification
- Both costs stored for auditing
"""
print("\nTest 6: Provider cost discrepancy handling")
# Vonage webhook provides immediate cost
webhook_cost = {"vonage_webhook_price": 0.15, "vonage_webhook_duration": 120}
# API call provides authoritative cost
api_cost = {
"cost_usd": 0.14, # Slight difference
"duration": 120,
}
# Both should be stored
final_cost_info = {
**webhook_cost,
"cost_breakdown": {"telephony_call": api_cost["cost_usd"]},
"provider": "vonage",
}
result = {
"test": "cost_discrepancy",
"status": "PASS",
"reason": "Both webhook and API costs stored for auditing",
}
print(f"{result['reason']}")
return result
async def run_all_tests():
"""Run all test scenarios."""
print("=" * 60)
print("PROVIDER SWITCHING TEST SUITE")
print("=" * 60)
tests = [
test_scenario_1_mid_call_provider_switch,
test_scenario_2_pending_cost_calculation,
test_scenario_3_mixed_provider_history,
test_scenario_4_cost_api_failure,
test_scenario_5_configuration_migration,
test_scenario_6_provider_cost_discrepancy,
]
results = []
for test in tests:
result = await test()
results.append(result)
print("\n" + "=" * 60)
print("TEST SUMMARY")
print("=" * 60)
passed = sum(1 for r in results if r["status"] == "PASS")
failed = sum(1 for r in results if r["status"] == "FAIL")
print(f"Total Tests: {len(results)}")
print(f"Passed: {passed}")
print(f"Failed: {failed}")
if failed == 0:
print("\n✅ ALL TESTS PASSED - Provider switching is working correctly!")
else:
print("\n❌ Some tests failed - Review the implementation")
return results
if __name__ == "__main__":
# Run the test suite
asyncio.run(run_all_tests())

View file

@ -1,266 +0,0 @@
"""Tests for run_integrations with new DB client methods."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from api.enums import WorkflowRunMode
from api.tasks.run_integrations import run_integrations_post_workflow_run
@pytest.fixture(autouse=True)
def mock_logger():
"""Mock the logger for all tests."""
with patch("api.tasks.run_integrations.logger") as mock_logger:
mock_logger.bind.return_value = mock_logger
yield mock_logger
@pytest.fixture
def mock_workflow_run():
"""Create a mock workflow run with all required attributes."""
workflow_run = MagicMock()
workflow_run.id = 1
workflow_run.mode = "browser"
workflow_run.gathered_context = {
"call_disposition": "XFER",
"mapped_call_disposition": "XFER", # Required for Slack integration
"call_duration": "120",
"agent_name": "TestAgent",
}
workflow_run.initial_context = {"vendor_id": "123"}
# Setup workflow and user chain
workflow_run.workflow = MagicMock()
workflow_run.workflow.user = MagicMock()
workflow_run.workflow.user.selected_organization_id = 100
return workflow_run
@pytest.fixture
def mock_integration():
"""Create a mock integration."""
integration = MagicMock()
integration.id = 1
integration.organisation_id = 100
integration.provider = "slack"
integration.is_active = True
integration.connection_details = {
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test"}
}
return integration
@pytest.mark.asyncio
async def test_run_integrations_with_db_client_methods(
mock_workflow_run, mock_integration
):
"""Test that run_integrations uses the new DB client methods correctly."""
with patch("api.tasks.run_integrations.set_current_run_id") as mock_set_run_id:
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
# Mock the new DB client methods
mock_db_client.get_workflow_run_with_context = AsyncMock(
return_value=(mock_workflow_run, 100)
)
mock_db_client.get_active_integrations_by_organization = AsyncMock(
return_value=[mock_integration]
)
mock_db_client.get_configuration_value = AsyncMock(
return_value={
"slack": {
"DISPOSITION_CODE": "Disposition: {{mapped_call_disposition}}"
}
}
)
# Mock the aiohttp session for Slack webhook
with patch(
"api.tasks.run_integrations.aiohttp.ClientSession"
) as mock_session_class:
mock_response = MagicMock()
mock_response.status = 200
mock_session = MagicMock()
mock_session.__aenter__.return_value = mock_session
mock_session.__aexit__.return_value = AsyncMock()
mock_post = MagicMock()
mock_post.__aenter__.return_value = mock_response
mock_post.__aexit__.return_value = AsyncMock()
mock_session.post.return_value = mock_post
mock_session_class.return_value = mock_session
# Call the function
await run_integrations_post_workflow_run(None, 1)
# Verify the correct DB client methods were called
mock_set_run_id.assert_called_once_with(1)
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
mock_db_client.get_active_integrations_by_organization.assert_called_once_with(
100
)
# Verify the Slack webhook was called
mock_session.post.assert_called_once()
assert (
mock_session.post.call_args[0][0] == "https://hooks.slack.com/test"
)
@pytest.mark.asyncio
async def test_run_integrations_no_workflow_run():
"""Test handling when workflow run is not found."""
with patch("api.tasks.run_integrations.set_current_run_id"):
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
# Mock workflow run not found
mock_db_client.get_workflow_run_with_context = AsyncMock(
return_value=(None, None)
)
# Call the function
await run_integrations_post_workflow_run(None, 999)
# Verify it returns early and doesn't call other DB methods
mock_db_client.get_workflow_run_with_context.assert_called_once_with(999)
mock_db_client.get_active_integrations_by_organization.assert_not_called()
@pytest.mark.asyncio
async def test_run_integrations_no_organization():
"""Test handling when user has no organization."""
mock_workflow_run = MagicMock()
mock_workflow_run.id = 1
mock_workflow_run.gathered_context = {"test": "data"}
mock_workflow_run.workflow = MagicMock()
mock_workflow_run.workflow.user = MagicMock()
with patch("api.tasks.run_integrations.set_current_run_id"):
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
# Mock workflow run found but no organization
mock_db_client.get_workflow_run_with_context = AsyncMock(
return_value=(mock_workflow_run, None)
)
# Call the function
await run_integrations_post_workflow_run(None, 1)
# Verify it returns early after checking organization
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
mock_db_client.get_active_integrations_by_organization.assert_not_called()
@pytest.mark.asyncio
async def test_run_integrations_no_gathered_context(mock_workflow_run):
"""Test handling when workflow run has no gathered context."""
mock_workflow_run.gathered_context = None
with patch("api.tasks.run_integrations.set_current_run_id"):
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
# Mock workflow run with no gathered context
mock_db_client.get_workflow_run_with_context = AsyncMock(
return_value=(mock_workflow_run, 100)
)
# Call the function
await run_integrations_post_workflow_run(None, 1)
# Verify it returns early after checking gathered_context
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
mock_db_client.get_active_integrations_by_organization.assert_not_called()
@pytest.mark.asyncio
async def test_run_integrations_stasis_mode(mock_workflow_run):
"""Test that stasis mode triggers vendor sync."""
mock_workflow_run.mode = WorkflowRunMode.STASIS.value
mock_workflow_run.initial_context = {
"vendor": "test_vendor",
"vendor_base_url": "https://api.vendor.com",
"vendor_id": "123",
}
with patch("api.tasks.run_integrations.set_current_run_id"):
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
with patch("api.tasks.run_integrations._sync_vendor_data") as mock_sync:
mock_sync.return_value = None
mock_db_client.get_workflow_run_with_context = AsyncMock(
return_value=(mock_workflow_run, 100)
)
mock_db_client.get_active_integrations_by_organization = AsyncMock(
return_value=[]
)
# Call the function
await run_integrations_post_workflow_run(None, 1)
# Verify vendor sync was called
mock_sync.assert_called_once_with(
mock_workflow_run.initial_context,
mock_workflow_run.gathered_context,
)
@pytest.mark.asyncio
async def test_run_integrations_multiple_integrations(mock_workflow_run):
"""Test processing multiple integrations."""
# Create multiple mock integrations
slack_integration = MagicMock()
slack_integration.provider = "slack"
slack_integration.connection_details = {
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test1"}
}
slack_integration2 = MagicMock()
slack_integration2.provider = "slack"
slack_integration2.connection_details = {
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test2"}
}
with patch("api.tasks.run_integrations.set_current_run_id"):
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
mock_db_client.get_workflow_run_with_context = AsyncMock(
return_value=(mock_workflow_run, 100)
)
mock_db_client.get_active_integrations_by_organization = AsyncMock(
return_value=[slack_integration, slack_integration2]
)
mock_db_client.get_configuration_value = AsyncMock(
return_value={"slack": {"DISPOSITION_CODE": "Test message"}}
)
with patch(
"api.tasks.run_integrations.aiohttp.ClientSession"
) as mock_session_class:
mock_response = MagicMock()
mock_response.status = 200
mock_session = MagicMock()
mock_session.__aenter__.return_value = mock_session
mock_session.__aexit__.return_value = AsyncMock()
mock_post = MagicMock()
mock_post.__aenter__.return_value = mock_response
mock_post.__aexit__.return_value = AsyncMock()
mock_session.post.return_value = mock_post
mock_session_class.return_value = mock_session
# Call the function
await run_integrations_post_workflow_run(None, 1)
# Verify both integrations were processed
assert mock_session.post.call_count == 2
# Check that both webhooks were called
call_urls = [call[0][0] for call in mock_session.post.call_args_list]
assert "https://hooks.slack.com/test1" in call_urls
assert "https://hooks.slack.com/test2" in call_urls

View file

@ -1,330 +0,0 @@
"""Tests for webhook execution in run_integrations.py."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from api.tasks.run_integrations import (
_build_auth_header,
_build_render_context,
_execute_webhook_node,
)
@pytest.fixture(autouse=True)
def mock_logger():
"""Mock the logger for all tests."""
with patch("api.tasks.run_integrations.logger") as mock_log:
mock_log.bind.return_value = mock_log
yield mock_log
class TestBuildAuthHeader:
"""Tests for _build_auth_header function."""
def test_bearer_token(self):
"""Test bearer token auth header."""
credential = MagicMock()
credential.credential_type = "bearer_token"
credential.credential_data = {"token": "my-secret-token"}
result = _build_auth_header(credential)
assert result == {"Authorization": "Bearer my-secret-token"}
def test_api_key(self):
"""Test API key auth header."""
credential = MagicMock()
credential.credential_type = "api_key"
credential.credential_data = {"header_name": "X-API-Key", "api_key": "key123"}
result = _build_auth_header(credential)
assert result == {"X-API-Key": "key123"}
def test_api_key_default_header(self):
"""Test API key with default header name."""
credential = MagicMock()
credential.credential_type = "api_key"
credential.credential_data = {"api_key": "key123"}
result = _build_auth_header(credential)
assert result == {"X-API-Key": "key123"}
def test_basic_auth(self):
"""Test basic auth header."""
credential = MagicMock()
credential.credential_type = "basic_auth"
credential.credential_data = {"username": "user", "password": "pass"}
result = _build_auth_header(credential)
# base64 of "user:pass" is "dXNlcjpwYXNz"
assert result == {"Authorization": "Basic dXNlcjpwYXNz"}
def test_custom_header(self):
"""Test custom header auth."""
credential = MagicMock()
credential.credential_type = "custom_header"
credential.credential_data = {
"header_name": "X-Custom-Auth",
"header_value": "custom-value",
}
result = _build_auth_header(credential)
assert result == {"X-Custom-Auth": "custom-value"}
def test_unknown_type(self):
"""Test unknown credential type returns empty dict."""
credential = MagicMock()
credential.credential_type = "unknown"
credential.credential_data = {}
result = _build_auth_header(credential)
assert result == {}
class TestBuildRenderContext:
"""Tests for _build_render_context function."""
def test_basic_context(self):
"""Test building render context from workflow run."""
workflow_run = MagicMock()
workflow_run.id = 123
workflow_run.name = "WR-TEST-001"
workflow_run.workflow_id = 456
workflow_run.workflow.name = "Test Workflow"
workflow_run.initial_context = {"phone_number": "+1234567890"}
workflow_run.gathered_context = {
"customer_name": "John",
"mapped_call_disposition": "QUALIFIED",
}
workflow_run.usage_info = {"call_duration_seconds": 120}
workflow_run.completed_at = None
result = _build_render_context(workflow_run)
assert result["workflow_run_id"] == 123
assert result["workflow_run_name"] == "WR-TEST-001"
assert result["workflow_id"] == 456
assert result["workflow_name"] == "Test Workflow"
assert result["initial_context"]["phone_number"] == "+1234567890"
assert result["gathered_context"]["customer_name"] == "John"
assert result["cost_info"]["call_duration_seconds"] == 120
assert result["disposition_code"] == "QUALIFIED"
def test_empty_contexts(self):
"""Test with empty/None contexts."""
workflow_run = MagicMock()
workflow_run.id = 1
workflow_run.name = "Test"
workflow_run.workflow_id = 1
workflow_run.workflow.name = "Workflow"
workflow_run.initial_context = None
workflow_run.gathered_context = None
workflow_run.usage_info = None
workflow_run.completed_at = None
result = _build_render_context(workflow_run)
assert result["initial_context"] == {}
assert result["gathered_context"] == {}
assert result["cost_info"] == {}
assert result["disposition_code"] is None
class TestExecuteWebhookNode:
"""Tests for _execute_webhook_node function."""
@pytest.mark.asyncio
async def test_disabled_webhook_skipped(self):
"""Test that disabled webhooks are skipped."""
webhook_data = {"name": "Test Webhook", "enabled": False}
result = await _execute_webhook_node(
webhook_data=webhook_data,
render_context={},
organization_id=1,
)
assert result is True # Returns True for skipped webhooks
@pytest.mark.asyncio
async def test_missing_url_returns_false(self):
"""Test that missing endpoint URL returns False."""
webhook_data = {"name": "Test Webhook", "enabled": True, "endpoint_url": None}
result = await _execute_webhook_node(
webhook_data=webhook_data,
render_context={},
organization_id=1,
)
assert result is False
@pytest.mark.asyncio
async def test_successful_post_request(self):
"""Test successful POST webhook execution."""
webhook_data = {
"name": "CRM Sync",
"enabled": True,
"http_method": "POST",
"endpoint_url": "https://api.example.com/webhook",
"payload_template": {
"call_id": "{{workflow_run_id}}",
"phone": "{{initial_context.phone_number}}",
},
}
render_context = {
"workflow_run_id": 123,
"initial_context": {"phone_number": "+1234567890"},
}
with patch("api.tasks.run_integrations.db_client") as mock_db:
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_client_instance
result = await _execute_webhook_node(
webhook_data=webhook_data,
render_context=render_context,
organization_id=1,
)
assert result is True
# Verify the request was made correctly
mock_client_instance.request.assert_called_once()
call_kwargs = mock_client_instance.request.call_args[1]
assert call_kwargs["method"] == "POST"
assert call_kwargs["url"] == "https://api.example.com/webhook"
assert call_kwargs["json"] == {
"call_id": "123",
"phone": "+1234567890",
}
@pytest.mark.asyncio
async def test_webhook_with_credential(self):
"""Test webhook execution with credential auth."""
webhook_data = {
"name": "Authenticated Webhook",
"enabled": True,
"http_method": "POST",
"endpoint_url": "https://api.example.com/webhook",
"credential_uuid": "cred-123",
"payload_template": {},
}
mock_credential = MagicMock()
mock_credential.name = "API Key"
mock_credential.credential_type = "bearer_token"
mock_credential.credential_data = {"token": "secret-token"}
with patch("api.tasks.run_integrations.db_client") as mock_db:
mock_db.get_credential_by_uuid = AsyncMock(return_value=mock_credential)
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_client_instance
result = await _execute_webhook_node(
webhook_data=webhook_data,
render_context={},
organization_id=1,
)
assert result is True
# Verify auth header was included
call_kwargs = mock_client_instance.request.call_args[1]
assert call_kwargs["headers"]["Authorization"] == "Bearer secret-token"
@pytest.mark.asyncio
async def test_webhook_with_custom_headers(self):
"""Test webhook execution with custom headers."""
webhook_data = {
"name": "Custom Headers Webhook",
"enabled": True,
"http_method": "POST",
"endpoint_url": "https://api.example.com/webhook",
"custom_headers": [
{"key": "X-Source", "value": "dograh"},
{"key": "X-Workflow", "value": "test"},
],
"payload_template": {},
}
with patch("api.tasks.run_integrations.db_client") as mock_db:
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_client_instance
result = await _execute_webhook_node(
webhook_data=webhook_data,
render_context={},
organization_id=1,
)
assert result is True
# Verify custom headers were included
call_kwargs = mock_client_instance.request.call_args[1]
assert call_kwargs["headers"]["X-Source"] == "dograh"
assert call_kwargs["headers"]["X-Workflow"] == "test"
@pytest.mark.asyncio
async def test_webhook_http_error(self):
"""Test webhook execution with HTTP error."""
import httpx
webhook_data = {
"name": "Failing Webhook",
"enabled": True,
"http_method": "POST",
"endpoint_url": "https://api.example.com/webhook",
"payload_template": {},
}
with patch("api.tasks.run_integrations.db_client") as mock_db:
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_response.raise_for_status = MagicMock(
side_effect=httpx.HTTPStatusError(
"Server Error",
request=MagicMock(),
response=mock_response,
)
)
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_client_instance
result = await _execute_webhook_node(
webhook_data=webhook_data,
render_context={},
organization_id=1,
)
assert result is False

View file

@ -1,117 +0,0 @@
"""Tests for the `/s3/signed-url` endpoint.
This test-suite verifies:
1. Regular users can retrieve signed URLs for resources belonging to their own workflow runs.
2. Regular users are *forbidden* from accessing resources that belong to other users.
3. Superusers can access any resource irrespective of ownership.
"""
import os
from unittest.mock import AsyncMock
import pytest
from fastapi import status
# Ensure the S3 environment variables exist so that the module import does not fail
os.environ.setdefault("S3_BUCKET", "test-bucket")
os.environ.setdefault("S3_REGION", "us-east-1")
@pytest.mark.asyncio
async def test_signed_url_for_own_run(monkeypatch, test_client_factory, db_session):
"""A normal user should be able to fetch a signed URL for their own workflow run."""
from api.db.models import UserModel
# ------------------------------------------------------------------
# 1. Set-up create user, workflow & workflow run
# ------------------------------------------------------------------
user: UserModel = await db_session.get_or_create_user_by_provider_id("user_own_run")
workflow = await db_session.create_workflow("wf", {}, user.id)
run = await db_session.create_workflow_run("run", workflow.id, "chat", user.id)
key = f"transcripts/{run.id}.txt"
# Patch S3 signed-url generator to avoid network calls
monkeypatch.setattr(
"api.services.filesystem.s3.s3_fs.aget_signed_url",
AsyncMock(return_value="https://signed-url"),
)
async with test_client_factory(user) as client:
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data == {"url": "https://signed-url", "expires_in": 3600}
@pytest.mark.asyncio
async def test_signed_url_for_other_users_run_forbidden(
monkeypatch, test_client_factory, db_session
):
"""A normal user must *not* access workflow runs owned by someone else."""
from api.db.models import UserModel
# Owner of the workflow run
owner: UserModel = await db_session.get_or_create_user_by_provider_id("owner_user")
workflow = await db_session.create_workflow("wf", {}, owner.id)
run = await db_session.create_workflow_run("run", workflow.id, "chat", owner.id)
# Second user attempting access
intruder: UserModel = await db_session.get_or_create_user_by_provider_id(
"intruder_user"
)
key = f"recordings/{run.id}.wav"
monkeypatch.setattr(
"api.services.filesystem.s3.s3_fs.aget_signed_url",
AsyncMock(return_value="https://signed-url"),
)
async with test_client_factory(intruder) as client:
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
assert response.status_code == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio
async def test_superuser_can_access_any_run(
monkeypatch, test_client_factory, db_session
):
"""Superusers should be able to fetch signed URLs for any workflow run."""
from api.db.models import UserModel
# Normal user & run owner
owner: UserModel = await db_session.get_or_create_user_by_provider_id(
"owner_of_run"
)
workflow = await db_session.create_workflow("wf", {}, owner.id)
run = await db_session.create_workflow_run("run", workflow.id, "chat", owner.id)
# Superuser
superuser: UserModel = await db_session.get_or_create_user_by_provider_id(
"admin_user"
)
# Promote to superuser
# We need to commit the change so that the DB reflects it
async with db_session.async_session() as session:
db_user = await session.get(UserModel, superuser.id)
db_user.is_superuser = True
await session.commit()
await session.refresh(db_user) # ensure we have the latest state
superuser.is_superuser = True
key = f"transcripts/{run.id}.txt"
monkeypatch.setattr(
"api.services.filesystem.s3.s3_fs.aget_signed_url",
AsyncMock(return_value="https://signed-url"),
)
async with test_client_factory(superuser) as client:
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
assert response.status_code == status.HTTP_200_OK
assert response.json()["url"] == "https://signed-url"

View file

@ -1,129 +0,0 @@
import os
import tempfile
from unittest.mock import AsyncMock, patch
import pytest
from api.tasks.s3_upload import upload_audio_to_s3, upload_transcript_to_s3
@pytest.mark.asyncio
async def test_upload_audio_to_s3_success():
"""Test successful audio upload to S3."""
# Create a temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf:
tf.write(b"fake audio data")
temp_path = tf.name
try:
# Mock dependencies
mock_ctx = AsyncMock()
mock_s3_fs = AsyncMock()
mock_db_client = AsyncMock()
with (
patch("api.tasks.s3_upload.s3_fs", mock_s3_fs),
patch("api.tasks.s3_upload.db_client", mock_db_client),
):
await upload_audio_to_s3(
mock_ctx, workflow_run_id=123, temp_file_path=temp_path
)
# Verify S3 upload was called
mock_s3_fs.aupload_file.assert_called_once_with(
temp_path, "recordings/123.wav"
)
# Verify DB update was called
mock_db_client.update_workflow_run.assert_called_once_with(
run_id=123, recording_url="recordings/123.wav"
)
# Verify temp file was cleaned up
assert not os.path.exists(temp_path)
finally:
# Clean up if test failed
if os.path.exists(temp_path):
os.remove(temp_path)
@pytest.mark.asyncio
async def test_upload_audio_to_s3_file_not_found():
"""Test audio upload when temp file doesn't exist."""
mock_ctx = AsyncMock()
with pytest.raises(FileNotFoundError):
await upload_audio_to_s3(
mock_ctx, workflow_run_id=123, temp_file_path="/nonexistent/file.wav"
)
@pytest.mark.asyncio
async def test_upload_transcript_to_s3_success():
"""Test successful transcript upload to S3."""
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tf:
tf.write("Test transcript content")
temp_path = tf.name
try:
# Mock dependencies
mock_ctx = AsyncMock()
mock_s3_fs = AsyncMock()
mock_db_client = AsyncMock()
with (
patch("api.tasks.s3_upload.s3_fs", mock_s3_fs),
patch("api.tasks.s3_upload.db_client", mock_db_client),
):
await upload_transcript_to_s3(
mock_ctx, workflow_run_id=456, temp_file_path=temp_path
)
# Verify S3 upload was called
mock_s3_fs.aupload_file.assert_called_once_with(
temp_path, "transcripts/456.txt"
)
# Verify DB update was called
mock_db_client.update_workflow_run.assert_called_once_with(
run_id=456, transcript_url="transcripts/456.txt"
)
# Verify temp file was cleaned up
assert not os.path.exists(temp_path)
finally:
# Clean up if test failed
if os.path.exists(temp_path):
os.remove(temp_path)
@pytest.mark.asyncio
async def test_upload_s3_cleanup_on_error():
"""Test that temp files are cleaned up even when S3 upload fails."""
# Create a temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf:
tf.write(b"fake audio data")
temp_path = tf.name
try:
mock_ctx = AsyncMock()
mock_s3_fs = AsyncMock()
# Make S3 upload fail
mock_s3_fs.aupload_file.side_effect = Exception("S3 upload failed")
with patch("api.tasks.s3_upload.s3_fs", mock_s3_fs):
with pytest.raises(Exception):
await upload_audio_to_s3(
mock_ctx, workflow_run_id=123, temp_file_path=temp_path
)
# Verify temp file was still cleaned up
assert not os.path.exists(temp_path)
finally:
# Clean up if test failed
if os.path.exists(temp_path):
os.remove(temp_path)

View file

@ -1,89 +0,0 @@
from api.utils.template_renderer import render_template
def test_render_template_basic():
"""Test basic template rendering."""
template = "Hello {{name}}, your balance is {{balance}}."
context = {"name": "John", "balance": "$1000"}
result = render_template(template, context)
assert result == "Hello John, your balance is $1000."
def test_render_template_with_spaces():
"""Test template rendering with spaces around variables."""
template = "Hello {{ name }}, your balance is {{ balance }}."
context = {"name": "John", "balance": "$1000"}
result = render_template(template, context)
assert result == "Hello John, your balance is $1000."
def test_render_template_missing_variable():
"""Test template rendering with missing variables."""
template = "Hello {{name}}, your balance is {{balance}}."
context = {"name": "John"}
result = render_template(template, context)
assert result == "Hello John, your balance is ."
def test_render_template_with_fallback():
"""Test template rendering with fallback values."""
template = "Hello {{name | fallback}}, your balance is {{balance | fallback:$0}}."
context = {}
result = render_template(template, context)
assert result == "Hello Name, your balance is $0."
def test_render_template_with_fallback_existing_value():
"""Test that fallback is not used when value exists."""
template = "Hello {{name | fallback:Guest}}"
context = {"name": "John"}
result = render_template(template, context)
assert result == "Hello John"
def test_render_template_with_line_breaks():
"""Test template rendering with line breaks."""
template = (
"DISPOSITION_CODE: {{call_disposition}}\\nCALL_DURATION: {{call_duration}}"
)
context = {"call_disposition": "XFER", "call_duration": "300"}
result = render_template(template, context)
expected = "DISPOSITION_CODE: XFER\nCALL_DURATION: 300"
assert result == expected
def test_render_template_empty():
"""Test rendering empty template."""
assert render_template("", {}) == ""
assert render_template(None, {}) == None
def test_render_template_no_placeholders():
"""Test template with no placeholders."""
template = "This is a plain text message"
result = render_template(template, {"unused": "value"})
assert result == "This is a plain text message"
def test_render_template_none_values():
"""Test template with None values."""
template = "Value: {{value}}"
context = {"value": None}
result = render_template(template, context)
assert result == "Value: "
def test_render_template_numeric_values():
"""Test template with numeric values."""
template = "Count: {{count}}, Price: {{price}}"
context = {"count": 42, "price": 19.99}
result = render_template(template, context)
assert result == "Count: 42, Price: 19.99"

View file

@ -1,152 +0,0 @@
#!/usr/bin/env python
"""
Test script to verify atomic operations in organization_usage_client.py
This simulates concurrent access from multiple processes.
"""
import asyncio
import os
from concurrent.futures import ProcessPoolExecutor
# Set up environment
os.environ.setdefault("DATABASE_URL", os.environ.get("DATABASE_URL", ""))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from api.db.organization_usage_client import OrganizationUsageClient
async def reserve_quota_process(org_id: int, tokens: int, process_id: int):
"""Simulate a process trying to reserve quota."""
engine = create_async_engine(os.environ["DATABASE_URL"])
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
client = OrganizationUsageClient(async_session)
results = []
for i in range(5):
result = await client.check_and_reserve_quota(org_id, tokens)
results.append((process_id, i, result))
await asyncio.sleep(0.01) # Small delay to increase contention
await engine.dispose()
return results
async def update_usage_process(org_id: int, tokens: int, process_id: int):
"""Simulate a process updating usage after runs."""
engine = create_async_engine(os.environ["DATABASE_URL"])
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
client = OrganizationUsageClient(async_session)
for i in range(5):
await client.update_usage_after_run(org_id, tokens, duration_seconds=10)
await asyncio.sleep(0.01)
await engine.dispose()
return f"Process {process_id} completed updates"
def run_reserve_quota(args):
"""Wrapper to run async function in process."""
org_id, tokens, process_id = args
return asyncio.run(reserve_quota_process(org_id, tokens, process_id))
def run_update_usage(args):
"""Wrapper to run async function in process."""
org_id, tokens, process_id = args
return asyncio.run(update_usage_process(org_id, tokens, process_id))
async def test_concurrent_quota_reservation():
"""Test that concurrent quota reservations are handled atomically."""
print("Testing concurrent quota reservations...")
# Assuming org_id 1 exists with quota enabled
org_id = 1
tokens_per_request = 100
# Run multiple processes trying to reserve quota simultaneously
with ProcessPoolExecutor(max_workers=3) as executor:
futures = []
for i in range(3):
futures.append(
executor.submit(run_reserve_quota, (org_id, tokens_per_request, i))
)
results = []
for future in futures:
results.extend(future.result())
print(f"Reservation results: {results}")
# Check that reservations were handled atomically
successful_reservations = sum(1 for _, _, success in results if success)
print(f"Successful reservations: {successful_reservations}")
async def test_concurrent_usage_updates():
"""Test that concurrent usage updates are handled atomically."""
print("\nTesting concurrent usage updates...")
org_id = 1
tokens_per_update = 50
# Get initial usage
engine = create_async_engine(os.environ["DATABASE_URL"])
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
client = OrganizationUsageClient(async_session)
initial_usage = await client.get_current_usage(org_id)
initial_tokens = initial_usage["used_dograh_tokens"]
print(f"Initial tokens: {initial_tokens}")
# Run multiple processes updating usage simultaneously
with ProcessPoolExecutor(max_workers=3) as executor:
futures = []
for i in range(3):
futures.append(
executor.submit(run_update_usage, (org_id, tokens_per_update, i))
)
for future in futures:
print(future.result())
# Check final usage
final_usage = await client.get_current_usage(org_id)
final_tokens = final_usage["used_dograh_tokens"]
expected_tokens = initial_tokens + (
3 * 5 * tokens_per_update
) # 3 processes * 5 updates * 50 tokens
print(f"Final tokens: {final_tokens}")
print(f"Expected tokens: {expected_tokens}")
print(f"Difference: {final_tokens - expected_tokens}")
await engine.dispose()
if final_tokens == expected_tokens:
print("✅ All updates were applied atomically!")
else:
print("❌ Some updates were lost due to race conditions!")
async def main():
"""Run all concurrency tests."""
try:
await test_concurrent_quota_reservation()
await test_concurrent_usage_updates()
except Exception as e:
print(f"Error during testing: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
print("Starting organization usage concurrency tests...")
print(f"Using DATABASE_URL: {os.environ.get('DATABASE_URL', 'NOT SET')}")
asyncio.run(main())

View file

@ -1,140 +0,0 @@
import json
import os
from unittest.mock import AsyncMock, patch
import pytest
from pipecat.services.openai.llm import OpenAILLMContext
from api.services.workflow.dto import ExtractionVariableDTO, VariableType
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
class DummyLLM:
"""A minimal stub that mimics the parts of an LLM service used by the extractor."""
def __init__(self, streamed_response: str | None = None):
# Optionally provide a pre-defined streaming response for _perform_extraction tests
self._streamed_response = streamed_response or "{}"
self.registered_functions: dict[str, AsyncMock] = {}
# ------------------------------------------------------------------
# API used by VariableExtractionManager
# ------------------------------------------------------------------
def register_function(self, name: str, func, cancel_on_interruption=True): # noqa: D401 simple delegate
self.registered_functions[name] = func
async def get_chat_completions(self, _context, _messages):
"""Return an async generator that yields a single chunk with the full response."""
class _Delta: # noqa: D401 tiny helper classes for stub response
def __init__(self, content):
self.content = content
class _Choice:
def __init__(self, delta):
self.delta = delta
class _Chunk:
def __init__(self, content):
self.choices = [_Choice(_Delta(content))]
async def _stream():
yield _Chunk(self._streamed_response)
return _stream()
class DummyEngine:
"""A bare-bones Engine stub exposing only what the extractor relies on."""
def __init__(self, llm):
self.llm = llm
self.context = OpenAILLMContext()
self._pending_function_calls = 0
# VariableExtractionManager currently updates this private attribute
self._gathered_context: dict = {}
# ------------------------------------------------------------------
# Tests
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_perform_extraction_parses_json_correctly():
"""_perform_extraction should return the parsed JSON from the LLM stream."""
# Set dummy OpenAI API key to prevent initialization errors
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
expected_payload = {"name": "Alice", "age": 30}
llm = DummyLLM(json.dumps(expected_payload))
engine = DummyEngine(llm)
manager = VariableExtractionManager(engine)
# Mock the AsyncOpenAI client and its response
mock_response = AsyncMock()
mock_response.choices = [AsyncMock()]
mock_response.choices[0].message = AsyncMock()
mock_response.choices[0].message.content = json.dumps(expected_payload)
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch(
"api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI",
return_value=mock_client,
):
# Minimal set of variables to extract the prompts themselves are irrelevant here
extraction_variables = [
ExtractionVariableDTO(
name="name", type=VariableType.string, prompt="user name"
),
ExtractionVariableDTO(
name="age", type=VariableType.number, prompt="user age"
),
]
result = await manager._perform_extraction(
extraction_variables, parent_ctx=None, extraction_prompt=""
)
assert result == expected_payload
@pytest.mark.asyncio
async def test_perform_extraction_with_custom_system_prompt():
"""_perform_extraction should use the provided extraction_prompt as system prompt."""
# Set dummy OpenAI API key to prevent initialization errors
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
expected_payload = {"color": "blue"}
llm = DummyLLM(json.dumps(expected_payload))
engine = DummyEngine(llm)
manager = VariableExtractionManager(engine)
# Mock the AsyncOpenAI client and its response
mock_response = AsyncMock()
mock_response.choices = [AsyncMock()]
mock_response.choices[0].message = AsyncMock()
mock_response.choices[0].message.content = json.dumps(expected_payload)
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch(
"api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI",
return_value=mock_client,
):
extraction_variables = [
ExtractionVariableDTO(
name="color", type=VariableType.string, prompt="favourite color"
)
]
# Call with a custom extraction prompt
custom_prompt = "You are a color extraction specialist."
result = await manager._perform_extraction(
extraction_variables, parent_ctx=None, extraction_prompt=custom_prompt
)
assert result == expected_payload

View file

@ -1,547 +0,0 @@
"""
Test voicemail detection in RTC connection flow.
This test emulates how a call is connected using SmallWebRTC,
triggers voicemail detection, and verifies the disconnect reason.
"""
import json
from unittest.mock import AsyncMock, Mock, patch
import pytest
from pipecat.utils.enums import EndTaskReason
from api.routes.rtc_offer import RTCOfferRequest, offer
from api.services.workflow.pipecat_engine_voicemail_detector import VoicemailDetector
@pytest.mark.asyncio
class TestVoicemailDetectionRTC:
"""Test voicemail detection through RTC connection flow."""
async def test_voicemail_detection_full_flow(self):
"""
Test complete voicemail detection flow:
1. RTC connection request
2. Transport sends on_client_connected event
3. Engine initializes with voicemail detection enabled
4. Voicemail detector returns true
5. Call terminates with voicemail_detected reason
6. Transport sends on_client_disconnected event
7. Disconnect reason is properly set
"""
# Mock user and authentication
mock_user = Mock()
mock_user.id = 1
mock_user.organization_id = 1
# Mock workflow with voicemail detection enabled
mock_workflow = Mock()
mock_workflow.id = 100
mock_workflow.workflow_definition_with_fallback = {
"edges": [],
"nodes": [
{
"id": "start",
"type": "start",
"data": {
"detect_voicemail": True,
"system_prompt": "You are a helpful assistant",
},
}
],
}
# Mock workflow run
mock_workflow_run = Mock()
mock_workflow_run.id = 200
mock_workflow_run.is_completed = False
# Create request
request = RTCOfferRequest(
pc_id="test_pc_123",
sdp="test_sdp_offer",
type="offer",
workflow_id=mock_workflow.id,
workflow_run_id=mock_workflow_run.id,
restart_pc=False,
call_context_vars={"test_var": "test_value"},
)
# Mock dependencies
with (
patch("api.services.auth.depends.get_user") as mock_get_user_dep,
patch("api.routes.rtc_offer.SmallWebRTCConnection") as MockWebRTCConnection,
patch("api.routes.rtc_offer.run_pipeline_smallwebrtc") as mock_run_pipeline,
):
# Setup mocks
mock_get_user_dep.return_value = mock_user
# Mock WebRTC connection
mock_connection = Mock()
mock_connection.pc_id = "test_pc_123"
mock_connection.initialize = AsyncMock()
mock_connection.get_answer = Mock(
return_value={
"pc_id": "test_pc_123",
"sdp": "test_sdp_answer",
"type": "answer",
}
)
MockWebRTCConnection.return_value = mock_connection
# Track registered event handlers
registered_handlers = {}
def mock_event_handler(event_name):
def decorator(func):
registered_handlers[event_name] = func
return func
return decorator
mock_connection.event_handler = mock_event_handler
# Mock BackgroundTasks
mock_background_tasks = Mock()
# Create the offer
response = await offer(request, mock_background_tasks, mock_user)
# Verify response
assert response["pc_id"] == "test_pc_123"
assert response["type"] == "answer"
# Verify connection was initialized
mock_connection.initialize.assert_called_once_with(
sdp="test_sdp_offer", type="offer"
)
# Verify background task was added
mock_background_tasks.add_task.assert_called_once()
task_args = mock_background_tasks.add_task.call_args[0]
assert task_args[0] == mock_run_pipeline
assert task_args[1] == mock_connection
assert task_args[2] == mock_workflow.id
assert task_args[3] == mock_workflow_run.id
assert task_args[4] == mock_user.id
assert task_args[5] == {"test_var": "test_value"}
async def test_voicemail_detection_in_pipeline(self):
"""Tests whether the updates happen in on_client_disconnected properly
with values set in the engine"""
# Mock components
mock_transport = AsyncMock()
mock_engine = Mock() # Use Mock instead of AsyncMock for engine
mock_engine.initialize = AsyncMock()
mock_engine.cleanup = AsyncMock()
mock_audio_buffer = AsyncMock()
mock_task = AsyncMock()
mock_aggregator = Mock()
# Setup engine with voicemail detector
mock_voicemail_detector = AsyncMock(spec=VoicemailDetector)
mock_engine.voicemail_detector = mock_voicemail_detector
mock_engine.get_call_disposition = Mock(
return_value=EndTaskReason.VOICEMAIL_DETECTED.value
)
mock_engine.get_gathered_context = Mock(
return_value={
"voicemail_transcript": "Hi, you've reached John's voicemail. Please leave a message.",
"voicemail_confidence": 0.95,
}
)
# Mock usage metrics
mock_aggregator.get_all_usage_metrics_serialized.return_value = {}
# Register event handlers
from api.services.pipecat.event_handlers import (
register_transport_event_handlers,
)
# Track registered handlers
handlers = {}
def track_handler(event_name):
def decorator(func):
handlers[event_name] = func
return func
return decorator
mock_transport.event_handler = track_handler
# Create a mock db_client module with update_workflow_run method
mock_db_client = Mock()
mock_db_client.update_workflow_run = AsyncMock()
with (
patch("api.services.pipecat.event_handlers.db_client", mock_db_client),
patch(
"api.services.pipecat.event_handlers.enqueue_job",
new_callable=AsyncMock,
) as mock_enqueue_job,
patch(
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
return_value=1,
),
patch(
"api.services.pipecat.event_handlers.apply_disposition_mapping",
side_effect=lambda value, org_id: value, # Return value unchanged
),
):
# Register handlers
register_transport_event_handlers(
mock_transport,
workflow_run_id=123,
audio_buffer=mock_audio_buffer,
task=mock_task,
engine=mock_engine,
usage_metrics_aggregator=mock_aggregator,
)
# Verify handlers were registered
assert "on_client_connected" in handlers
assert "on_client_disconnected" in handlers
# Simulate client connection
await handlers["on_client_connected"](
mock_transport, {"id": "participant_1"}
)
# Verify initialization
mock_audio_buffer.start_recording.assert_called_once()
mock_engine.initialize.assert_called_once()
# Simulate voicemail detection and disconnect
await handlers["on_client_disconnected"](
mock_transport, {"id": "participant_1"}, None
)
# Verify engine cleanup
mock_engine.cleanup.assert_called_once()
# TODO: check whether task was cancelled or not once have more
# clarity on how to handle engine disconnect vs remote hangup
# Verify task was NOT cancelled (engine disconnect)
# mock_task.cancel.assert_not_called()
# Verify workflow run was updated with voicemail context
mock_db_client.update_workflow_run.assert_called()
call_args = mock_db_client.update_workflow_run.call_args
assert call_args[1]["run_id"] == 123
# Check that the mapped_call_disposition was set correctly
assert (
call_args[1]["gathered_context"]["mapped_call_disposition"]
== "voicemail_detected"
)
async def test_voicemail_detector_audio_processing(self):
"""Test VoicemailDetector audio processing and detection logic - tests that voicemail detector
calls engine's send_end_task_frame with the correct reason and metadata"""
# Create voicemail detector
detector = VoicemailDetector(detection_duration=5.0, workflow_run_id=123)
# Mock OpenAI client
mock_openai = AsyncMock()
mock_whisper_response = Mock()
mock_whisper_response.text = "Hi, you've reached the voicemail of John Smith. Please leave a message after the beep."
mock_openai.audio.transcriptions.create.return_value = mock_whisper_response
mock_gpt_response = Mock()
mock_gpt_response.choices = [Mock()]
mock_gpt_response.choices[0].message.content = json.dumps(
{
"is_voicemail": True,
"confidence": 0.98,
"reasoning": "Clear voicemail greeting with request to leave message",
}
)
mock_openai.chat.completions.create.return_value = mock_gpt_response
# Mock engine
mock_engine = AsyncMock()
mock_engine.task = AsyncMock()
with (
patch(
"api.services.workflow.pipecat_engine_voicemail_detector.AsyncOpenAI",
return_value=mock_openai,
),
patch(
"api.services.workflow.pipecat_engine_voicemail_detector.s3_fs"
) as mock_s3,
):
# Mock S3 upload to return None (simulating successful upload)
mock_s3.aupload_file = AsyncMock(return_value=True)
# Start detection
await detector.start_detection(mock_engine)
assert detector.is_detecting == True
# Simulate audio data (16kHz, mono, 5 seconds)
sample_rate = 16000
duration = 5.0
audio_data = b"\x00\x00" * int(sample_rate * duration) # Silent audio
# Process audio in chunks
chunk_size = 1600 # 100ms chunks
for i in range(0, len(audio_data), chunk_size):
chunk = audio_data[i : i + chunk_size]
await detector.handle_audio_data(None, chunk, sample_rate, 1)
# Wait for detection to complete
if detector._detection_task:
await detector._detection_task
# Verify OpenAI calls
mock_openai.audio.transcriptions.create.assert_called_once()
mock_openai.chat.completions.create.assert_called_once()
# Verify send_end_task_frame was called with voicemail detection
mock_engine.send_end_task_frame.assert_called_once_with(
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
additional_metadata={
"voicemail_transcript": "Hi, you've reached the voicemail of John Smith. Please leave a message after the beep.",
"voicemail_confidence": 0.98,
"voicemail_reasoning": "Clear voicemail greeting with request to leave message",
"voicemail_detection_duration": 5.0,
"voicemail_audio_s3_path": "voicemail_detections/123_voicemail_98_5.wav", # S3 upload returns True, so filename is used
},
abort_immediately=True,
)
async def test_voicemail_detector_no_detection(self):
"""Test VoicemailDetector when voicemail is not detected."""
# Create voicemail detector
detector = VoicemailDetector(detection_duration=5.0, workflow_run_id=124)
# Mock OpenAI client
mock_openai = AsyncMock()
mock_whisper_response = Mock()
mock_whisper_response.text = "Hello? Hello? Can you hear me?"
mock_openai.audio.transcriptions.create.return_value = mock_whisper_response
mock_gpt_response = Mock()
mock_gpt_response.choices = [Mock()]
mock_gpt_response.choices[0].message.content = json.dumps(
{
"is_voicemail": False,
"confidence": 0.95,
"reasoning": "Live person speaking, asking if caller can hear them",
}
)
mock_openai.chat.completions.create.return_value = mock_gpt_response
# Mock engine
mock_engine = AsyncMock()
mock_engine.task = AsyncMock()
with patch(
"api.services.workflow.pipecat_engine_voicemail_detector.AsyncOpenAI",
return_value=mock_openai,
):
# Start detection
await detector.start_detection(mock_engine)
# Simulate audio data
sample_rate = 16000
duration = 5.0
audio_data = b"\x00\x00" * int(sample_rate * duration)
# Process audio
await detector.handle_audio_data(None, audio_data, sample_rate, 1)
# Wait for detection
if detector._detection_task:
await detector._detection_task
# Verify send_end_task_frame was NOT called
mock_engine.send_end_task_frame.assert_not_called()
async def test_voicemail_detector_cancellation(self):
"""Test VoicemailDetector cancellation before completion."""
# Create voicemail detector
detector = VoicemailDetector(detection_duration=10.0, workflow_run_id=125)
# Mock engine
mock_engine = AsyncMock()
# Start detection
await detector.start_detection(mock_engine)
assert detector.is_detecting == True
# Cancel detection immediately
await detector.stop_detection()
assert detector._is_cancelled == True
# Try to add audio data after cancellation
await detector.handle_audio_data(None, b"\x00\x00" * 1000, 16000, 1)
# Verify buffer didn't grow (no audio accepted after cancellation)
assert len(detector.audio_buffer) == 0
async def test_disconnect_reason_propagation(self):
"""Test that voicemail disconnect reason is properly propagated."""
# Create disconnect reason info directly
disconnect_info = {
"disposition_code": EndTaskReason.VOICEMAIL_DETECTED.value,
"details": "Voicemail detected after 5 seconds of audio",
"is_remote": False,
"is_user_initiated": False,
"is_successful_transfer": False,
"transport_metadata": {
"voicemail_confidence": 0.97,
"voicemail_transcript": "You've reached voicemail...",
},
}
# Verify attributes
assert disconnect_info["disposition_code"] == "voicemail_detected"
assert disconnect_info["is_remote"] == False
assert disconnect_info["is_user_initiated"] == False
assert disconnect_info["is_successful_transfer"] == False
assert (
disconnect_info["details"] == "Voicemail detected after 5 seconds of audio"
)
assert disconnect_info["transport_metadata"]["voicemail_confidence"] == 0.97
async def test_voicemail_detection_end_to_end(self):
"""
Complete end-to-end test covering:
1. on_client_connected event
2. Engine initialization with voicemail detection
3. Audio processing and voicemail detection
4. Engine setting disconnect reason
5. on_client_disconnected event
6. Proper disconnect reason in workflow run update
"""
# Create comprehensive mocks
from api.services.pipecat.event_handlers import (
register_transport_event_handlers,
)
# Mock transport
mock_transport = AsyncMock()
handlers = {}
def track_handler(event_name):
def decorator(func):
handlers[event_name] = func
return func
return decorator
mock_transport.event_handler = track_handler
# Mock audio buffer
mock_audio_buffer = Mock()
mock_audio_buffer.start_recording = AsyncMock()
mock_audio_buffer.stop_recording = AsyncMock()
# Mock task
mock_task = AsyncMock()
# Mock aggregator
mock_aggregator = Mock()
mock_aggregator.get_all_usage_metrics_serialized.return_value = {}
# Create a mock engine with voicemail detection
mock_engine = Mock()
mock_engine.initialize = AsyncMock()
mock_engine.cleanup = AsyncMock()
# Mock voicemail detector
mock_voicemail_detector = Mock()
mock_engine.voicemail_detector = mock_voicemail_detector
mock_engine._voicemail_detector = mock_voicemail_detector
# Initially no disconnect reason
mock_engine.get_call_disposition = Mock(return_value=None)
mock_engine.get_gathered_context = Mock(return_value={})
# Mock db_client
mock_db_client = Mock()
mock_db_client.update_workflow_run = AsyncMock()
with (
patch("api.services.pipecat.event_handlers.db_client", mock_db_client),
patch(
"api.services.pipecat.event_handlers.enqueue_job",
new_callable=AsyncMock,
) as mock_enqueue_job,
patch(
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
return_value=1,
),
patch(
"api.services.pipecat.event_handlers.apply_disposition_mapping",
side_effect=lambda value, org_id: value, # Return value unchanged
),
):
# Register event handlers
register_transport_event_handlers(
mock_transport,
workflow_run_id=123,
audio_buffer=mock_audio_buffer,
task=mock_task,
engine=mock_engine,
usage_metrics_aggregator=mock_aggregator,
)
# Verify handlers were registered
assert "on_client_connected" in handlers
assert "on_client_disconnected" in handlers
# Step 1: Client connects
await handlers["on_client_connected"](
mock_transport, {"id": "participant_1"}
)
# Verify initialization
mock_audio_buffer.start_recording.assert_called_once()
mock_engine.initialize.assert_called_once()
# Step 2-3: Simulate voicemail detection occurs
# Update engine state to reflect voicemail was detected
mock_engine.get_call_disposition = Mock(
return_value=EndTaskReason.VOICEMAIL_DETECTED.value
)
mock_engine.get_gathered_context = Mock(
return_value={
"voicemail_transcript": "You've reached voicemail, leave a message",
"voicemail_confidence": 0.95,
}
)
# Step 5: Client disconnects
await handlers["on_client_disconnected"](
mock_transport, {"id": "participant_1"}, None
)
# Verify engine cleanup
mock_engine.cleanup.assert_called_once()
# Step 6: Verify proper disconnect reason in workflow run update
mock_db_client.update_workflow_run.assert_called()
call_args = mock_db_client.update_workflow_run.call_args
# Check the gathered context includes disconnect reason
gathered_context = call_args[1]["gathered_context"]
assert gathered_context["mapped_call_disposition"] == "voicemail_detected"
assert gathered_context["voicemail_confidence"] == 0.95
assert (
gathered_context["voicemail_transcript"]
== "You've reached voicemail, leave a message"
)
# Verify task was NOT cancelled (engine-initiated disconnect)
mock_task.cancel.assert_not_called()
# Verify audio buffer was stopped
mock_audio_buffer.stop_recording.assert_called_once()
# Verify background jobs were enqueued
assert (
mock_enqueue_job.call_count >= 3
) # At least 3 jobs should be enqueued

View file

@ -1,667 +0,0 @@
"""
Tests for workflow API routes.
This module tests the create, update, get, and validate workflow endpoints.
The fixtures for database setup, test client, and utilities are in conftest.py.
"""
import pytest
from fastapi import status
@pytest.fixture
def sample_workflow_definition():
"""Sample workflow definition for testing."""
return {
"nodes": [
{
"id": "6581",
"type": "startCall",
"position": {"x": 427, "y": 23},
"data": {
"prompt": "Hello, I am Abhishek from Dograh. ",
"is_static": True,
"name": "Start Call",
"is_start": True,
"invalid": False,
"validationMessage": None,
"allow_interrupt": False,
},
"measured": {"width": 300, "height": 100},
"selected": True,
"dragging": False,
},
{
"id": "915",
"type": "agentNode",
"position": {"x": 305, "y": 340},
"data": {
"prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent.",
"name": "Agent",
"invalid": False,
"validationMessage": None,
"allow_interrupt": False,
},
"measured": {"width": 300, "height": 100},
"selected": False,
"dragging": False,
},
{
"id": "7598",
"type": "agentNode",
"position": {"x": 90, "y": 650},
"data": {
"prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon",
"name": "Agent",
"invalid": False,
"validationMessage": None,
"allow_interrupt": True,
},
"measured": {"width": 300, "height": 100},
"selected": False,
"dragging": False,
},
{
"id": "6919",
"type": "agentNode",
"position": {"x": 520, "y": 650},
"data": {
"prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon",
"name": "Agent",
"invalid": False,
"validationMessage": None,
"allow_interrupt": True,
},
"measured": {"width": 300, "height": 100},
"selected": False,
"dragging": False,
},
{
"id": "1802",
"type": "endCall",
"position": {"x": 305, "y": 960},
"data": {
"prompt": "Thank you!",
"invalid": False,
"validationMessage": None,
"is_static": True,
"name": "End Call",
"is_end": True,
"allow_interrupt": False,
},
"measured": {"width": 300, "height": 100},
"selected": False,
"dragging": False,
},
],
"edges": [
{
"animated": True,
"type": "custom",
"source": "915",
"target": "7598",
"id": "xy-edge__915-7598",
"selected": False,
"data": {
"condition": "The customer wants to talk to a customer service agent",
"label": "customer service agent",
"invalid": False,
"validationMessage": None,
},
},
{
"animated": True,
"type": "custom",
"source": "915",
"target": "6919",
"id": "xy-edge__915-6919",
"selected": False,
"data": {
"condition": "customer wants to talk to a sales representative",
"label": "sales representative",
"invalid": False,
"validationMessage": None,
},
},
{
"animated": True,
"type": "custom",
"source": "6581",
"target": "915",
"id": "xy-edge__6581-915",
"selected": False,
"data": {
"condition": "Always take this route",
"label": "Always take this route",
"invalid": False,
"validationMessage": None,
},
},
{
"animated": True,
"type": "custom",
"source": "7598",
"target": "1802",
"id": "xy-edge__7598-1802",
"selected": False,
"data": {
"condition": "end call",
"label": "end call",
"invalid": False,
"validationMessage": None,
},
},
{
"animated": True,
"type": "custom",
"source": "6919",
"target": "1802",
"id": "xy-edge__6919-1802",
"selected": False,
"data": {
"condition": "end call",
"label": "end call",
"invalid": False,
"validationMessage": None,
},
},
],
"viewport": {"x": 0, "y": 0, "zoom": 1},
}
class TestCreateWorkflow:
"""Test cases for creating workflows."""
async def test_create_workflow_success(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test successful workflow creation."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_create_success"
)
request_data = {
"name": "Test Workflow",
"workflow_definition": sample_workflow_definition,
}
async with test_client_factory(test_user) as client:
response = await client.post("/api/v1/workflow/create", json=request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "id" in data
assert data["name"] == "Test Workflow"
assert data["workflow_definition"] == sample_workflow_definition
assert "created_at" in data
assert "current_definition_id" in data
async def test_create_workflow_invalid_definition(
self, test_client_factory, db_session
):
"""Test workflow creation with invalid definition."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_invalid_def"
)
request_data = {
"name": "Invalid Workflow",
"workflow_definition": {"invalid": "structure"},
}
async with test_client_factory(test_user) as client:
response = await client.post("/api/v1/workflow/create", json=request_data)
# The API should still create the workflow even with invalid definition
# Validation happens in the validate endpoint
assert response.status_code == status.HTTP_200_OK
@pytest.mark.asyncio
async def test_create_workflow_missing_name(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test workflow creation without name."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_missing_name"
)
request_data = {"workflow_definition": sample_workflow_definition}
async with test_client_factory(test_user) as client:
response = await client.post("/api/v1/workflow/create", json=request_data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_create_workflow_missing_definition(
self, test_client_factory, db_session
):
"""Test workflow creation without workflow definition."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_missing_definition"
)
request_data = {"name": "Test Workflow"}
async with test_client_factory(test_user) as client:
response = await client.post("/api/v1/workflow/create", json=request_data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
class TestGetWorkflows:
"""Test cases for fetching workflows."""
@pytest.mark.asyncio
async def test_get_all_workflows_empty(self, test_client_factory, db_session):
"""Test getting all workflows when none exist."""
# Create a test user within the test function
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_empty_workflows"
)
# Create a test client for this specific user
async with test_client_factory(test_user) as client:
response = await client.get("/api/v1/workflow/fetch")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)
assert len(data) == 0
@pytest.mark.asyncio
async def test_get_all_workflows_with_data(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test getting all workflows when some exist."""
# Create a test user within the test function
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_with_workflows"
)
# Create a test client for this specific user
async with test_client_factory(test_user) as client:
# Create a workflow first
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Test Workflow 1",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
# Create another workflow
create_response2 = await client.post(
"/api/v1/workflow/create",
json={
"name": "Test Workflow 2",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response2.status_code == status.HTTP_200_OK
# Get all workflows
response = await client.get("/api/v1/workflow/fetch")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)
assert len(data) == 2
# Check that both workflows are returned
workflow_names = [w["name"] for w in data]
assert "Test Workflow 1" in workflow_names
assert "Test Workflow 2" in workflow_names
@pytest.mark.asyncio
async def test_get_specific_workflow(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test getting a specific workflow by ID."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_specific_workflow"
)
async with test_client_factory(test_user) as client:
# Create a workflow first
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Specific Workflow",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
created_workflow = create_response.json()
workflow_id = created_workflow["id"]
# Get the specific workflow
response = await client.get(
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["id"] == workflow_id
assert data["name"] == "Specific Workflow"
assert data["workflow_definition"] == sample_workflow_definition
@pytest.mark.asyncio
async def test_get_nonexistent_workflow(self, test_client_factory, db_session):
"""Test getting a workflow that doesn't exist."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_nonexistent"
)
async with test_client_factory(test_user) as client:
response = await client.get("/api/v1/workflow/fetch?workflow_id=99999")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "not found" in response.json()["detail"].lower()
class TestUpdateWorkflow:
"""Test cases for updating workflows."""
@pytest.mark.asyncio
async def test_update_workflow_name_only(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test updating only the workflow name."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_update_name"
)
async with test_client_factory(test_user) as client:
# Create a workflow first
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Original Name",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
workflow_id = create_response.json()["id"]
# Update the workflow name
update_data = {"name": "Updated Name"}
response = await client.put(
f"/api/v1/workflow/{workflow_id}", json=update_data
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["id"] == workflow_id
assert data["name"] == "Updated Name"
assert (
data["workflow_definition"] == sample_workflow_definition
) # Should remain unchanged
@pytest.mark.asyncio
async def test_update_workflow_name_and_definition(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test updating both workflow name and definition."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_update_both"
)
async with test_client_factory(test_user) as client:
# Create a workflow first
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Original Name",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
workflow_id = create_response.json()["id"]
# Create new workflow definition
new_definition = {
"nodes": [
{
"id": "start",
"type": "start",
"position": {"x": 50, "y": 50},
"data": {"label": "New Start"},
}
],
"edges": [],
}
# Update the workflow
update_data = {
"name": "Updated Name",
"workflow_definition": new_definition,
}
response = await client.put(
f"/api/v1/workflow/{workflow_id}", json=update_data
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["id"] == workflow_id
assert data["name"] == "Updated Name"
assert data["workflow_definition"] == new_definition
@pytest.mark.asyncio
async def test_update_nonexistent_workflow(self, test_client_factory, db_session):
"""Test updating a workflow that doesn't exist."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_update_nonexistent"
)
update_data = {"name": "Updated Name"}
async with test_client_factory(test_user) as client:
response = await client.put("/api/v1/workflow/99999", json=update_data)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "not found" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_update_workflow_missing_name(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test updating a workflow without providing a name."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_update_missing_name"
)
async with test_client_factory(test_user) as client:
# Create a workflow first
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Original Name",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
workflow_id = create_response.json()["id"]
# Try to update without providing name
update_data = {"workflow_definition": sample_workflow_definition}
response = await client.put(
f"/api/v1/workflow/{workflow_id}", json=update_data
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
class TestWorkflowValidation:
"""Test cases for workflow validation endpoint."""
@pytest.mark.asyncio
async def test_validate_workflow_success(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test successful workflow validation."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_validate_success"
)
async with test_client_factory(test_user) as client:
# Create a workflow first
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Valid Workflow",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
workflow_id = create_response.json()["id"]
# Validate the workflow
response = await client.post(f"/api/v1/workflow/{workflow_id}/validate")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["is_valid"] is True
assert data["errors"] == []
@pytest.mark.asyncio
async def test_validate_nonexistent_workflow(self, test_client_factory, db_session):
"""Test validating a workflow that doesn't exist."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_validate_nonexistent"
)
async with test_client_factory(test_user) as client:
response = await client.post("/api/v1/workflow/99999/validate")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "not found" in response.json()["detail"].lower()
class TestWorkflowIntegration:
"""Integration tests for workflow operations."""
@pytest.mark.asyncio
async def test_full_workflow_lifecycle(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test the complete lifecycle of a workflow: create, get, update, validate."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_lifecycle"
)
async with test_client_factory(test_user) as client:
# 1. Create workflow
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Lifecycle Test Workflow",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
workflow_id = create_response.json()["id"]
# 2. Get the created workflow
get_response = await client.get(
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
)
assert get_response.status_code == status.HTTP_200_OK
workflow_data = get_response.json()
assert workflow_data["name"] == "Lifecycle Test Workflow"
# 3. Add a new node in the workflow definition
new_node = {
"id": "6919_new",
"type": "agentNode",
"position": {"x": 520, "y": 650},
"data": {
"prompt": "Something new",
"name": "Agent",
"invalid": False,
"validationMessage": None,
"allow_interrupt": True,
},
"measured": {"width": 300, "height": 100},
"selected": False,
"dragging": False,
}
new_edges = [
{
"source": "6919",
"target": "6919_new",
"id": "xy-edge__6919-6919_new",
"data": {
"condition": "Always take this route",
"label": "Always take this route",
"invalid": False,
"validationMessage": None,
},
},
{
"source": "6919_new",
"target": "1802",
"id": "xy-edge__6919_new-1802",
"data": {
"condition": "Always take this route",
"label": "Always take this route",
"invalid": False,
"validationMessage": None,
},
},
]
new_definition = {
"nodes": [
*sample_workflow_definition["nodes"],
new_node,
],
"edges": [
*sample_workflow_definition["edges"],
*new_edges,
],
}
update_response = await client.put(
f"/api/v1/workflow/{workflow_id}",
json={
"name": "Updated Lifecycle Workflow",
"workflow_definition": new_definition,
},
)
assert update_response.status_code == status.HTTP_200_OK
assert update_response.json()["name"] == "Updated Lifecycle Workflow"
# 4. Validate the updated workflow
validate_response = await client.post(
f"/api/v1/workflow/{workflow_id}/validate"
)
assert validate_response.status_code == status.HTTP_200_OK
assert validate_response.json()["is_valid"] is True
# 5. Verify the update by getting the workflow again
final_get_response = await client.get(
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
)
assert final_get_response.status_code == status.HTTP_200_OK
final_data = final_get_response.json()
assert final_data["name"] == "Updated Lifecycle Workflow"
assert final_data["workflow_definition"] == new_definition

View file

@ -0,0 +1,95 @@
"""Build HTTP authentication headers from ExternalCredentialModel.
This module provides functions for constructing HTTP authentication headers
from ExternalCredentialModel instances. Used by both webhook integrations
and custom tool execution.
"""
import base64
from typing import TYPE_CHECKING, Any, Dict, Optional
if TYPE_CHECKING:
from api.db.models import ExternalCredentialModel
def build_auth_header(credential: "ExternalCredentialModel") -> Dict[str, str]:
"""Build authentication header based on credential type.
Supports the following credential types:
- bearer_token: Authorization: Bearer <token>
- api_key: Custom header with API key
- basic_auth: Authorization: Basic <base64(username:password)>
- custom_header: Any custom header name/value pair
Args:
credential: The ExternalCredentialModel instance
Returns:
Dict with header name and value, or empty dict if credential type
is not recognized or is 'none'
"""
cred_type = credential.credential_type
cred_data = credential.credential_data or {}
if cred_type == "bearer_token":
token = cred_data.get("token", "")
return {"Authorization": f"Bearer {token}"}
elif cred_type == "api_key":
header_name = cred_data.get("header_name", "X-API-Key")
api_key = cred_data.get("api_key", "")
return {header_name: api_key}
elif cred_type == "basic_auth":
username = cred_data.get("username", "")
password = cred_data.get("password", "")
encoded = base64.b64encode(f"{username}:{password}".encode()).decode()
return {"Authorization": f"Basic {encoded}"}
elif cred_type == "custom_header":
header_name = cred_data.get("header_name", "X-Custom")
header_value = cred_data.get("header_value", "")
return {header_name: header_value}
return {}
def build_auth_header_from_data(
credential_type: str,
credential_data: Optional[Dict[str, Any]] = None,
) -> Dict[str, str]:
"""Build authentication header from raw credential data.
This is a convenience function when you have credential data
directly rather than a full ExternalCredentialModel.
Args:
credential_type: Type of credential (bearer_token, api_key, etc.)
credential_data: Dict containing credential-specific fields
Returns:
Dict with header name and value
"""
cred_data = credential_data or {}
if credential_type == "bearer_token":
token = cred_data.get("token", "")
return {"Authorization": f"Bearer {token}"}
elif credential_type == "api_key":
header_name = cred_data.get("header_name", "X-API-Key")
api_key = cred_data.get("api_key", "")
return {header_name: api_key}
elif credential_type == "basic_auth":
username = cred_data.get("username", "")
password = cred_data.get("password", "")
encoded = base64.b64encode(f"{username}:{password}".encode()).decode()
return {"Authorization": f"Basic {encoded}"}
elif credential_type == "custom_header":
header_name = cred_data.get("header_name", "X-Custom")
header_value = cred_data.get("header_value", "")
return {header_name: header_value}
return {}

View file

@ -0,0 +1,498 @@
"use client";
import { ArrowLeft, Code, Globe, Loader2, Save } from "lucide-react";
import { useParams, useRouter } from "next/navigation";
import { useCallback, useEffect, useState } from "react";
import {
getToolApiV1ToolsToolUuidGet,
updateToolApiV1ToolsToolUuidPut,
} from "@/client/sdk.gen";
import type { ToolResponse } from "@/client/types.gen";
// Extended HttpApiConfig with parameters (until client types are regenerated)
interface HttpApiConfigWithParams {
method?: string;
url?: string;
headers?: Record<string, string>;
credential_uuid?: string;
parameters?: ToolParameter[];
timeout_ms?: number;
}
import {
CredentialSelector,
type HttpMethod,
HttpMethodSelector,
KeyValueEditor,
type KeyValueItem,
ParameterEditor,
type ToolParameter,
} from "@/components/http";
import { Button } from "@/components/ui/button";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/ui/card";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { Skeleton } from "@/components/ui/skeleton";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { Textarea } from "@/components/ui/textarea";
import { useAuth } from "@/lib/auth";
export default function ToolDetailPage() {
const { toolUuid } = useParams<{ toolUuid: string }>();
const { user, getAccessToken, redirectToLogin, loading } = useAuth();
const router = useRouter();
const [tool, setTool] = useState<ToolResponse | null>(null);
const [isLoading, setIsLoading] = useState(true);
const [isSaving, setIsSaving] = useState(false);
const [error, setError] = useState<string | null>(null);
const [saveSuccess, setSaveSuccess] = useState(false);
const [showCodeDialog, setShowCodeDialog] = useState(false);
// Form state
const [name, setName] = useState("");
const [description, setDescription] = useState("");
const [httpMethod, setHttpMethod] = useState<HttpMethod>("POST");
const [url, setUrl] = useState("");
const [credentialUuid, setCredentialUuid] = useState("");
const [headers, setHeaders] = useState<KeyValueItem[]>([]);
const [parameters, setParameters] = useState<ToolParameter[]>([]);
const [timeoutMs, setTimeoutMs] = useState(5000);
// Redirect if not authenticated
useEffect(() => {
if (!loading && !user) {
redirectToLogin();
}
}, [loading, user, redirectToLogin]);
const fetchTool = useCallback(async () => {
if (loading || !user || !toolUuid) return;
try {
setIsLoading(true);
setError(null);
const accessToken = await getAccessToken();
const response = await getToolApiV1ToolsToolUuidGet({
path: { tool_uuid: toolUuid },
headers: {
Authorization: `Bearer ${accessToken}`,
},
});
if (response.data) {
setTool(response.data);
populateFormFromTool(response.data);
}
} catch (err) {
setError("Failed to fetch tool");
console.error("Error fetching tool:", err);
} finally {
setIsLoading(false);
}
}, [loading, user, toolUuid, getAccessToken]);
const populateFormFromTool = (tool: ToolResponse) => {
setName(tool.name);
setDescription(tool.description || "");
const config = tool.definition?.config as HttpApiConfigWithParams | undefined;
if (config) {
setHttpMethod((config.method as HttpMethod) || "POST");
setUrl(config.url || "");
setCredentialUuid(config.credential_uuid || "");
setTimeoutMs(config.timeout_ms || 5000);
// Convert headers object to array
if (config.headers) {
setHeaders(
Object.entries(config.headers).map(([key, value]) => ({
key,
value: value as string,
}))
);
} else {
setHeaders([]);
}
// Load parameters
if (config.parameters && Array.isArray(config.parameters)) {
setParameters(
config.parameters.map((p: ToolParameter) => ({
name: p.name || "",
type: p.type || "string",
description: p.description || "",
required: p.required ?? true,
}))
);
} else {
setParameters([]);
}
}
};
useEffect(() => {
fetchTool();
}, [fetchTool]);
const handleSave = async () => {
// Validate URL
if (!url.trim()) {
setError("URL is required");
return;
}
// Validate parameters have names
const invalidParams = parameters.filter((p) => !p.name.trim());
if (invalidParams.length > 0) {
setError("All parameters must have a name");
return;
}
try {
setIsSaving(true);
setError(null);
setSaveSuccess(false);
const accessToken = await getAccessToken();
// Convert headers array to object
const headersObject: Record<string, string> = {};
headers.filter((h) => h.key && h.value).forEach((h) => {
headersObject[h.key] = h.value;
});
// Filter out empty parameters
const validParameters = parameters.filter((p) => p.name.trim());
// Build the request body (cast needed until client types are regenerated)
const requestBody = {
name,
description: description || undefined,
definition: {
schema_version: 1,
type: "http_api",
config: {
method: httpMethod,
url,
credential_uuid: credentialUuid || undefined,
headers:
Object.keys(headersObject).length > 0
? headersObject
: undefined,
parameters:
validParameters.length > 0 ? validParameters : undefined,
timeout_ms: timeoutMs,
},
},
};
const response = await updateToolApiV1ToolsToolUuidPut({
path: { tool_uuid: toolUuid },
// eslint-disable-next-line @typescript-eslint/no-explicit-any
body: requestBody as any,
headers: {
Authorization: `Bearer ${accessToken}`,
},
});
if (response.data) {
setTool(response.data);
setSaveSuccess(true);
setTimeout(() => setSaveSuccess(false), 3000);
}
} catch (err) {
setError("Failed to save tool");
console.error("Error saving tool:", err);
} finally {
setIsSaving(false);
}
};
const getCodeSnippet = () => {
if (!tool) return "";
const headersObj: Record<string, string> = {
"Content-Type": "application/json",
};
headers.filter((h) => h.key && h.value).forEach((h) => {
headersObj[h.key] = h.value;
});
// Build example body from parameters
const exampleBody: Record<string, unknown> = {};
parameters.forEach((p) => {
if (p.type === "number") {
exampleBody[p.name] = 0;
} else if (p.type === "boolean") {
exampleBody[p.name] = true;
} else {
exampleBody[p.name] = `<${p.name}>`;
}
});
const hasBody = httpMethod !== "GET" && httpMethod !== "DELETE" && parameters.length > 0;
return `// ${tool.name}
// ${tool.description || "HTTP API Tool"}
const response = await fetch("${url}", {
method: "${httpMethod}",
headers: ${JSON.stringify(headersObj, null, 4)},${hasBody ? `
body: JSON.stringify(${JSON.stringify(exampleBody, null, 4)}),` : ""}
});
const data = await response.json();`;
};
if (loading || !user) {
return (
<div className="min-h-screen bg-background flex items-center justify-center">
<div className="space-y-4">
<Skeleton className="h-12 w-64" />
<Skeleton className="h-64 w-96" />
</div>
</div>
);
}
if (isLoading) {
return (
<div className="min-h-screen bg-background">
<div className="container mx-auto px-4 py-8">
<div className="max-w-4xl mx-auto space-y-6">
<Skeleton className="h-8 w-48" />
<Skeleton className="h-64 w-full" />
</div>
</div>
</div>
);
}
if (!tool) {
return (
<div className="min-h-screen bg-background">
<div className="container mx-auto px-4 py-8">
<div className="max-w-4xl mx-auto text-center">
<h1 className="text-2xl font-bold mb-4">Tool not found</h1>
<Button onClick={() => router.push("/tools")}>
<ArrowLeft className="w-4 h-4 mr-2" />
Back to Tools
</Button>
</div>
</div>
</div>
);
}
return (
<div className="min-h-screen bg-background">
<div className="container mx-auto px-4 py-8">
<div className="max-w-4xl mx-auto">
{/* Header */}
<div className="flex items-center justify-between mb-6">
<div className="flex items-center gap-4">
<Button
variant="ghost"
size="sm"
onClick={() => router.push("/tools")}
>
<ArrowLeft className="w-4 h-4 mr-2" />
Back
</Button>
<div className="flex items-center gap-3">
<div
className="w-10 h-10 rounded-lg flex items-center justify-center"
style={{
backgroundColor: tool.icon_color || "#3B82F6",
}}
>
<Globe className="w-5 h-5 text-white" />
</div>
<div>
<h1 className="text-xl font-bold">{name}</h1>
<p className="text-sm text-muted-foreground">
HTTP API Tool
</p>
</div>
</div>
</div>
<div className="flex items-center gap-2">
<Button
variant="outline"
onClick={() => setShowCodeDialog(true)}
>
<Code className="w-4 h-4 mr-2" />
View Code
</Button>
<Button onClick={handleSave} disabled={isSaving}>
{isSaving ? (
<>
<Loader2 className="w-4 h-4 mr-2 animate-spin" />
Saving...
</>
) : (
<>
<Save className="w-4 h-4 mr-2" />
Save
</>
)}
</Button>
</div>
</div>
{error && (
<div className="mb-4 p-4 bg-destructive/10 border border-destructive/20 rounded-lg text-destructive">
{error}
</div>
)}
{saveSuccess && (
<div className="mb-4 p-4 bg-green-500/10 border border-green-500/20 rounded-lg text-green-600">
Tool saved successfully!
</div>
)}
<Card>
<CardHeader>
<CardTitle>Tool Configuration</CardTitle>
<CardDescription>
Configure the HTTP API endpoint and request settings
</CardDescription>
</CardHeader>
<CardContent>
<Tabs defaultValue="settings" className="w-full">
<TabsList className="grid w-full grid-cols-3">
<TabsTrigger value="settings">Settings</TabsTrigger>
<TabsTrigger value="auth">Authentication</TabsTrigger>
<TabsTrigger value="parameters">Parameters</TabsTrigger>
</TabsList>
<TabsContent value="settings" className="space-y-4 mt-4">
<div className="grid gap-2">
<Label>Tool Name</Label>
<Label className="text-xs text-muted-foreground">
Use a descriptive name, like &quot;Get Weather using API&quot; for a tool that fetches weather
</Label>
<Input
value={name}
onChange={(e) => setName(e.target.value)}
placeholder="e.g., Book Appointment"
/>
</div>
<div className="grid gap-2">
<Label>Description</Label>
<Label className="text-xs text-muted-foreground">
Provide a description which makes it easy for LLM to understand what this tool does
</Label>
<Textarea
value={description}
onChange={(e) => setDescription(e.target.value)}
placeholder="What does this tool do?"
rows={3}
/>
</div>
<div className="grid grid-cols-2 gap-4">
<div className="grid gap-2">
<Label>HTTP Method</Label>
<HttpMethodSelector
value={httpMethod}
onChange={setHttpMethod}
/>
</div>
<div className="grid gap-2">
<Label>Timeout (ms)</Label>
<Input
type="number"
value={timeoutMs}
onChange={(e) =>
setTimeoutMs(parseInt(e.target.value) || 5000)
}
min={1000}
max={30000}
/>
</div>
</div>
<div className="grid gap-2">
<Label>Endpoint URL</Label>
<Input
value={url}
onChange={(e) => setUrl(e.target.value)}
placeholder="https://api.example.com/appointments"
/>
</div>
</TabsContent>
<TabsContent value="auth" className="space-y-4 mt-4">
<CredentialSelector
value={credentialUuid}
onChange={setCredentialUuid}
/>
</TabsContent>
<TabsContent value="parameters" className="space-y-4 mt-4">
<div className="grid gap-2">
<Label>Tool Parameters</Label>
<Label className="text-xs text-muted-foreground">
Define the parameters that the LLM will provide when calling this tool.
These will be sent as JSON body for POST/PUT/PATCH or as URL query params for GET/DELETE.
</Label>
<ParameterEditor
parameters={parameters}
onChange={setParameters}
/>
</div>
<div className="grid gap-2 pt-4 border-t">
<Label>Custom Headers</Label>
<Label className="text-xs text-muted-foreground">
Add custom headers to include in the request (optional)
</Label>
<KeyValueEditor
items={headers}
onChange={setHeaders}
keyPlaceholder="Header name"
valuePlaceholder="Header value"
addButtonText="Add Header"
/>
</div>
</TabsContent>
</Tabs>
</CardContent>
</Card>
</div>
</div>
{/* Code View Dialog */}
<Dialog open={showCodeDialog} onOpenChange={setShowCodeDialog}>
<DialogContent className="max-w-2xl">
<DialogHeader>
<DialogTitle>Code Preview</DialogTitle>
<DialogDescription>
JavaScript code to make this API call
</DialogDescription>
</DialogHeader>
<div className="bg-muted rounded-lg p-4 font-mono text-sm overflow-auto max-h-96">
<pre>{getCodeSnippet()}</pre>
</div>
</DialogContent>
</Dialog>
</div>
);
}

431
ui/src/app/tools/page.tsx Normal file
View file

@ -0,0 +1,431 @@
"use client";
import { Globe, Plus, Search, Trash2 } from "lucide-react";
import { useRouter } from "next/navigation";
import { useCallback, useEffect, useState } from "react";
import {
createToolApiV1ToolsPost,
deleteToolApiV1ToolsToolUuidDelete,
listToolsApiV1ToolsGet,
} from "@/client/sdk.gen";
import type { ToolResponse } from "@/client/types.gen";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/ui/card";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Skeleton } from "@/components/ui/skeleton";
import { useAuth } from "@/lib/auth";
type ToolCategory = "http_api" | "native" | "integration";
const TOOL_CATEGORIES: { value: ToolCategory; label: string; description: string; disabled?: boolean }[] = [
{
value: "http_api",
label: "External HTTP API",
description: "Make HTTP requests to external APIs",
},
{
value: "native",
label: "Native (Coming Soon)",
description: "Built-in tools like call transfer, DTMF input",
disabled: true,
},
{
value: "integration",
label: "Integration (Coming Soon)",
description: "Third-party integrations like Google Calendar",
disabled: true,
},
];
export default function ToolsPage() {
const { user, getAccessToken, redirectToLogin, loading } = useAuth();
const router = useRouter();
const [tools, setTools] = useState<ToolResponse[]>([]);
const [isLoading, setIsLoading] = useState(true);
const [searchQuery, setSearchQuery] = useState("");
const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false);
const [newToolName, setNewToolName] = useState("");
const [newToolDescription, setNewToolDescription] = useState("");
const [newToolCategory, setNewToolCategory] = useState<ToolCategory>("http_api");
const [isCreating, setIsCreating] = useState(false);
const [error, setError] = useState<string | null>(null);
// Redirect if not authenticated
useEffect(() => {
if (!loading && !user) {
redirectToLogin();
}
}, [loading, user, redirectToLogin]);
const fetchTools = useCallback(async () => {
if (loading || !user) return;
try {
setIsLoading(true);
setError(null);
const accessToken = await getAccessToken();
const response = await listToolsApiV1ToolsGet({
headers: {
Authorization: `Bearer ${accessToken}`,
},
});
if (response.data) {
setTools(response.data);
}
} catch (err) {
setError("Failed to fetch tools");
console.error("Error fetching tools:", err);
} finally {
setIsLoading(false);
}
}, [loading, user, getAccessToken]);
useEffect(() => {
fetchTools();
}, [fetchTools]);
const handleCreateTool = async () => {
if (!newToolName.trim()) {
setError("Please enter a name for the tool");
return;
}
try {
setIsCreating(true);
setError(null);
const accessToken = await getAccessToken();
const response = await createToolApiV1ToolsPost({
body: {
name: newToolName,
description: newToolDescription || undefined,
category: newToolCategory,
icon: "globe",
icon_color: "#3B82F6",
definition: {
schema_version: 1,
type: newToolCategory,
config: {
method: "POST",
url: "",
},
},
},
headers: {
Authorization: `Bearer ${accessToken}`,
},
});
if (response.data) {
setIsCreateDialogOpen(false);
setNewToolName("");
setNewToolDescription("");
setNewToolCategory("http_api");
// Navigate to the new tool's detail page
router.push(`/tools/${response.data.tool_uuid}`);
}
} catch (err) {
setError("Failed to create tool");
console.error("Error creating tool:", err);
} finally {
setIsCreating(false);
}
};
const handleDeleteTool = async (toolUuid: string, e: React.MouseEvent) => {
e.stopPropagation();
if (!confirm("Are you sure you want to archive this tool?")) return;
try {
setError(null);
const accessToken = await getAccessToken();
await deleteToolApiV1ToolsToolUuidDelete({
path: {
tool_uuid: toolUuid,
},
headers: {
Authorization: `Bearer ${accessToken}`,
},
});
fetchTools();
} catch (err) {
setError("Failed to delete tool");
console.error("Error deleting tool:", err);
}
};
const filteredTools = tools.filter(
(tool) =>
tool.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
tool.description?.toLowerCase().includes(searchQuery.toLowerCase())
);
const getCategoryBadge = (category: string) => {
switch (category) {
case "http_api":
return <Badge variant="default">HTTP API</Badge>;
case "native":
return <Badge variant="secondary">Native</Badge>;
case "integration":
return <Badge variant="outline">Integration</Badge>;
default:
return <Badge variant="outline">{category}</Badge>;
}
};
const getStatusBadge = (status: string) => {
switch (status) {
case "active":
return <Badge className="bg-green-500">Active</Badge>;
case "draft":
return <Badge variant="secondary">Draft</Badge>;
case "archived":
return <Badge variant="destructive">Archived</Badge>;
default:
return <Badge variant="outline">{status}</Badge>;
}
};
if (loading || !user) {
return (
<div className="min-h-screen bg-background flex items-center justify-center">
<div className="space-y-4">
<Skeleton className="h-12 w-64" />
<Skeleton className="h-64 w-96" />
</div>
</div>
);
}
return (
<div className="min-h-screen bg-background">
<div className="container mx-auto px-4 py-8">
<div className="max-w-6xl mx-auto">
<div className="mb-8">
<h1 className="text-3xl font-bold mb-2">Tools</h1>
<p className="text-muted-foreground">
Manage reusable HTTP API tools that can be used across your workflows
</p>
</div>
{error && (
<div className="mb-4 p-4 bg-destructive/10 border border-destructive/20 rounded-lg text-destructive">
{error}
</div>
)}
<Card className="mb-6">
<CardHeader>
<div className="flex justify-between items-center">
<div>
<CardTitle>Your Tools</CardTitle>
<CardDescription>
Create and manage HTTP API tools for your organization
</CardDescription>
</div>
<Button onClick={() => setIsCreateDialogOpen(true)}>
<Plus className="w-4 h-4 mr-2" />
Create Tool
</Button>
</div>
</CardHeader>
<CardContent>
{/* Search */}
<div className="relative mb-4">
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 h-4 w-4 text-muted-foreground" />
<Input
placeholder="Search tools..."
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
className="pl-10"
/>
</div>
{isLoading ? (
<div className="space-y-4">
{[1, 2, 3].map((i) => (
<div
key={i}
className="flex items-center justify-between p-4 border rounded-lg"
>
<div className="space-y-2">
<Skeleton className="h-4 w-32" />
<Skeleton className="h-3 w-48" />
</div>
<Skeleton className="h-8 w-20" />
</div>
))}
</div>
) : filteredTools.length === 0 ? (
<div className="text-center py-12">
<Globe className="w-12 h-12 text-muted-foreground mx-auto mb-4" />
<p className="text-muted-foreground mb-4">
{searchQuery
? "No tools match your search"
: "No tools found"}
</p>
{!searchQuery && (
<Button onClick={() => setIsCreateDialogOpen(true)}>
Create Your First Tool
</Button>
)}
</div>
) : (
<div className="space-y-4">
{filteredTools.map((tool) => (
<div
key={tool.tool_uuid}
className="flex items-center justify-between p-4 border rounded-lg hover:bg-muted/50 cursor-pointer transition-colors"
onClick={() =>
router.push(`/tools/${tool.tool_uuid}`)
}
>
<div className="flex items-center gap-4">
<div
className="w-10 h-10 rounded-lg flex items-center justify-center"
style={{
backgroundColor:
tool.icon_color || "#3B82F6",
}}
>
<Globe className="w-5 h-5 text-white" />
</div>
<div>
<div className="flex items-center gap-2">
<span className="font-medium">
{tool.name}
</span>
{getCategoryBadge(tool.category)}
{getStatusBadge(tool.status)}
</div>
{tool.description && (
<p className="text-sm text-muted-foreground mt-1">
{tool.description}
</p>
)}
</div>
</div>
<Button
variant="ghost"
size="sm"
onClick={(e) =>
handleDeleteTool(tool.tool_uuid, e)
}
className="text-destructive hover:text-destructive/90"
>
<Trash2 className="w-4 h-4" />
</Button>
</div>
))}
</div>
)}
</CardContent>
</Card>
</div>
</div>
{/* Create Tool Dialog */}
<Dialog open={isCreateDialogOpen} onOpenChange={setIsCreateDialogOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle>Create New Tool</DialogTitle>
<DialogDescription>
Create a new tool that can be used in your workflows.
</DialogDescription>
</DialogHeader>
<div className="grid gap-4 py-4">
<div className="grid gap-2">
<Label>Tool Type</Label>
<Select
value={newToolCategory}
onValueChange={(v) => setNewToolCategory(v as ToolCategory)}
>
<SelectTrigger className="w-full">
<SelectValue />
</SelectTrigger>
<SelectContent>
{TOOL_CATEGORIES.map((category) => (
<SelectItem
key={category.value}
value={category.value}
disabled={category.disabled}
>
{category.label}
</SelectItem>
))}
</SelectContent>
</Select>
<p className="text-xs text-muted-foreground">
{TOOL_CATEGORIES.find(c => c.value === newToolCategory)?.description}
</p>
</div>
<div className="grid gap-2">
<Label htmlFor="name">Tool Name</Label>
<Label className="text-xs text-muted-foreground">
Use a descriptive name, like &quot;Get Weather using API&quot; for a tool that fetches weather
</Label>
<Input
id="name"
value={newToolName}
onChange={(e) => setNewToolName(e.target.value)}
placeholder="e.g., Book Appointment, Check Inventory"
/>
</div>
<div className="grid gap-2">
<Label htmlFor="description">Description (Optional)</Label>
<Label className="text-xs text-muted-foreground">
Provide a description which makes it easy for LLM to understand what this tool does
</Label>
<Input
id="description"
value={newToolDescription}
onChange={(e) => setNewToolDescription(e.target.value)}
placeholder="What does this tool do?"
/>
</div>
</div>
<DialogFooter>
<Button
variant="outline"
onClick={() => setIsCreateDialogOpen(false)}
>
Cancel
</Button>
<Button onClick={handleCreateTool} disabled={isCreating}>
{isCreating ? "Creating..." : "Create Tool"}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
);
}

View file

@ -221,7 +221,7 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
</ReactFlow>
{/* Bottom-left controls - horizontal layout with custom buttons */}
<div className="absolute bottom-12 left-8 z-[1000] flex gap-2">
<div className="absolute bottom-12 left-8 z-10 flex gap-2">
<TooltipProvider>
{/* Zoom In */}
<Tooltip>

View file

@ -172,7 +172,7 @@ export default function WorkflowRunPage() {
<svg className="h-4 w-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth="2" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" />
</svg>
Back to Agent
Customize Agent
</Button>
</Link>
</CardHeader>

File diff suppressed because one or more lines are too long

View file

@ -144,6 +144,18 @@ export type CreateTestSessionRequest = {
};
};
/**
* Request schema for creating a tool.
*/
export type CreateToolRequest = {
name: string;
description?: string | null;
category?: string;
icon?: string | null;
icon_color?: string | null;
definition: ToolDefinition;
};
export type CreateWorkflowRequest = {
name: string;
workflow_definition: {
@ -174,6 +186,14 @@ export type CreateWorkflowTemplateRequest = {
activity_description: string;
};
/**
* Response schema for the user who created a tool.
*/
export type CreatedByResponse = {
id: number;
provider_id: string;
};
/**
* Response schema for a webhook credential (never includes sensitive data).
*/
@ -309,6 +329,52 @@ export type HttpValidationError = {
detail?: Array<ValidationError>;
};
/**
* Configuration for HTTP API tools.
*/
export type HttpApiConfig = {
/**
* HTTP method (GET, POST, PUT, PATCH, DELETE)
*/
method: string;
/**
* Target URL (supports {{variable}} placeholders)
*/
url: string;
/**
* Static headers to include
*/
headers?: {
[key: string]: string;
} | null;
/**
* Reference to ExternalCredentialModel for auth
*/
credential_uuid?: string | null;
/**
* Request body with {{variable}} placeholders
*/
body_template?: {
[key: string]: unknown;
} | null;
/**
* Request timeout in milliseconds
*/
timeout_ms?: number | null;
/**
* Retry configuration
*/
retry_config?: {
[key: string]: unknown;
} | null;
/**
* JSONPath mappings for response extraction
*/
response_mapping?: {
[key: string]: string;
} | null;
};
/**
* Request payload for superadmin impersonation.
*
@ -500,6 +566,44 @@ export type TestSessionResponse = {
completed_at: string | null;
};
/**
* Tool definition schema.
*/
export type ToolDefinition = {
/**
* Schema version for compatibility
*/
schema_version?: number;
/**
* Tool type (http_api)
*/
type: string;
/**
* Tool configuration
*/
config: HttpApiConfig;
};
/**
* Response schema for a tool.
*/
export type ToolResponse = {
id: number;
tool_uuid: string;
name: string;
description: string | null;
category: string;
icon: string | null;
icon_color: string | null;
status: string;
definition: {
[key: string]: unknown;
};
created_at: string;
updated_at: string | null;
created_by?: CreatedByResponse | null;
};
/**
* Request model for triggering a call via API
*/
@ -566,6 +670,18 @@ export type UpdateIntegrationRequest = {
}>;
};
/**
* Request schema for updating a tool.
*/
export type UpdateToolRequest = {
name?: string | null;
description?: string | null;
icon?: string | null;
icon_color?: string | null;
definition?: ToolDefinition | null;
status?: string | null;
};
export type UpdateWorkflowRequest = {
name: string;
workflow_definition?: {
@ -2347,6 +2463,177 @@ export type UpdateCredentialApiV1CredentialsCredentialUuidPutResponses = {
export type UpdateCredentialApiV1CredentialsCredentialUuidPutResponse = UpdateCredentialApiV1CredentialsCredentialUuidPutResponses[keyof UpdateCredentialApiV1CredentialsCredentialUuidPutResponses];
export type ListToolsApiV1ToolsGetData = {
body?: never;
headers?: {
authorization?: string | null;
};
path?: never;
query?: {
status?: string | null;
category?: string | null;
};
url: '/api/v1/tools/';
};
export type ListToolsApiV1ToolsGetErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type ListToolsApiV1ToolsGetError = ListToolsApiV1ToolsGetErrors[keyof ListToolsApiV1ToolsGetErrors];
export type ListToolsApiV1ToolsGetResponses = {
/**
* Successful Response
*/
200: Array<ToolResponse>;
};
export type ListToolsApiV1ToolsGetResponse = ListToolsApiV1ToolsGetResponses[keyof ListToolsApiV1ToolsGetResponses];
export type CreateToolApiV1ToolsPostData = {
body: CreateToolRequest;
headers?: {
authorization?: string | null;
};
path?: never;
query?: never;
url: '/api/v1/tools/';
};
export type CreateToolApiV1ToolsPostErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type CreateToolApiV1ToolsPostError = CreateToolApiV1ToolsPostErrors[keyof CreateToolApiV1ToolsPostErrors];
export type CreateToolApiV1ToolsPostResponses = {
/**
* Successful Response
*/
200: ToolResponse;
};
export type CreateToolApiV1ToolsPostResponse = CreateToolApiV1ToolsPostResponses[keyof CreateToolApiV1ToolsPostResponses];
export type DeleteToolApiV1ToolsToolUuidDeleteData = {
body?: never;
headers?: {
authorization?: string | null;
};
path: {
tool_uuid: string;
};
query?: never;
url: '/api/v1/tools/{tool_uuid}';
};
export type DeleteToolApiV1ToolsToolUuidDeleteErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type DeleteToolApiV1ToolsToolUuidDeleteError = DeleteToolApiV1ToolsToolUuidDeleteErrors[keyof DeleteToolApiV1ToolsToolUuidDeleteErrors];
export type DeleteToolApiV1ToolsToolUuidDeleteResponses = {
/**
* Successful Response
*/
200: {
[key: string]: unknown;
};
};
export type DeleteToolApiV1ToolsToolUuidDeleteResponse = DeleteToolApiV1ToolsToolUuidDeleteResponses[keyof DeleteToolApiV1ToolsToolUuidDeleteResponses];
export type GetToolApiV1ToolsToolUuidGetData = {
body?: never;
headers?: {
authorization?: string | null;
};
path: {
tool_uuid: string;
};
query?: never;
url: '/api/v1/tools/{tool_uuid}';
};
export type GetToolApiV1ToolsToolUuidGetErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type GetToolApiV1ToolsToolUuidGetError = GetToolApiV1ToolsToolUuidGetErrors[keyof GetToolApiV1ToolsToolUuidGetErrors];
export type GetToolApiV1ToolsToolUuidGetResponses = {
/**
* Successful Response
*/
200: ToolResponse;
};
export type GetToolApiV1ToolsToolUuidGetResponse = GetToolApiV1ToolsToolUuidGetResponses[keyof GetToolApiV1ToolsToolUuidGetResponses];
export type UpdateToolApiV1ToolsToolUuidPutData = {
body: UpdateToolRequest;
headers?: {
authorization?: string | null;
};
path: {
tool_uuid: string;
};
query?: never;
url: '/api/v1/tools/{tool_uuid}';
};
export type UpdateToolApiV1ToolsToolUuidPutErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type UpdateToolApiV1ToolsToolUuidPutError = UpdateToolApiV1ToolsToolUuidPutErrors[keyof UpdateToolApiV1ToolsToolUuidPutErrors];
export type UpdateToolApiV1ToolsToolUuidPutResponses = {
/**
* Successful Response
*/
200: ToolResponse;
};
export type UpdateToolApiV1ToolsToolUuidPutResponse = UpdateToolApiV1ToolsToolUuidPutResponses[keyof UpdateToolApiV1ToolsToolUuidPutResponses];
export type GetIntegrationsApiV1IntegrationGetData = {
body?: never;
headers?: {

View file

@ -0,0 +1,64 @@
"use client";
import { useCallback, useEffect, useState } from "react";
import { listToolsApiV1ToolsGet } from "@/client/sdk.gen";
import type { ToolResponse } from "@/client/types.gen";
import { Badge } from "@/components/ui/badge";
import { useAuth } from "@/lib/auth";
interface ToolBadgesProps {
toolUuids: string[];
}
export function ToolBadges({ toolUuids }: ToolBadgesProps) {
const { getAccessToken } = useAuth();
const [tools, setTools] = useState<ToolResponse[]>([]);
const fetchTools = useCallback(async () => {
try {
const accessToken = await getAccessToken();
const response = await listToolsApiV1ToolsGet({
headers: { Authorization: `Bearer ${accessToken}` },
});
if (response.data) {
setTools(response.data);
}
} catch (error) {
console.error("Failed to fetch tools:", error);
}
}, [getAccessToken]);
useEffect(() => {
if (toolUuids.length > 0) {
fetchTools();
}
}, [toolUuids.length, fetchTools]);
const selectedTools = tools.filter((tool) => toolUuids.includes(tool.tool_uuid));
if (selectedTools.length === 0 && toolUuids.length > 0) {
// Still loading or tools not found
return (
<div className="flex flex-wrap gap-1">
<Badge variant="outline" className="text-xs">
Loading...
</Badge>
</div>
);
}
return (
<div className="flex flex-wrap gap-1">
{selectedTools.map((tool) => (
<Badge
key={tool.tool_uuid}
variant="outline"
className="text-xs"
>
{tool.name}
</Badge>
))}
</div>
);
}

View file

@ -0,0 +1,161 @@
"use client";
import { ExternalLink, Globe, Loader2 } from "lucide-react";
import Link from "next/link";
import { useCallback, useEffect, useState } from "react";
import { listToolsApiV1ToolsGet } from "@/client/sdk.gen";
import type { ToolResponse } from "@/client/types.gen";
import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox";
import { Label } from "@/components/ui/label";
import { useAuth } from "@/lib/auth";
interface ToolSelectorProps {
value: string[];
onChange: (uuids: string[]) => void;
disabled?: boolean;
label?: string;
description?: string;
showLabel?: boolean;
}
export function ToolSelector({
value,
onChange,
disabled = false,
label = "Tools",
description = "Select tools that the agent can use during the conversation.",
showLabel = true,
}: ToolSelectorProps) {
const { getAccessToken } = useAuth();
const [tools, setTools] = useState<ToolResponse[]>([]);
const [loading, setLoading] = useState(false);
const fetchTools = useCallback(async () => {
setLoading(true);
try {
const accessToken = await getAccessToken();
const response = await listToolsApiV1ToolsGet({
headers: { Authorization: `Bearer ${accessToken}` },
query: { status: "active" },
});
if (response.error) {
console.error("Failed to fetch tools:", response.error);
setTools([]);
return;
}
if (response.data) {
setTools(response.data);
}
} catch (error) {
console.error("Failed to fetch tools:", error);
setTools([]);
} finally {
setLoading(false);
}
}, [getAccessToken]);
useEffect(() => {
fetchTools();
}, [fetchTools]);
const handleToggle = (toolUuid: string, checked: boolean) => {
if (checked) {
onChange([...value, toolUuid]);
} else {
onChange(value.filter((id) => id !== toolUuid));
}
};
return (
<div className="grid gap-2">
{showLabel && (
<>
<Label>{label}</Label>
{description && (
<Label className="text-xs text-muted-foreground">
{description}
</Label>
)}
</>
)}
{loading ? (
<div className="flex items-center gap-2 p-3 border rounded-md">
<Loader2 className="h-4 w-4 animate-spin" />
<span className="text-sm text-muted-foreground">Loading tools...</span>
</div>
) : tools.length === 0 ? (
<div className="p-4 border rounded-md text-center">
<p className="text-sm text-muted-foreground mb-2">
No tools available.
</p>
<Button variant="outline" size="sm" asChild>
<Link href="/tools" target="_blank">
<ExternalLink className="h-4 w-4 mr-2" />
Create a Tool
</Link>
</Button>
</div>
) : (
<div className="border rounded-md divide-y">
{tools.map((tool) => {
const isSelected = value.includes(tool.tool_uuid);
return (
<label
key={tool.tool_uuid}
className={`flex items-center gap-3 p-3 cursor-pointer hover:bg-muted/50 ${
disabled ? "opacity-50 cursor-not-allowed" : ""
}`}
>
<Checkbox
checked={isSelected}
disabled={disabled}
onCheckedChange={(checked) => {
handleToggle(tool.tool_uuid, checked === true);
}}
/>
<div
className="w-6 h-6 rounded flex items-center justify-center shrink-0"
style={{
backgroundColor: tool.icon_color || "#3B82F6",
}}
>
<Globe className="h-3 w-3 text-white" />
</div>
<div className="flex flex-col min-w-0 flex-1">
<span className="text-sm font-medium truncate">
{tool.name}
</span>
{tool.description && (
<span className="text-xs text-muted-foreground truncate">
{tool.description}
</span>
)}
</div>
</label>
);
})}
<div className="p-2 bg-muted/30">
<Link
href="/tools"
target="_blank"
className="flex items-center gap-2 text-sm text-muted-foreground hover:text-foreground"
>
<ExternalLink className="h-4 w-4" />
Manage Tools
</Link>
</div>
</div>
)}
{value.length > 0 && (
<p className="text-xs text-muted-foreground">
{value.length} tool{value.length !== 1 ? "s" : ""} selected
</p>
)}
</div>
);
}

View file

@ -1,9 +1,11 @@
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
import { Edit, Headset, PlusIcon,Trash2Icon } from "lucide-react";
import { Edit, Headset, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
import { memo, useEffect, useState } from "react";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import { ExtractionVariable,FlowNodeData } from "@/components/flow/types";
import { ToolBadges } from "@/components/flow/ToolBadges";
import { ToolSelector } from "@/components/flow/ToolSelector";
import { ExtractionVariable, FlowNodeData } from "@/components/flow/types";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
@ -30,6 +32,8 @@ interface AgentNodeEditFormProps {
setVariables: (vars: ExtractionVariable[]) => void;
addGlobalPrompt: boolean;
setAddGlobalPrompt: (value: boolean) => void;
toolUuids: string[];
setToolUuids: (value: string[]) => void;
}
interface AgentNodeProps extends NodeProps {
@ -50,6 +54,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
const [extractionPrompt, setExtractionPrompt] = useState(data.extraction_prompt ?? "");
const [variables, setVariables] = useState<ExtractionVariable[]>(data.extraction_variables ?? []);
const [addGlobalPrompt, setAddGlobalPrompt] = useState(data.add_global_prompt ?? true);
const [toolUuids, setToolUuids] = useState<string[]>(data.tool_uuids ?? []);
const handleSave = async () => {
handleSaveNodeData({
@ -61,6 +66,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
extraction_prompt: extractionPrompt,
extraction_variables: variables,
add_global_prompt: addGlobalPrompt,
tool_uuids: toolUuids.length > 0 ? toolUuids : undefined,
});
setOpen(false);
// Save the workflow after updating node data with a small delay to ensure state is updated
@ -79,6 +85,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
setExtractionPrompt(data.extraction_prompt ?? "");
setVariables(data.extraction_variables ?? []);
setAddGlobalPrompt(data.add_global_prompt ?? true);
setToolUuids(data.tool_uuids ?? []);
}
setOpen(newOpen);
};
@ -93,6 +100,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
setExtractionPrompt(data.extraction_prompt ?? "");
setVariables(data.extraction_variables ?? []);
setAddGlobalPrompt(data.add_global_prompt ?? true);
setToolUuids(data.tool_uuids ?? []);
}
}, [data, open]);
@ -114,6 +122,15 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
<p className="text-sm text-muted-foreground line-clamp-5 leading-relaxed">
{data.prompt || 'No prompt configured'}
</p>
{data.tool_uuids && data.tool_uuids.length > 0 && (
<div className="mt-3 pt-3 border-t border-border/50">
<div className="flex items-center gap-1.5 text-xs text-muted-foreground mb-2">
<Wrench className="h-3 w-3" />
<span>Tools:</span>
</div>
<ToolBadges toolUuids={data.tool_uuids} />
</div>
)}
</NodeContent>
<NodeToolbar isVisible={selected} position={Position.Right}>
@ -151,6 +168,8 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
setVariables={setVariables}
addGlobalPrompt={addGlobalPrompt}
setAddGlobalPrompt={setAddGlobalPrompt}
toolUuids={toolUuids}
setToolUuids={setToolUuids}
/>
)}
</NodeEditDialog>
@ -173,6 +192,8 @@ const AgentNodeEditForm = ({
setVariables,
addGlobalPrompt,
setAddGlobalPrompt,
toolUuids,
setToolUuids,
}: AgentNodeEditFormProps) => {
const handleVariableNameChange = (idx: number, value: string) => {
const newVars = [...variables];
@ -307,6 +328,15 @@ const AgentNodeEditForm = ({
</Button>
</div>
)}
{/* Tools Section */}
<div className="pt-4 border-t mt-4">
<ToolSelector
value={toolUuids}
onChange={setToolUuids}
description="Select tools that the agent can invoke during this conversation step."
/>
</div>
</div>
);
};

View file

@ -1,8 +1,10 @@
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
import { Edit, Play } from "lucide-react";
import { Edit, Play, Wrench } from "lucide-react";
import { memo, useEffect, useState } from "react";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import { ToolBadges } from "@/components/flow/ToolBadges";
import { ToolSelector } from "@/components/flow/ToolSelector";
import { FlowNodeData } from "@/components/flow/types";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
@ -31,6 +33,8 @@ interface StartCallEditFormProps {
setDelayedStart: (value: boolean) => void;
delayedStartDuration: number;
setDelayedStartDuration: (value: number) => void;
toolUuids: string[];
setToolUuids: (value: string[]) => void;
}
interface StartCallNodeProps extends NodeProps {
@ -52,6 +56,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
const [detectVoicemail, setDetectVoicemail] = useState(data.detect_voicemail ?? false);
const [delayedStart, setDelayedStart] = useState(data.delayed_start ?? false);
const [delayedStartDuration, setDelayedStartDuration] = useState(data.delayed_start_duration ?? 2);
const [toolUuids, setToolUuids] = useState<string[]>(data.tool_uuids ?? []);
const handleSave = async () => {
handleSaveNodeData({
@ -62,7 +67,8 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
add_global_prompt: addGlobalPrompt,
detect_voicemail: detectVoicemail,
delayed_start: delayedStart,
delayed_start_duration: delayedStart ? delayedStartDuration : undefined
delayed_start_duration: delayedStart ? delayedStartDuration : undefined,
tool_uuids: toolUuids.length > 0 ? toolUuids : undefined,
});
setOpen(false);
// Save the workflow after updating node data with a small delay to ensure state is updated
@ -81,6 +87,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
setDetectVoicemail(data.detect_voicemail ?? false);
setDelayedStart(data.delayed_start ?? false);
setDelayedStartDuration(data.delayed_start_duration ?? 3);
setToolUuids(data.tool_uuids ?? []);
}
setOpen(newOpen);
};
@ -95,6 +102,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
setDetectVoicemail(data.detect_voicemail ?? false);
setDelayedStart(data.delayed_start ?? false);
setDelayedStartDuration(data.delayed_start_duration ?? 3);
setToolUuids(data.tool_uuids ?? []);
}
}, [data, open]);
@ -115,6 +123,15 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
<p className="text-sm text-muted-foreground line-clamp-5 leading-relaxed">
{data.prompt || 'No prompt configured'}
</p>
{data.tool_uuids && data.tool_uuids.length > 0 && (
<div className="mt-3 pt-3 border-t border-border/50">
<div className="flex items-center gap-1.5 text-xs text-muted-foreground mb-2">
<Wrench className="h-3 w-3" />
<span>Tools:</span>
</div>
<ToolBadges toolUuids={data.tool_uuids} />
</div>
)}
</NodeContent>
<NodeToolbar isVisible={selected} position={Position.Right}>
@ -147,6 +164,8 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
setDelayedStart={setDelayedStart}
delayedStartDuration={delayedStartDuration}
setDelayedStartDuration={setDelayedStartDuration}
toolUuids={toolUuids}
setToolUuids={setToolUuids}
/>
)}
</NodeEditDialog>
@ -168,7 +187,9 @@ const StartCallEditForm = ({
delayedStart,
setDelayedStart,
delayedStartDuration,
setDelayedStartDuration
setDelayedStartDuration,
toolUuids,
setToolUuids,
}: StartCallEditFormProps) => {
return (
<div className="grid gap-2">
@ -258,6 +279,15 @@ const StartCallEditForm = ({
</div>
)}
</div>
{/* Tools Section */}
<div className="pt-4 border-t mt-4">
<ToolSelector
value={toolUuids}
onChange={setToolUuids}
description="Select tools that the agent can invoke during this conversation step."
/>
</div>
</div>
);
};

View file

@ -1,36 +1,22 @@
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
import { AlertCircle, Circle, Edit, Link2, Loader2, PlusIcon, Trash2Icon } from "lucide-react";
import { memo, useCallback, useEffect, useState } from "react";
import { Circle, Edit, Link2, Trash2Icon } from "lucide-react";
import { memo, useEffect, useState } from "react";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import {
createCredentialApiV1CredentialsPost,
listCredentialsApiV1CredentialsGet,
} from "@/client";
import { CredentialResponse, WebhookCredentialType } from "@/client/types.gen";
import { FlowNodeData } from "@/components/flow/types";
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
CredentialSelector,
type HttpMethod,
HttpMethodSelector,
KeyValueEditor,
type KeyValueItem,
} from "@/components/http";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { JsonEditor, validateJson } from "@/components/ui/json-editor";
import { Label } from "@/components/ui/label";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Switch } from "@/components/ui/switch";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { useAuth } from "@/lib/auth";
import { NodeContent } from "./common/NodeContent";
import { NodeEditDialog } from "./common/NodeEditDialog";
@ -40,17 +26,9 @@ interface WebhookNodeProps extends NodeProps {
data: FlowNodeData;
}
type HttpMethod = "GET" | "POST" | "PUT" | "PATCH" | "DELETE";
interface CustomHeader {
key: string;
value: string;
}
export const WebhookNode = memo(({ data, selected, id }: WebhookNodeProps) => {
const { open, setOpen, handleSaveNodeData, handleDeleteNode } = useNodeHandlers({ id });
const { saveWorkflow } = useWorkflow();
const { getAccessToken } = useAuth();
// Form state
const [name, setName] = useState(data.name || "Webhook");
@ -58,41 +36,13 @@ export const WebhookNode = memo(({ data, selected, id }: WebhookNodeProps) => {
const [httpMethod, setHttpMethod] = useState<HttpMethod>(data.http_method || "POST");
const [endpointUrl, setEndpointUrl] = useState(data.endpoint_url || "");
const [credentialUuid, setCredentialUuid] = useState(data.credential_uuid || "");
const [customHeaders, setCustomHeaders] = useState<CustomHeader[]>(
const [customHeaders, setCustomHeaders] = useState<KeyValueItem[]>(
data.custom_headers || []
);
const [payloadTemplate, setPayloadTemplate] = useState(
data.payload_template ? JSON.stringify(data.payload_template, null, 2) : "{}"
);
// Credentials state
const [credentials, setCredentials] = useState<CredentialResponse[]>([]);
const [credentialsLoading, setCredentialsLoading] = useState(false);
// Fetch credentials when dialog opens
const fetchCredentials = useCallback(async () => {
setCredentialsLoading(true);
try {
const accessToken = await getAccessToken();
const response = await listCredentialsApiV1CredentialsGet({
headers: { Authorization: `Bearer ${accessToken}` },
});
if (response.error) {
console.error("Failed to fetch credentials:", response.error);
setCredentials([]);
return;
}
if (response.data) {
setCredentials(response.data);
}
} catch (error) {
console.error("Failed to fetch credentials:", error);
setCredentials([]);
} finally {
setCredentialsLoading(false);
}
}, [getAccessToken]);
// Validation state - only shown on save attempt
const [jsonError, setJsonError] = useState<string | null>(null);
const [endpointError, setEndpointError] = useState<string | null>(null);
@ -143,8 +93,6 @@ export const WebhookNode = memo(({ data, selected, id }: WebhookNodeProps) => {
// Clear any previous errors
setJsonError(null);
setEndpointError(null);
// Fetch credentials when dialog opens
fetchCredentials();
}
setOpen(newOpen);
};
@ -233,10 +181,6 @@ export const WebhookNode = memo(({ data, selected, id }: WebhookNodeProps) => {
setEndpointUrl={setEndpointUrl}
credentialUuid={credentialUuid}
setCredentialUuid={setCredentialUuid}
credentials={credentials}
credentialsLoading={credentialsLoading}
onRefreshCredentials={fetchCredentials}
getAccessToken={getAccessToken}
customHeaders={customHeaders}
setCustomHeaders={setCustomHeaders}
payloadTemplate={payloadTemplate}
@ -259,16 +203,23 @@ interface WebhookNodeEditFormProps {
setEndpointUrl: (value: string) => void;
credentialUuid: string;
setCredentialUuid: (value: string) => void;
credentials: CredentialResponse[];
credentialsLoading: boolean;
onRefreshCredentials: () => Promise<void>;
getAccessToken: () => Promise<string>;
customHeaders: CustomHeader[];
setCustomHeaders: (value: CustomHeader[]) => void;
customHeaders: KeyValueItem[];
setCustomHeaders: (value: KeyValueItem[]) => void;
payloadTemplate: string;
setPayloadTemplate: (value: string) => void;
}
const availableVariables = [
{ name: "workflow_run_id", description: "Unique ID of the workflow run" },
{ name: "workflow_id", description: "ID of the workflow" },
{ name: "workflow_name", description: "Name of the workflow" },
{ name: "initial_context.*", description: "Initial context variables" },
{ name: "gathered_context.*", description: "Extracted variables" },
{ name: "cost_info.call_duration_seconds", description: "Call duration" },
{ name: "recording_url", description: "Call recording URL" },
{ name: "transcript_url", description: "Transcript URL" },
];
const WebhookNodeEditForm = ({
name,
setName,
@ -280,130 +231,11 @@ const WebhookNodeEditForm = ({
setEndpointUrl,
credentialUuid,
setCredentialUuid,
credentials,
credentialsLoading,
onRefreshCredentials,
getAccessToken,
customHeaders,
setCustomHeaders,
payloadTemplate,
setPayloadTemplate,
}: WebhookNodeEditFormProps) => {
// Add Credential Dialog state
const [isAddCredentialOpen, setIsAddCredentialOpen] = useState(false);
const [newCredName, setNewCredName] = useState("");
const [newCredDescription, setNewCredDescription] = useState("");
const [newCredType, setNewCredType] = useState<WebhookCredentialType>("bearer_token");
const [newCredData, setNewCredData] = useState<Record<string, string>>({});
const [isCreatingCredential, setIsCreatingCredential] = useState(false);
const [credentialError, setCredentialError] = useState<string | null>(null);
const handleCreateCredential = async () => {
if (!newCredName.trim()) return;
setIsCreatingCredential(true);
setCredentialError(null);
try {
const accessToken = await getAccessToken();
const response = await createCredentialApiV1CredentialsPost({
headers: { Authorization: `Bearer ${accessToken}` },
body: {
name: newCredName,
description: newCredDescription || undefined,
credential_type: newCredType,
credential_data: newCredData,
},
});
if (response.error) {
const errorDetail = (response.error as { detail?: string })?.detail
|| "Failed to create credential";
setCredentialError(errorDetail);
return;
}
if (response.data) {
// Refresh credentials list
await onRefreshCredentials();
// Select the newly created credential
setCredentialUuid(response.data.uuid);
// Close dialog and reset form
setIsAddCredentialOpen(false);
setNewCredName("");
setNewCredDescription("");
setNewCredType("bearer_token");
setNewCredData({});
setCredentialError(null);
}
} catch (error) {
console.error("Failed to create credential:", error);
setCredentialError(
error instanceof Error ? error.message : "An unexpected error occurred"
);
} finally {
setIsCreatingCredential(false);
}
};
const handleAddCredentialDialogChange = (open: boolean) => {
setIsAddCredentialOpen(open);
if (!open) {
// Reset error when closing dialog
setCredentialError(null);
}
};
const getCredentialDataFields = (type: WebhookCredentialType) => {
switch (type) {
case "api_key":
return [
{ key: "header_name", label: "Header Name", placeholder: "X-API-Key" },
{ key: "api_key", label: "API Key", placeholder: "your-api-key", isSecret: true },
];
case "bearer_token":
return [
{ key: "token", label: "Token", placeholder: "your-bearer-token", isSecret: true },
];
case "basic_auth":
return [
{ key: "username", label: "Username", placeholder: "username" },
{ key: "password", label: "Password", placeholder: "password", isSecret: true },
];
case "custom_header":
return [
{ key: "header_name", label: "Header Name", placeholder: "X-Custom-Header" },
{ key: "header_value", label: "Header Value", placeholder: "header-value", isSecret: true },
];
default:
return [];
}
};
const addHeader = () => {
setCustomHeaders([...customHeaders, { key: "", value: "" }]);
};
const updateHeader = (index: number, field: "key" | "value", value: string) => {
const newHeaders = [...customHeaders];
newHeaders[index] = { ...newHeaders[index], [field]: value };
setCustomHeaders(newHeaders);
};
const removeHeader = (index: number) => {
setCustomHeaders(customHeaders.filter((_, i) => i !== index));
};
const availableVariables = [
{ name: "workflow_run_id", description: "Unique ID of the workflow run" },
{ name: "workflow_id", description: "ID of the workflow" },
{ name: "workflow_name", description: "Name of the workflow" },
{ name: "initial_context.*", description: "Initial context variables" },
{ name: "gathered_context.*", description: "Extracted variables" },
{ name: "cost_info.call_duration_seconds", description: "Call duration" },
{ name: "recording_url", description: "Call recording URL" },
{ name: "transcript_url", description: "Transcript URL" },
];
return (
<Tabs defaultValue="basic" className="w-full">
<TabsList className="grid w-full grid-cols-4">
@ -432,18 +264,10 @@ const WebhookNodeEditForm = ({
<div className="grid gap-2">
<Label>HTTP Method</Label>
<Select value={httpMethod} onValueChange={(v) => setHttpMethod(v as HttpMethod)}>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="GET">GET</SelectItem>
<SelectItem value="POST">POST</SelectItem>
<SelectItem value="PUT">PUT</SelectItem>
<SelectItem value="PATCH">PATCH</SelectItem>
<SelectItem value="DELETE">DELETE</SelectItem>
</SelectContent>
</Select>
<HttpMethodSelector
value={httpMethod}
onChange={setHttpMethod}
/>
</div>
<div className="grid gap-2">
@ -460,154 +284,10 @@ const WebhookNodeEditForm = ({
</TabsContent>
<TabsContent value="auth" className="space-y-4 mt-4">
<div className="grid gap-2">
<Label>Credential</Label>
<Label className="text-xs text-muted-foreground">
Select a credential for authentication, or leave empty for no auth.
</Label>
<div className="flex gap-2">
<Select
value={credentialUuid || "none"}
onValueChange={(v) => setCredentialUuid(v === "none" ? "" : v)}
disabled={credentialsLoading}
>
<SelectTrigger className="flex-1">
{credentialsLoading ? (
<div className="flex items-center gap-2">
<Loader2 className="h-4 w-4 animate-spin" />
<span>Loading...</span>
</div>
) : (
<SelectValue placeholder="No authentication" />
)}
</SelectTrigger>
<SelectContent>
<SelectItem value="none">No authentication</SelectItem>
{credentials.map((cred) => (
<SelectItem key={cred.uuid} value={cred.uuid}>
{cred.name} ({cred.credential_type})
</SelectItem>
))}
</SelectContent>
</Select>
<Button
variant="outline"
size="icon"
onClick={() => setIsAddCredentialOpen(true)}
title="Add new credential"
>
<PlusIcon className="h-4 w-4" />
</Button>
</div>
</div>
{credentials.length === 0 && !credentialsLoading && (
<div className="p-3 border rounded-md bg-muted/20">
<p className="text-sm text-muted-foreground">
No credentials found. Click the + button to create one.
</p>
</div>
)}
{/* Add Credential Dialog */}
<Dialog open={isAddCredentialOpen} onOpenChange={handleAddCredentialDialogChange}>
<DialogContent className="sm:max-w-md">
<DialogHeader>
<DialogTitle>Add Credential</DialogTitle>
<DialogDescription>
Create a new credential for webhook authentication.
</DialogDescription>
</DialogHeader>
{/* Error display */}
{credentialError && (
<div className="flex items-start gap-2 p-3 text-sm text-red-600 bg-red-50 border border-red-200 rounded-md">
<AlertCircle className="h-4 w-4 mt-0.5 flex-shrink-0" />
<span>{credentialError}</span>
</div>
)}
<div className="space-y-4 py-4">
<div className="grid gap-2">
<Label htmlFor="cred-name">Name *</Label>
<Input
id="cred-name"
value={newCredName}
onChange={(e) => setNewCredName(e.target.value)}
placeholder="My API Key"
/>
</div>
<div className="grid gap-2">
<Label htmlFor="cred-description">Description</Label>
<Input
id="cred-description"
value={newCredDescription}
onChange={(e) => setNewCredDescription(e.target.value)}
placeholder="Optional description"
/>
</div>
<div className="grid gap-2">
<Label>Credential Type</Label>
<Select
value={newCredType}
onValueChange={(v) => {
setNewCredType(v as WebhookCredentialType);
setNewCredData({});
}}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="bearer_token">Bearer Token</SelectItem>
<SelectItem value="api_key">API Key</SelectItem>
<SelectItem value="basic_auth">Basic Auth</SelectItem>
<SelectItem value="custom_header">Custom Header</SelectItem>
</SelectContent>
</Select>
</div>
{getCredentialDataFields(newCredType).map((field) => (
<div key={field.key} className="grid gap-2">
<Label htmlFor={`cred-${field.key}`}>{field.label}</Label>
<Input
id={`cred-${field.key}`}
type={field.isSecret ? "password" : "text"}
value={newCredData[field.key] || ""}
onChange={(e) =>
setNewCredData((prev) => ({
...prev,
[field.key]: e.target.value,
}))
}
placeholder={field.placeholder}
/>
</div>
))}
</div>
<DialogFooter>
<Button
variant="outline"
onClick={() => setIsAddCredentialOpen(false)}
disabled={isCreatingCredential}
>
Cancel
</Button>
<Button
onClick={handleCreateCredential}
disabled={!newCredName.trim() || isCreatingCredential}
>
{isCreatingCredential ? (
<>
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
Creating...
</>
) : (
"Create"
)}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
<CredentialSelector
value={credentialUuid}
onChange={setCredentialUuid}
/>
</TabsContent>
<TabsContent value="headers" className="space-y-4 mt-4">
@ -616,34 +296,13 @@ const WebhookNodeEditForm = ({
<Label className="text-xs text-muted-foreground">
Add custom headers to include in the webhook request.
</Label>
{customHeaders.map((header, index) => (
<div key={index} className="flex items-center gap-2">
<Input
placeholder="Header name"
value={header.key}
onChange={(e) => updateHeader(index, "key", e.target.value)}
className="flex-1"
/>
<Input
placeholder="Header value"
value={header.value}
onChange={(e) => updateHeader(index, "value", e.target.value)}
className="flex-1"
/>
<Button
variant="outline"
size="icon"
onClick={() => removeHeader(index)}
>
<Trash2Icon className="h-4 w-4" />
</Button>
</div>
))}
<Button variant="outline" size="sm" onClick={addHeader} className="w-fit">
<PlusIcon className="h-4 w-4 mr-1" /> Add Header
</Button>
<KeyValueEditor
items={customHeaders}
onChange={setCustomHeaders}
keyPlaceholder="Header name"
valuePlaceholder="Header value"
addButtonText="Add Header"
/>
</div>
</TabsContent>

View file

@ -40,6 +40,8 @@ export type FlowNodeData = {
max_retries: number;
retry_delay_seconds: number;
};
// Tools - array of tool UUIDs that can be invoked by this node
tool_uuids?: string[];
}
export type FlowNode = {

View file

@ -0,0 +1,242 @@
"use client";
import { AlertCircle, Loader2 } from "lucide-react";
import { useState } from "react";
import { createCredentialApiV1CredentialsPost } from "@/client";
import { CredentialResponse, WebhookCredentialType } from "@/client/types.gen";
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { useAuth } from "@/lib/auth";
interface CreateCredentialDialogProps {
open: boolean;
onOpenChange: (open: boolean) => void;
onCreated?: (credential: CredentialResponse) => void;
}
interface CredentialField {
key: string;
label: string;
placeholder: string;
isSecret?: boolean;
}
const getCredentialDataFields = (type: WebhookCredentialType): CredentialField[] => {
switch (type) {
case "api_key":
return [
{ key: "header_name", label: "Header Name", placeholder: "X-API-Key" },
{ key: "api_key", label: "API Key", placeholder: "your-api-key", isSecret: true },
];
case "bearer_token":
return [
{ key: "token", label: "Token", placeholder: "your-bearer-token", isSecret: true },
];
case "basic_auth":
return [
{ key: "username", label: "Username", placeholder: "username" },
{ key: "password", label: "Password", placeholder: "password", isSecret: true },
];
case "custom_header":
return [
{ key: "header_name", label: "Header Name", placeholder: "X-Custom-Header" },
{ key: "header_value", label: "Header Value", placeholder: "header-value", isSecret: true },
];
default:
return [];
}
};
export function CreateCredentialDialog({
open,
onOpenChange,
onCreated,
}: CreateCredentialDialogProps) {
const { getAccessToken } = useAuth();
const [name, setName] = useState("");
const [description, setDescription] = useState("");
const [credentialType, setCredentialType] = useState<WebhookCredentialType>("bearer_token");
const [credentialData, setCredentialData] = useState<Record<string, string>>({});
const [isCreating, setIsCreating] = useState(false);
const [error, setError] = useState<string | null>(null);
const handleCreate = async () => {
if (!name.trim()) return;
setIsCreating(true);
setError(null);
try {
const accessToken = await getAccessToken();
const response = await createCredentialApiV1CredentialsPost({
headers: { Authorization: `Bearer ${accessToken}` },
body: {
name,
description: description || undefined,
credential_type: credentialType,
credential_data: credentialData,
},
});
if (response.error) {
const errorDetail = (response.error as { detail?: string })?.detail
|| "Failed to create credential";
setError(errorDetail);
return;
}
if (response.data) {
onCreated?.(response.data);
handleClose();
}
} catch (err) {
console.error("Failed to create credential:", err);
setError(
err instanceof Error ? err.message : "An unexpected error occurred"
);
} finally {
setIsCreating(false);
}
};
const handleClose = () => {
onOpenChange(false);
// Reset form
setName("");
setDescription("");
setCredentialType("bearer_token");
setCredentialData({});
setError(null);
};
const handleOpenChange = (newOpen: boolean) => {
if (!newOpen) {
setError(null);
}
onOpenChange(newOpen);
};
const fields = getCredentialDataFields(credentialType);
return (
<Dialog open={open} onOpenChange={handleOpenChange}>
<DialogContent className="sm:max-w-md">
<DialogHeader>
<DialogTitle>Add Credential</DialogTitle>
<DialogDescription>
Create a new credential for authentication.
</DialogDescription>
</DialogHeader>
{error && (
<div className="flex items-start gap-2 p-3 text-sm text-red-600 bg-red-50 border border-red-200 rounded-md">
<AlertCircle className="h-4 w-4 mt-0.5 flex-shrink-0" />
<span>{error}</span>
</div>
)}
<div className="space-y-4 py-4">
<div className="grid gap-2">
<Label htmlFor="cred-name">Name *</Label>
<Input
id="cred-name"
value={name}
onChange={(e) => setName(e.target.value)}
placeholder="My API Key"
/>
</div>
<div className="grid gap-2">
<Label htmlFor="cred-description">Description</Label>
<Input
id="cred-description"
value={description}
onChange={(e) => setDescription(e.target.value)}
placeholder="Optional description"
/>
</div>
<div className="grid gap-2">
<Label>Credential Type</Label>
<Select
value={credentialType}
onValueChange={(v) => {
setCredentialType(v as WebhookCredentialType);
setCredentialData({});
}}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="bearer_token">Bearer Token</SelectItem>
<SelectItem value="api_key">API Key</SelectItem>
<SelectItem value="basic_auth">Basic Auth</SelectItem>
<SelectItem value="custom_header">Custom Header</SelectItem>
</SelectContent>
</Select>
</div>
{fields.map((field) => (
<div key={field.key} className="grid gap-2">
<Label htmlFor={`cred-${field.key}`}>{field.label}</Label>
<Input
id={`cred-${field.key}`}
type={field.isSecret ? "password" : "text"}
value={credentialData[field.key] || ""}
onChange={(e) =>
setCredentialData((prev) => ({
...prev,
[field.key]: e.target.value,
}))
}
placeholder={field.placeholder}
/>
</div>
))}
</div>
<DialogFooter>
<Button
variant="outline"
onClick={handleClose}
disabled={isCreating}
>
Cancel
</Button>
<Button
onClick={handleCreate}
disabled={!name.trim() || isCreating}
>
{isCreating ? (
<>
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
Creating...
</>
) : (
"Create"
)}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
);
}

View file

@ -0,0 +1,140 @@
"use client";
import { Loader2, PlusIcon } from "lucide-react";
import { useCallback, useEffect, useState } from "react";
import { listCredentialsApiV1CredentialsGet } from "@/client";
import { CredentialResponse } from "@/client/types.gen";
import { CreateCredentialDialog } from "@/components/http/create-credential-dialog";
import { Button } from "@/components/ui/button";
import { Label } from "@/components/ui/label";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { useAuth } from "@/lib/auth";
interface CredentialSelectorProps {
value: string;
onChange: (uuid: string) => void;
disabled?: boolean;
placeholder?: string;
label?: string;
description?: string;
showLabel?: boolean;
}
export function CredentialSelector({
value,
onChange,
disabled = false,
placeholder = "No authentication",
label = "Credential",
description = "Select a credential for authentication, or leave empty for no auth.",
showLabel = true,
}: CredentialSelectorProps) {
const { getAccessToken } = useAuth();
const [credentials, setCredentials] = useState<CredentialResponse[]>([]);
const [loading, setLoading] = useState(false);
const [isAddDialogOpen, setIsAddDialogOpen] = useState(false);
const fetchCredentials = useCallback(async () => {
setLoading(true);
try {
const accessToken = await getAccessToken();
const response = await listCredentialsApiV1CredentialsGet({
headers: { Authorization: `Bearer ${accessToken}` },
});
if (response.error) {
console.error("Failed to fetch credentials:", response.error);
setCredentials([]);
return;
}
if (response.data) {
setCredentials(response.data);
}
} catch (error) {
console.error("Failed to fetch credentials:", error);
setCredentials([]);
} finally {
setLoading(false);
}
}, [getAccessToken]);
useEffect(() => {
fetchCredentials();
}, [fetchCredentials]);
const handleCredentialCreated = async (credential: CredentialResponse) => {
await fetchCredentials();
onChange(credential.uuid);
};
return (
<div className="grid gap-2">
{showLabel && (
<>
<Label>{label}</Label>
{description && (
<Label className="text-xs text-muted-foreground">
{description}
</Label>
)}
</>
)}
<div className="flex gap-2">
<Select
value={value || "none"}
onValueChange={(v) => onChange(v === "none" ? "" : v)}
disabled={disabled || loading}
>
<SelectTrigger className="flex-1">
{loading ? (
<div className="flex items-center gap-2">
<Loader2 className="h-4 w-4 animate-spin" />
<span>Loading...</span>
</div>
) : (
<SelectValue placeholder={placeholder} />
)}
</SelectTrigger>
<SelectContent>
<SelectItem value="none">{placeholder}</SelectItem>
{credentials.map((cred) => (
<SelectItem key={cred.uuid} value={cred.uuid}>
{cred.name} ({cred.credential_type})
</SelectItem>
))}
</SelectContent>
</Select>
<Button
variant="outline"
size="icon"
onClick={() => setIsAddDialogOpen(true)}
title="Add new credential"
disabled={disabled}
>
<PlusIcon className="h-4 w-4" />
</Button>
</div>
{credentials.length === 0 && !loading && (
<div className="p-3 border rounded-md bg-muted/20">
<p className="text-sm text-muted-foreground">
No credentials found. Click the + button to create one.
</p>
</div>
)}
<CreateCredentialDialog
open={isAddDialogOpen}
onOpenChange={setIsAddDialogOpen}
onCreated={handleCredentialCreated}
/>
</div>
);
}

View file

@ -0,0 +1,44 @@
"use client";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
export type HttpMethod = "GET" | "POST" | "PUT" | "PATCH" | "DELETE";
interface HttpMethodSelectorProps {
value: HttpMethod;
onChange: (method: HttpMethod) => void;
disabled?: boolean;
}
const HTTP_METHODS: HttpMethod[] = ["GET", "POST", "PUT", "PATCH", "DELETE"];
export function HttpMethodSelector({
value,
onChange,
disabled = false,
}: HttpMethodSelectorProps) {
return (
<Select
value={value}
onValueChange={(v) => onChange(v as HttpMethod)}
disabled={disabled}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
{HTTP_METHODS.map((method) => (
<SelectItem key={method} value={method}>
{method}
</SelectItem>
))}
</SelectContent>
</Select>
);
}

View file

@ -0,0 +1,5 @@
export { CreateCredentialDialog } from "./create-credential-dialog";
export { CredentialSelector } from "./credential-selector";
export { type HttpMethod, HttpMethodSelector } from "./http-method-selector";
export { KeyValueEditor, type KeyValueItem } from "./key-value-editor";
export { ParameterEditor, type ParameterType,type ToolParameter } from "./parameter-editor";

View file

@ -0,0 +1,85 @@
"use client";
import { PlusIcon, Trash2Icon } from "lucide-react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
export interface KeyValueItem {
key: string;
value: string;
}
interface KeyValueEditorProps {
items: KeyValueItem[];
onChange: (items: KeyValueItem[]) => void;
keyPlaceholder?: string;
valuePlaceholder?: string;
addButtonText?: string;
emptyMessage?: string;
disabled?: boolean;
}
export function KeyValueEditor({
items,
onChange,
keyPlaceholder = "Key",
valuePlaceholder = "Value",
addButtonText = "Add",
disabled = false,
}: KeyValueEditorProps) {
const addItem = () => {
onChange([...items, { key: "", value: "" }]);
};
const updateItem = (index: number, field: "key" | "value", value: string) => {
const newItems = [...items];
newItems[index] = { ...newItems[index], [field]: value };
onChange(newItems);
};
const removeItem = (index: number) => {
onChange(items.filter((_, i) => i !== index));
};
return (
<div className="space-y-2">
{items.map((item, index) => (
<div key={index} className="flex items-center gap-2">
<Input
placeholder={keyPlaceholder}
value={item.key}
onChange={(e) => updateItem(index, "key", e.target.value)}
className="flex-1"
disabled={disabled}
/>
<Input
placeholder={valuePlaceholder}
value={item.value}
onChange={(e) => updateItem(index, "value", e.target.value)}
className="flex-1"
disabled={disabled}
/>
<Button
variant="outline"
size="icon"
onClick={() => removeItem(index)}
disabled={disabled}
>
<Trash2Icon className="h-4 w-4" />
</Button>
</div>
))}
<Button
variant="outline"
size="sm"
onClick={addItem}
className="w-fit"
disabled={disabled}
>
<PlusIcon className="h-4 w-4 mr-1" /> {addButtonText}
</Button>
</div>
);
}

View file

@ -0,0 +1,167 @@
"use client";
import { PlusIcon, Trash2Icon } from "lucide-react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Switch } from "@/components/ui/switch";
export type ParameterType = "string" | "number" | "boolean";
export interface ToolParameter {
name: string;
type: ParameterType;
description: string;
required: boolean;
}
interface ParameterEditorProps {
parameters: ToolParameter[];
onChange: (parameters: ToolParameter[]) => void;
disabled?: boolean;
}
export function ParameterEditor({
parameters,
onChange,
disabled = false,
}: ParameterEditorProps) {
const addParameter = () => {
onChange([
...parameters,
{ name: "", type: "string", description: "", required: true },
]);
};
const updateParameter = (
index: number,
field: keyof ToolParameter,
value: string | boolean
) => {
const newParams = [...parameters];
newParams[index] = { ...newParams[index], [field]: value };
onChange(newParams);
};
const removeParameter = (index: number) => {
onChange(parameters.filter((_, i) => i !== index));
};
return (
<div className="space-y-4">
{parameters.length === 0 && (
<div className="text-sm text-muted-foreground py-4 text-center border border-dashed rounded-md">
No parameters defined. Add a parameter to specify what data this tool needs.
</div>
)}
{parameters.map((param, index) => (
<div
key={index}
className="border rounded-lg p-4 space-y-3 bg-muted/20"
>
<div className="flex items-center justify-between">
<span className="text-sm font-medium text-muted-foreground">
Parameter {index + 1}
</span>
<Button
variant="ghost"
size="icon"
onClick={() => removeParameter(index)}
disabled={disabled}
className="h-8 w-8"
>
<Trash2Icon className="h-4 w-4 text-muted-foreground hover:text-destructive" />
</Button>
</div>
<div className="grid grid-cols-2 gap-3">
<div className="space-y-1.5">
<Label className="text-xs">Name</Label>
<Label className="text-xs text-muted-foreground">
Name of the parameter, like &quot;order_id&quot; or &quot;customer_name&quot;
</Label>
<Input
placeholder="e.g., customer_name"
value={param.name}
onChange={(e) =>
updateParameter(index, "name", e.target.value)
}
disabled={disabled}
/>
</div>
<div className="space-y-1.5">
<Label className="text-xs">Type</Label>
<Label className="text-xs text-muted-foreground">
Type of the parameter, like &quot;string&quot; or &quot;number&quot; or &quot;boolean&quot;
</Label>
<Select
value={param.type}
onValueChange={(value: ParameterType) =>
updateParameter(index, "type", value)
}
disabled={disabled}
>
<SelectTrigger>
<SelectValue placeholder="Select type" />
</SelectTrigger>
<SelectContent>
<SelectItem value="string">String</SelectItem>
<SelectItem value="number">Number</SelectItem>
<SelectItem value="boolean">Boolean</SelectItem>
</SelectContent>
</Select>
</div>
</div>
<div className="space-y-1.5">
<Label className="text-xs">Description</Label>
<Label className="text-xs text-muted-foreground">
Description of the parameter, which makes it easy for LLM to understand, like &quot;The ID of the Customer to fetch Order Details&quot;
</Label>
<Input
placeholder="Describe what this parameter is for..."
value={param.description}
onChange={(e) =>
updateParameter(index, "description", e.target.value)
}
disabled={disabled}
/>
</div>
<div className="flex items-center gap-2">
<Switch
id={`required-${index}`}
checked={param.required}
onCheckedChange={(checked) =>
updateParameter(index, "required", checked)
}
disabled={disabled}
/>
<Label htmlFor={`required-${index}`} className="text-sm">
Required
</Label>
</div>
</div>
))}
<Button
variant="outline"
size="sm"
onClick={addParameter}
className="w-fit"
disabled={disabled}
>
<PlusIcon className="h-4 w-4 mr-1" /> Add Parameter
</Button>
</div>
);
}

View file

@ -16,6 +16,7 @@ import {
Star,
TrendingUp,
Workflow,
Wrench,
Zap,
} from "lucide-react";
import Link from "next/link";
@ -108,6 +109,11 @@ export function AppSidebar() {
url: "/telephony-configurations",
icon: Phone,
},
{
title: "Tools",
url: "/tools",
icon: Wrench,
},
// {
// title: "Integrations",
// url: "/integrations",