mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
Merge branch 'main' of https://github.com/dograh-hq/dograh
This commit is contained in:
commit
6f34433e00
65 changed files with 5483 additions and 6673 deletions
92
api/alembic/versions/ebc80cea7965_add_tools_model.py
Normal file
92
api/alembic/versions/ebc80cea7965_add_tools_model.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
"""add tools model
|
||||
|
||||
Revision ID: ebc80cea7965
|
||||
Revises: 36b5dbf670e4
|
||||
Create Date: 2026-01-01 10:13:50.807135
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "ebc80cea7965"
|
||||
down_revision: Union[str, None] = "36b5dbf670e4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
sa.Enum("active", "archived", "draft", name="tool_status").create(op.get_bind())
|
||||
sa.Enum("http_api", "native", "integration", name="tool_category").create(
|
||||
op.get_bind()
|
||||
)
|
||||
op.create_table(
|
||||
"tools",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("tool_uuid", sa.String(length=36), nullable=False),
|
||||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"category",
|
||||
postgresql.ENUM(
|
||||
"http_api",
|
||||
"native",
|
||||
"integration",
|
||||
name="tool_category",
|
||||
create_type=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("icon", sa.String(length=50), nullable=True),
|
||||
sa.Column("icon_color", sa.String(length=7), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
postgresql.ENUM(
|
||||
"active", "archived", "draft", name="tool_status", create_type=False
|
||||
),
|
||||
server_default=sa.text("'active'::tool_status"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("definition", sa.JSON(), nullable=False),
|
||||
sa.Column("created_by", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["created_by"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("organization_id", "name", name="unique_org_tool_name"),
|
||||
)
|
||||
op.create_index("ix_tools_category", "tools", ["category"], unique=False)
|
||||
op.create_index(
|
||||
"ix_tools_organization_id", "tools", ["organization_id"], unique=False
|
||||
)
|
||||
op.create_index("ix_tools_status", "tools", ["status"], unique=False)
|
||||
op.create_index(op.f("ix_tools_tool_uuid"), "tools", ["tool_uuid"], unique=True)
|
||||
op.create_index("ix_tools_uuid", "tools", ["tool_uuid"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_tools_uuid", table_name="tools")
|
||||
op.drop_index(op.f("ix_tools_tool_uuid"), table_name="tools")
|
||||
op.drop_index("ix_tools_status", table_name="tools")
|
||||
op.drop_index("ix_tools_organization_id", table_name="tools")
|
||||
op.drop_index("ix_tools_category", table_name="tools")
|
||||
op.drop_table("tools")
|
||||
sa.Enum("http_api", "native", "integration", name="tool_category").drop(
|
||||
op.get_bind()
|
||||
)
|
||||
sa.Enum("active", "archived", "draft", name="tool_status").drop(op.get_bind())
|
||||
# ### end Alembic commands ###
|
||||
395
api/conftest.py
395
api/conftest.py
|
|
@ -1,143 +1,315 @@
|
|||
"""
|
||||
Shared pytest fixtures for the API tests.
|
||||
This file contains database setup, test client configuration, and utility fixtures
|
||||
that can be reused across all test files.
|
||||
Pytest configuration and fixtures for async database testing.
|
||||
|
||||
This module sets up the test infrastructure using:
|
||||
- A separate test database (appends _test to the database name)
|
||||
- Alembic migrations run once per test session
|
||||
- Transaction-based isolation for each test (savepoint pattern)
|
||||
|
||||
References:
|
||||
- https://www.core27.co/post/transactional-unit-tests-with-pytest-and-async-sqlalchemy
|
||||
- https://docs.sqlalchemy.org/en/20/orm/session_transaction.html
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from loguru import logger
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
# Load environment variables before importing anything else
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from api.app import app
|
||||
from api.db import db_client
|
||||
# Load .env.test from api directory for test configuration
|
||||
env_path = Path(__file__).parent / ".env.test"
|
||||
load_dotenv(env_path)
|
||||
|
||||
# Test database setup globals
|
||||
TEST_DATABASE_NAME = None
|
||||
TEST_DATABASE_URL = None
|
||||
import pytest
|
||||
from sqlalchemy import event, text
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncConnection,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import SessionTransaction
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_database():
|
||||
def get_test_database_url() -> str:
|
||||
"""
|
||||
Set up a temporary PostgreSQL database for testing.
|
||||
This fixture creates a unique test database, runs migrations, and cleans up afterward.
|
||||
Get the test database URL by appending _test to the database name.
|
||||
|
||||
Example:
|
||||
postgresql+asyncpg://user:pass@host/mydb
|
||||
-> postgresql+asyncpg://user:pass@host/mydb_test
|
||||
"""
|
||||
global TEST_DATABASE_NAME, TEST_DATABASE_URL
|
||||
original_url = os.environ.get("DATABASE_URL")
|
||||
if not original_url:
|
||||
raise ValueError("DATABASE_URL environment variable is not set")
|
||||
|
||||
# Generate a unique test database name
|
||||
TEST_DATABASE_NAME = f"test_dograh_{uuid.uuid4().hex[:8]}"
|
||||
parsed = urlparse(original_url)
|
||||
# Append _test to the database name (path without leading slash)
|
||||
original_db_name = parsed.path.lstrip("/")
|
||||
test_db_name = f"{original_db_name}_test"
|
||||
|
||||
# Get the base DATABASE_URL and parse it
|
||||
base_url = os.environ.get("DATABASE_URL")
|
||||
# Extract connection parts and replace database name
|
||||
url_parts = base_url.split("/")
|
||||
base_connection = "/".join(url_parts[:-1])
|
||||
TEST_DATABASE_URL = f"{base_connection}/{TEST_DATABASE_NAME}"
|
||||
# Reconstruct the URL with the new database name
|
||||
test_url = urlunparse(
|
||||
(
|
||||
parsed.scheme,
|
||||
parsed.netloc,
|
||||
f"/{test_db_name}",
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
)
|
||||
)
|
||||
return test_url
|
||||
|
||||
# Create a connection to the default postgres database to create our test database
|
||||
default_engine = create_async_engine(base_url)
|
||||
|
||||
def get_base_database_url() -> str:
|
||||
"""
|
||||
Get base database URL (postgres) for creating/dropping test database.
|
||||
"""
|
||||
original_url = os.environ.get("DATABASE_URL")
|
||||
parsed = urlparse(original_url)
|
||||
# Connect to 'postgres' database for admin operations
|
||||
base_url = urlunparse(
|
||||
(
|
||||
parsed.scheme,
|
||||
parsed.netloc,
|
||||
"/postgres",
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
)
|
||||
)
|
||||
return base_url
|
||||
|
||||
|
||||
def get_test_db_name() -> str:
|
||||
"""Extract the test database name."""
|
||||
original_url = os.environ.get("DATABASE_URL")
|
||||
parsed = urlparse(original_url)
|
||||
original_db_name = parsed.path.lstrip("/")
|
||||
return f"{original_db_name}_test"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def setup_test_database():
|
||||
"""
|
||||
Session-scoped fixture that creates the test database and runs migrations.
|
||||
|
||||
This runs once at the start of the test session.
|
||||
"""
|
||||
test_db_name = get_test_db_name()
|
||||
base_url = get_base_database_url()
|
||||
test_url = get_test_database_url()
|
||||
|
||||
# Create engine to connect to postgres database (for admin operations)
|
||||
admin_engine = create_async_engine(
|
||||
base_url,
|
||||
poolclass=NullPool,
|
||||
isolation_level="AUTOCOMMIT", # Required for CREATE DATABASE
|
||||
)
|
||||
|
||||
# Create test database if it doesn't exist
|
||||
async with admin_engine.connect() as conn:
|
||||
# Check if database exists
|
||||
result = await conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = :dbname"),
|
||||
{"dbname": test_db_name},
|
||||
)
|
||||
exists = result.scalar() is not None
|
||||
|
||||
if not exists:
|
||||
print(f"\n Creating test database: {test_db_name}")
|
||||
# Use template0 to avoid collation version mismatch issues
|
||||
await conn.execute(
|
||||
text(f'CREATE DATABASE "{test_db_name}" TEMPLATE template0')
|
||||
)
|
||||
else:
|
||||
print(f"\n Using existing test database: {test_db_name}")
|
||||
|
||||
await admin_engine.dispose()
|
||||
|
||||
# Run alembic migrations on the test database
|
||||
print(f" Running migrations on {test_db_name}...")
|
||||
await run_migrations(test_url)
|
||||
print(" Migrations complete!")
|
||||
|
||||
yield test_url
|
||||
|
||||
# Cleanup: Optionally drop the test database after tests
|
||||
# Commented out to allow inspection of test data after failures
|
||||
# async with admin_engine.connect() as conn:
|
||||
# await conn.execute(text(f'DROP DATABASE IF EXISTS "{test_db_name}"'))
|
||||
|
||||
|
||||
async def run_migrations(database_url: str):
|
||||
"""
|
||||
Run alembic migrations programmatically on the given database.
|
||||
"""
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
|
||||
# Get alembic.ini path
|
||||
alembic_ini_path = Path(__file__).parent / "alembic.ini"
|
||||
|
||||
# Create alembic config
|
||||
alembic_cfg = Config(str(alembic_ini_path))
|
||||
|
||||
# Override the database URL - need to patch both os.environ AND api.constants
|
||||
# because api.constants.DATABASE_URL is cached at import time
|
||||
original_env_url = os.environ.get("DATABASE_URL")
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
|
||||
|
||||
# Also patch the cached value in api.constants
|
||||
import api.constants
|
||||
|
||||
original_constants_url = api.constants.DATABASE_URL
|
||||
|
||||
api.constants.DATABASE_URL = database_url
|
||||
|
||||
# Run migrations in a thread to avoid blocking the event loop
|
||||
import asyncio
|
||||
|
||||
def _run_upgrade():
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
try:
|
||||
# Create the test database
|
||||
async with default_engine.connect() as conn:
|
||||
# Use autocommit mode to create database
|
||||
await conn.execute(text("COMMIT"))
|
||||
await conn.execute(text(f"CREATE DATABASE {TEST_DATABASE_NAME}"))
|
||||
|
||||
await default_engine.dispose()
|
||||
|
||||
# Run migrations on the test database
|
||||
env = os.environ.copy()
|
||||
env["DATABASE_URL"] = TEST_DATABASE_URL
|
||||
# Add the parent directory to PYTHONPATH so alembic can find the api module
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
env["PYTHONPATH"] = parent_dir + ":" + env.get("PYTHONPATH", "")
|
||||
|
||||
# Run alembic upgrade to create all tables
|
||||
result = subprocess.run(
|
||||
[
|
||||
"conda",
|
||||
"run",
|
||||
"-n",
|
||||
"dograh",
|
||||
"python",
|
||||
"-m",
|
||||
"alembic",
|
||||
"-c",
|
||||
"alembic.ini",
|
||||
"upgrade",
|
||||
"head",
|
||||
],
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Alembic stderr: {result.stderr}")
|
||||
logger.error(f"Alembic stdout: {result.stdout}")
|
||||
raise RuntimeError(f"Alembic migration failed: {result.stderr}")
|
||||
|
||||
logger.info(f"Created test database: {TEST_DATABASE_NAME}")
|
||||
yield TEST_DATABASE_URL
|
||||
|
||||
await asyncio.get_event_loop().run_in_executor(None, _run_upgrade)
|
||||
finally:
|
||||
# Cleanup: Drop the test database
|
||||
cleanup_engine = create_async_engine(base_url)
|
||||
try:
|
||||
async with cleanup_engine.connect() as conn:
|
||||
# Terminate any connections to the test database
|
||||
await conn.execute(text("COMMIT"))
|
||||
await conn.execute(
|
||||
text(f"""
|
||||
SELECT pg_terminate_backend(pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE datname = '{TEST_DATABASE_NAME}' AND pid <> pg_backend_pid()
|
||||
""")
|
||||
)
|
||||
await conn.execute(
|
||||
text(f"DROP DATABASE IF EXISTS {TEST_DATABASE_NAME}")
|
||||
)
|
||||
logger.info(f"Cleaned up test database: {TEST_DATABASE_NAME}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Warning: Could not clean up test database {TEST_DATABASE_NAME}: {e}"
|
||||
)
|
||||
finally:
|
||||
await cleanup_engine.dispose()
|
||||
# Restore original DATABASE_URL
|
||||
if original_env_url:
|
||||
os.environ["DATABASE_URL"] = original_env_url
|
||||
api.constants.DATABASE_URL = original_constants_url
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(test_database):
|
||||
@pytest.fixture(scope="session")
|
||||
async def test_engine(setup_test_database):
|
||||
"""
|
||||
Create a test database client that uses the temporary database.
|
||||
This fixture replaces the global db_client with a test version.
|
||||
Create a test database engine (session-scoped).
|
||||
|
||||
Uses NullPool to avoid connection issues with async tests.
|
||||
"""
|
||||
test_url = setup_test_database
|
||||
engine = create_async_engine(
|
||||
test_url,
|
||||
poolclass=NullPool,
|
||||
echo=False, # Set to True for SQL debugging
|
||||
)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def db_connection(test_engine) -> AsyncGenerator[AsyncConnection, None]:
|
||||
"""
|
||||
Create a database connection for each test.
|
||||
|
||||
This connection wraps all operations in a transaction that
|
||||
will be rolled back at the end of the test.
|
||||
"""
|
||||
async with test_engine.connect() as connection:
|
||||
yield connection
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def async_session(
|
||||
db_connection: AsyncConnection,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Create a database session with transaction isolation for each test.
|
||||
|
||||
This fixture:
|
||||
1. Begins a transaction on the connection
|
||||
2. Creates a savepoint (nested transaction)
|
||||
3. Yields the session for test use
|
||||
4. Rolls back all changes after the test
|
||||
|
||||
Tests can call session.commit() and it will only commit to the savepoint,
|
||||
not to the actual database. The outer transaction rollback ensures
|
||||
complete isolation between tests.
|
||||
"""
|
||||
# Begin outer transaction
|
||||
trans = await db_connection.begin()
|
||||
|
||||
# Create session bound to this connection
|
||||
async_session_maker = async_sessionmaker(
|
||||
bind=db_connection,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
# Begin a nested transaction (savepoint)
|
||||
nested = await session.begin_nested()
|
||||
|
||||
# Set up event listener to restart savepoint after commits
|
||||
@event.listens_for(session.sync_session, "after_transaction_end")
|
||||
def reopen_nested_transaction(session_sync, transaction: SessionTransaction):
|
||||
nonlocal nested
|
||||
if not nested.is_active:
|
||||
nested = session.sync_session.begin_nested()
|
||||
|
||||
yield session
|
||||
|
||||
# Rollback everything
|
||||
await trans.rollback()
|
||||
|
||||
|
||||
class _TestSessionContext:
|
||||
"""
|
||||
Context manager wrapper for test session.
|
||||
|
||||
Mimics the behavior of async_sessionmaker() context manager
|
||||
but uses the existing test session without closing it.
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self._session = session
|
||||
|
||||
async def __aenter__(self) -> AsyncSession:
|
||||
return self._session
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
await self._session.flush()
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def db_session(async_session: AsyncSession):
|
||||
"""
|
||||
Create a DBClient instance that uses the test session.
|
||||
|
||||
This patches the DBClient's async_session to use our test session,
|
||||
ensuring all database operations go through the transactional test session.
|
||||
|
||||
Note: This fixture yields a DBClient (not a raw session) for backward
|
||||
compatibility with existing tests that call db_session.get_or_create_user_by_provider_id(), etc.
|
||||
"""
|
||||
from api.db import db_client
|
||||
|
||||
def test_session_maker():
|
||||
return _TestSessionContext(async_session)
|
||||
|
||||
# Store originals
|
||||
original_engine = db_client.engine
|
||||
original_session = db_client.async_session
|
||||
original_async_session = db_client.async_session
|
||||
|
||||
# Replace the database client's engine and session with test ones
|
||||
test_engine = create_async_engine(test_database)
|
||||
test_session_maker = async_sessionmaker(bind=test_engine)
|
||||
|
||||
db_client.engine = test_engine
|
||||
# Patch the db_client to use our test session
|
||||
db_client.async_session = test_session_maker
|
||||
|
||||
yield db_client
|
||||
|
||||
# Restore original database client
|
||||
await test_engine.dispose()
|
||||
# Restore originals
|
||||
db_client.engine = original_engine
|
||||
db_client.async_session = original_session
|
||||
db_client.async_session = original_async_session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest.fixture
|
||||
async def test_client_factory(db_session):
|
||||
"""
|
||||
Factory fixture that creates test clients for specific users.
|
||||
|
|
@ -155,6 +327,9 @@ async def test_client_factory(db_session):
|
|||
"""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from api.app import app
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from api.db.organization_client import OrganizationClient
|
|||
from api.db.organization_configuration_client import OrganizationConfigurationClient
|
||||
from api.db.organization_usage_client import OrganizationUsageClient
|
||||
from api.db.reports_client import ReportsClient
|
||||
from api.db.tool_client import ToolClient
|
||||
from api.db.user_client import UserClient
|
||||
from api.db.webhook_credential_client import WebhookCredentialClient
|
||||
from api.db.workflow_client import WorkflowClient
|
||||
|
|
@ -31,6 +32,7 @@ class DBClient(
|
|||
EmbedTokenClient,
|
||||
AgentTriggerClient,
|
||||
WebhookCredentialClient,
|
||||
ToolClient,
|
||||
):
|
||||
"""
|
||||
Unified database client that combines all specialized database operations.
|
||||
|
|
@ -51,6 +53,7 @@ class DBClient(
|
|||
- EmbedTokenClient: handles embed token and session operations
|
||||
- AgentTriggerClient: handles agent trigger operations for API-based call triggering
|
||||
- WebhookCredentialClient: handles webhook credential operations
|
||||
- ToolClient: handles tool operations for reusable HTTP API tools
|
||||
"""
|
||||
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ from sqlalchemy.orm import declarative_base, relationship
|
|||
|
||||
from ..enums import (
|
||||
IntegrationAction,
|
||||
ToolCategory,
|
||||
ToolStatus,
|
||||
TriggerState,
|
||||
WebhookCredentialType,
|
||||
WorkflowRunMode,
|
||||
|
|
@ -800,3 +802,85 @@ class ExternalCredentialModel(Base):
|
|||
Index("ix_webhook_credentials_uuid", "credential_uuid"),
|
||||
UniqueConstraint("organization_id", "name", name="unique_org_credential_name"),
|
||||
)
|
||||
|
||||
|
||||
class ToolModel(Base):
|
||||
"""Model for storing reusable tools that can be invoked during workflows.
|
||||
|
||||
Tools provide a standardized way to integrate external functionality - from
|
||||
HTTP API calls to native integrations.
|
||||
"""
|
||||
|
||||
__tablename__ = "tools"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Public identifier (used in APIs and workflow references)
|
||||
tool_uuid = Column(
|
||||
String(36),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
# Organization scoping
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Tool metadata
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
|
||||
# Tool category - uses enum from api/enums.py
|
||||
category = Column(
|
||||
Enum(
|
||||
*[c.value for c in ToolCategory],
|
||||
name="tool_category",
|
||||
),
|
||||
nullable=False,
|
||||
default=ToolCategory.HTTP_API.value,
|
||||
)
|
||||
|
||||
# Icon configuration (for UI display)
|
||||
icon = Column(String(50), nullable=True) # Icon identifier
|
||||
icon_color = Column(String(7), nullable=True) # Hex color code
|
||||
|
||||
# Status management
|
||||
status = Column(
|
||||
Enum(
|
||||
*[s.value for s in ToolStatus],
|
||||
name="tool_status",
|
||||
),
|
||||
nullable=False,
|
||||
default=ToolStatus.ACTIVE.value,
|
||||
server_default=text("'active'::tool_status"),
|
||||
)
|
||||
|
||||
# The tool definition (JSONB) - contains schema_version for compatibility
|
||||
# Structure depends on category:
|
||||
# - http_api: {"schema_version": 1, "type": "http_api", "config": {...}}
|
||||
definition = Column(JSON, nullable=False, default=dict)
|
||||
|
||||
# Audit fields
|
||||
created_by = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
organization = relationship("OrganizationModel")
|
||||
created_by_user = relationship("UserModel")
|
||||
|
||||
# Indexes and constraints
|
||||
__table_args__ = (
|
||||
Index("ix_tools_organization_id", "organization_id"),
|
||||
Index("ix_tools_uuid", "tool_uuid"),
|
||||
Index("ix_tools_status", "status"),
|
||||
Index("ix_tools_category", "category"),
|
||||
UniqueConstraint("organization_id", "name", name="unique_org_tool_name"),
|
||||
)
|
||||
|
|
|
|||
276
api/db/tool_client.py
Normal file
276
api/db/tool_client.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
"""Database client for managing tools."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import ToolModel
|
||||
from api.enums import ToolCategory, ToolStatus
|
||||
|
||||
|
||||
class ToolClient(BaseDBClient):
|
||||
"""Client for managing tools (organization-scoped, UUID-referenced)."""
|
||||
|
||||
async def create_tool(
|
||||
self,
|
||||
organization_id: int,
|
||||
user_id: int,
|
||||
name: str,
|
||||
definition: dict,
|
||||
category: str = ToolCategory.HTTP_API.value,
|
||||
description: Optional[str] = None,
|
||||
icon: Optional[str] = None,
|
||||
icon_color: Optional[str] = None,
|
||||
) -> ToolModel:
|
||||
"""Create a new tool.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
user_id: ID of the user creating the tool
|
||||
name: Display name for the tool
|
||||
definition: JSON definition of the tool
|
||||
category: Tool category (http_api, native, integration)
|
||||
description: Optional description
|
||||
icon: Optional icon identifier
|
||||
icon_color: Optional hex color code
|
||||
|
||||
Returns:
|
||||
The created ToolModel with auto-generated UUID
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
tool = ToolModel(
|
||||
organization_id=organization_id,
|
||||
created_by=user_id,
|
||||
name=name,
|
||||
description=description,
|
||||
category=category,
|
||||
icon=icon,
|
||||
icon_color=icon_color,
|
||||
definition=definition,
|
||||
status=ToolStatus.ACTIVE.value,
|
||||
)
|
||||
|
||||
session.add(tool)
|
||||
await session.commit()
|
||||
await session.refresh(tool)
|
||||
|
||||
logger.info(
|
||||
f"Created tool '{name}' ({tool.tool_uuid}) "
|
||||
f"for organization {organization_id}"
|
||||
)
|
||||
return tool
|
||||
|
||||
async def get_tools_for_organization(
|
||||
self,
|
||||
organization_id: int,
|
||||
status: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
) -> List[ToolModel]:
|
||||
"""Get all tools for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
status: Optional filter by status (active, archived, draft)
|
||||
category: Optional filter by category (http_api, native, integration)
|
||||
|
||||
Returns:
|
||||
List of ToolModel instances
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(ToolModel).where(
|
||||
ToolModel.organization_id == organization_id
|
||||
)
|
||||
|
||||
if status:
|
||||
query = query.where(ToolModel.status == status)
|
||||
else:
|
||||
# By default, exclude archived tools
|
||||
query = query.where(ToolModel.status != ToolStatus.ARCHIVED.value)
|
||||
|
||||
if category:
|
||||
query = query.where(ToolModel.category == category)
|
||||
|
||||
query = query.order_by(ToolModel.name)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_tool_by_uuid(
|
||||
self,
|
||||
tool_uuid: str,
|
||||
organization_id: int,
|
||||
include_archived: bool = False,
|
||||
) -> Optional[ToolModel]:
|
||||
"""Get a tool by its UUID, scoped to organization.
|
||||
|
||||
Args:
|
||||
tool_uuid: The unique tool UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
include_archived: If True, include archived tools
|
||||
|
||||
Returns:
|
||||
ToolModel if found and authorized, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(ToolModel)
|
||||
.where(
|
||||
ToolModel.tool_uuid == tool_uuid,
|
||||
ToolModel.organization_id == organization_id,
|
||||
)
|
||||
.options(selectinload(ToolModel.created_by_user))
|
||||
)
|
||||
|
||||
if not include_archived:
|
||||
query = query.where(ToolModel.status != ToolStatus.ARCHIVED.value)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_tool(
|
||||
self,
|
||||
tool_uuid: str,
|
||||
organization_id: int,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
definition: Optional[dict] = None,
|
||||
icon: Optional[str] = None,
|
||||
icon_color: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
) -> Optional[ToolModel]:
|
||||
"""Update a tool by UUID.
|
||||
|
||||
Args:
|
||||
tool_uuid: The unique tool UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
name: New name (if provided)
|
||||
description: New description (if provided)
|
||||
definition: New definition (if provided)
|
||||
icon: New icon (if provided)
|
||||
icon_color: New icon color (if provided)
|
||||
status: New status (if provided)
|
||||
|
||||
Returns:
|
||||
Updated ToolModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# First check if tool exists and belongs to organization
|
||||
tool = await self.get_tool_by_uuid(
|
||||
tool_uuid, organization_id, include_archived=True
|
||||
)
|
||||
if not tool:
|
||||
return None
|
||||
|
||||
# Build update values
|
||||
update_values = {"updated_at": datetime.now(UTC)}
|
||||
if name is not None:
|
||||
update_values["name"] = name
|
||||
if description is not None:
|
||||
update_values["description"] = description
|
||||
if definition is not None:
|
||||
update_values["definition"] = definition
|
||||
if icon is not None:
|
||||
update_values["icon"] = icon
|
||||
if icon_color is not None:
|
||||
update_values["icon_color"] = icon_color
|
||||
if status is not None:
|
||||
update_values["status"] = status
|
||||
|
||||
await session.execute(
|
||||
update(ToolModel)
|
||||
.where(
|
||||
ToolModel.tool_uuid == tool_uuid,
|
||||
ToolModel.organization_id == organization_id,
|
||||
)
|
||||
.values(**update_values)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Fetch updated tool
|
||||
result = await session.execute(
|
||||
select(ToolModel)
|
||||
.where(ToolModel.tool_uuid == tool_uuid)
|
||||
.options(selectinload(ToolModel.created_by_user))
|
||||
)
|
||||
updated_tool = result.scalar_one()
|
||||
|
||||
logger.info(f"Updated tool {tool_uuid} for organization {organization_id}")
|
||||
return updated_tool
|
||||
|
||||
async def archive_tool(self, tool_uuid: str, organization_id: int) -> bool:
|
||||
"""Soft delete a tool by setting its status to archived.
|
||||
|
||||
Args:
|
||||
tool_uuid: The unique tool UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
True if tool was archived, False if not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
update(ToolModel)
|
||||
.where(
|
||||
ToolModel.tool_uuid == tool_uuid,
|
||||
ToolModel.organization_id == organization_id,
|
||||
ToolModel.status != ToolStatus.ARCHIVED.value,
|
||||
)
|
||||
.values(
|
||||
status=ToolStatus.ARCHIVED.value,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
logger.info(
|
||||
f"Archived tool {tool_uuid} for organization {organization_id}"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def validate_tool_uuid(self, tool_uuid: str, organization_id: int) -> bool:
|
||||
"""Check if a tool UUID exists and belongs to the organization.
|
||||
|
||||
This is useful for workflow validation to ensure referenced tools exist.
|
||||
|
||||
Args:
|
||||
tool_uuid: The tool UUID to validate
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
tool = await self.get_tool_by_uuid(tool_uuid, organization_id)
|
||||
return tool is not None
|
||||
|
||||
async def get_tools_by_uuids(
|
||||
self,
|
||||
tool_uuids: List[str],
|
||||
organization_id: int,
|
||||
) -> List[ToolModel]:
|
||||
"""Get multiple tools by their UUIDs.
|
||||
|
||||
Args:
|
||||
tool_uuids: List of tool UUIDs to fetch
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
List of ToolModel instances (only active tools)
|
||||
"""
|
||||
if not tool_uuids:
|
||||
return []
|
||||
|
||||
async with self.async_session() as session:
|
||||
query = select(ToolModel).where(
|
||||
ToolModel.tool_uuid.in_(tool_uuids),
|
||||
ToolModel.organization_id == organization_id,
|
||||
ToolModel.status == ToolStatus.ACTIVE.value,
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
18
api/enums.py
18
api/enums.py
|
|
@ -109,3 +109,21 @@ class WebhookCredentialType(Enum):
|
|||
BEARER_TOKEN = "bearer_token" # Bearer token auth
|
||||
BASIC_AUTH = "basic_auth" # Username/password
|
||||
CUSTOM_HEADER = "custom_header" # Custom header key-value
|
||||
|
||||
|
||||
class ToolCategory(Enum):
|
||||
"""Tool category types"""
|
||||
|
||||
HTTP_API = "http_api" # Custom HTTP API calls (implemented)
|
||||
NATIVE = (
|
||||
"native" # Built-in integrations (future: call_transfer, dtmf_input, end_call)
|
||||
)
|
||||
INTEGRATION = "integration" # Third-party integrations (future: Google Calendar, Salesforce, etc.)
|
||||
|
||||
|
||||
class ToolStatus(Enum):
|
||||
"""Tool status values"""
|
||||
|
||||
ACTIVE = "active" # Tool is available for use
|
||||
ARCHIVED = "archived" # Tool is soft-deleted
|
||||
DRAFT = "draft" # Tool is being configured (not ready for use)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
[pytest]
|
||||
asyncio_mode = strict
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
testpaths = .
|
||||
asyncio_mode = auto
|
||||
asyncio_default_fixture_loop_scope = session
|
||||
asyncio_default_test_loop_scope = session
|
||||
testpaths = tests
|
||||
python_files = test_*.py *_test.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
|
|
|
|||
|
|
@ -3,4 +3,5 @@ ruff==0.11.3
|
|||
pytest==8.3.5
|
||||
pytest-asyncio==0.26.0
|
||||
pre-commit==4.2.0
|
||||
watchfiles==1.1.0
|
||||
watchfiles==1.1.0
|
||||
python-dotenv==1.2.1
|
||||
|
|
@ -15,6 +15,7 @@ from api.routes.s3_signed_url import router as s3_router
|
|||
from api.routes.service_keys import router as service_keys_router
|
||||
from api.routes.superuser import router as superuser_router
|
||||
from api.routes.telephony import router as telephony_router
|
||||
from api.routes.tool import router as tool_router
|
||||
from api.routes.user import router as user_router
|
||||
from api.routes.webrtc_signaling import router as webrtc_signaling_router
|
||||
from api.routes.workflow import router as workflow_router
|
||||
|
|
@ -32,6 +33,7 @@ router.include_router(workflow_router)
|
|||
router.include_router(user_router)
|
||||
router.include_router(campaign_router)
|
||||
router.include_router(credentials_router)
|
||||
router.include_router(tool_router)
|
||||
router.include_router(integration_router)
|
||||
router.include_router(organization_router)
|
||||
router.include_router(s3_router)
|
||||
|
|
|
|||
336
api/routes/tool.py
Normal file
336
api/routes/tool.py
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
"""API routes for managing tools."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import ToolCategory, ToolStatus
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
router = APIRouter(prefix="/tools")
|
||||
|
||||
|
||||
# Request/Response schemas
|
||||
class ToolParameter(BaseModel):
|
||||
"""A parameter that the tool accepts."""
|
||||
|
||||
name: str = Field(description="Parameter name (used as key in request body)")
|
||||
type: str = Field(description="Parameter type: string, number, or boolean")
|
||||
description: str = Field(description="Description of what this parameter is for")
|
||||
required: bool = Field(
|
||||
default=True, description="Whether this parameter is required"
|
||||
)
|
||||
|
||||
|
||||
class HttpApiConfig(BaseModel):
|
||||
"""Configuration for HTTP API tools."""
|
||||
|
||||
method: str = Field(description="HTTP method (GET, POST, PUT, PATCH, DELETE)")
|
||||
url: str = Field(description="Target URL")
|
||||
headers: Optional[Dict[str, str]] = Field(
|
||||
default=None, description="Static headers to include"
|
||||
)
|
||||
credential_uuid: Optional[str] = Field(
|
||||
default=None, description="Reference to ExternalCredentialModel for auth"
|
||||
)
|
||||
parameters: Optional[List[ToolParameter]] = Field(
|
||||
default=None, description="Parameters that the tool accepts from LLM"
|
||||
)
|
||||
timeout_ms: Optional[int] = Field(
|
||||
default=5000, description="Request timeout in milliseconds"
|
||||
)
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
"""Tool definition schema."""
|
||||
|
||||
schema_version: int = Field(
|
||||
default=1, description="Schema version for compatibility"
|
||||
)
|
||||
type: str = Field(description="Tool type (http_api)")
|
||||
config: HttpApiConfig = Field(description="Tool configuration")
|
||||
|
||||
|
||||
class CreateToolRequest(BaseModel):
|
||||
"""Request schema for creating a tool."""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
description: Optional[str] = None
|
||||
category: str = Field(default=ToolCategory.HTTP_API.value)
|
||||
icon: Optional[str] = Field(default="globe", max_length=50)
|
||||
icon_color: Optional[str] = Field(default="#3B82F6", max_length=7)
|
||||
definition: ToolDefinition
|
||||
|
||||
|
||||
class UpdateToolRequest(BaseModel):
|
||||
"""Request schema for updating a tool."""
|
||||
|
||||
name: Optional[str] = Field(default=None, max_length=255)
|
||||
description: Optional[str] = None
|
||||
icon: Optional[str] = Field(default=None, max_length=50)
|
||||
icon_color: Optional[str] = Field(default=None, max_length=7)
|
||||
definition: Optional[ToolDefinition] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class CreatedByResponse(BaseModel):
|
||||
"""Response schema for the user who created a tool."""
|
||||
|
||||
id: int
|
||||
provider_id: str
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
"""Response schema for a tool."""
|
||||
|
||||
id: int
|
||||
tool_uuid: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
category: str
|
||||
icon: Optional[str]
|
||||
icon_color: Optional[str]
|
||||
status: str
|
||||
definition: Dict[str, Any]
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
created_by: Optional[CreatedByResponse] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
def build_tool_response(tool, include_created_by: bool = False) -> ToolResponse:
|
||||
"""Build a response from a tool model."""
|
||||
created_by = None
|
||||
if include_created_by and tool.created_by_user:
|
||||
created_by = CreatedByResponse(
|
||||
id=tool.created_by_user.id,
|
||||
provider_id=tool.created_by_user.provider_id,
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
id=tool.id,
|
||||
tool_uuid=tool.tool_uuid,
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
category=tool.category,
|
||||
icon=tool.icon,
|
||||
icon_color=tool.icon_color,
|
||||
status=tool.status,
|
||||
definition=tool.definition,
|
||||
created_at=tool.created_at,
|
||||
updated_at=tool.updated_at,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
|
||||
def validate_category(category: str) -> None:
|
||||
"""Validate that the category is valid."""
|
||||
valid_categories = [c.value for c in ToolCategory]
|
||||
if category not in valid_categories:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid category '{category}'. Must be one of: {', '.join(valid_categories)}",
|
||||
)
|
||||
|
||||
|
||||
def validate_status(status: str) -> None:
|
||||
"""Validate that the status is valid."""
|
||||
valid_statuses = [s.value for s in ToolStatus]
|
||||
if status not in valid_statuses:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status '{status}'. Must be one of: {', '.join(valid_statuses)}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_tools(
|
||||
status: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> List[ToolResponse]:
|
||||
"""
|
||||
List all tools for the user's organization.
|
||||
|
||||
Args:
|
||||
status: Optional filter by status (active, archived, draft)
|
||||
category: Optional filter by category (http_api, native, integration)
|
||||
|
||||
Returns:
|
||||
List of tools
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
if status:
|
||||
validate_status(status)
|
||||
if category:
|
||||
validate_category(category)
|
||||
|
||||
tools = await db_client.get_tools_for_organization(
|
||||
user.selected_organization_id,
|
||||
status=status,
|
||||
category=category,
|
||||
)
|
||||
|
||||
return [build_tool_response(tool) for tool in tools]
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_tool(
|
||||
request: CreateToolRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> ToolResponse:
|
||||
"""
|
||||
Create a new tool.
|
||||
|
||||
Args:
|
||||
request: The tool creation request
|
||||
|
||||
Returns:
|
||||
The created tool
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
validate_category(request.category)
|
||||
|
||||
try:
|
||||
tool = await db_client.create_tool(
|
||||
organization_id=user.selected_organization_id,
|
||||
user_id=user.id,
|
||||
name=request.name,
|
||||
definition=request.definition.model_dump(),
|
||||
category=request.category,
|
||||
description=request.description,
|
||||
icon=request.icon,
|
||||
icon_color=request.icon_color,
|
||||
)
|
||||
|
||||
return build_tool_response(tool)
|
||||
|
||||
except Exception as e:
|
||||
if "unique_org_tool_name" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A tool with the name '{request.name}' already exists",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{tool_uuid}")
|
||||
async def get_tool(
|
||||
tool_uuid: str,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> ToolResponse:
|
||||
"""
|
||||
Get a specific tool by UUID.
|
||||
|
||||
Args:
|
||||
tool_uuid: The UUID of the tool
|
||||
|
||||
Returns:
|
||||
The tool
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
tool = await db_client.get_tool_by_uuid(
|
||||
tool_uuid, user.selected_organization_id, include_archived=True
|
||||
)
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
|
||||
return build_tool_response(tool, include_created_by=True)
|
||||
|
||||
|
||||
@router.put("/{tool_uuid}")
|
||||
async def update_tool(
|
||||
tool_uuid: str,
|
||||
request: UpdateToolRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> ToolResponse:
|
||||
"""
|
||||
Update a tool.
|
||||
|
||||
Args:
|
||||
tool_uuid: The UUID of the tool to update
|
||||
request: The update request
|
||||
|
||||
Returns:
|
||||
The updated tool
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
if request.status:
|
||||
validate_status(request.status)
|
||||
|
||||
try:
|
||||
tool = await db_client.update_tool(
|
||||
tool_uuid=tool_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
definition=request.definition.model_dump() if request.definition else None,
|
||||
icon=request.icon,
|
||||
icon_color=request.icon_color,
|
||||
status=request.status,
|
||||
)
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
|
||||
return build_tool_response(tool, include_created_by=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if "unique_org_tool_name" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A tool with the name '{request.name}' already exists",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{tool_uuid}")
|
||||
async def delete_tool(
|
||||
tool_uuid: str,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> dict:
|
||||
"""
|
||||
Archive (soft delete) a tool.
|
||||
|
||||
Args:
|
||||
tool_uuid: The UUID of the tool to delete
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
deleted = await db_client.archive_tool(tool_uuid, user.selected_organization_id)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
|
||||
return {"status": "archived", "tool_uuid": tool_uuid}
|
||||
|
|
@ -57,6 +57,7 @@ class NodeDataDTO(BaseModel):
|
|||
detect_voicemail: bool = False
|
||||
delayed_start: bool = False
|
||||
delayed_start_duration: Optional[float] = None
|
||||
tool_uuids: Optional[List[str]] = None
|
||||
trigger_path: Optional[str] = None
|
||||
# Webhook node specific fields
|
||||
enabled: bool = True
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ import asyncio
|
|||
from loguru import logger
|
||||
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.services.workflow.pipecat_engine_utils import (
|
||||
get_function_schema,
|
||||
render_template,
|
||||
|
|
@ -105,6 +106,16 @@ class PipecatEngine:
|
|||
# Track current LLM reference text for TTS aggregation correction
|
||||
self._current_llm_reference_text: str = ""
|
||||
|
||||
# Custom tool manager (initialized in initialize())
|
||||
self._custom_tool_manager: Optional[CustomToolManager] = None
|
||||
|
||||
async def _get_organization_id(self) -> Optional[int]:
|
||||
"""Get and cache the organization ID from workflow run."""
|
||||
if self._custom_tool_manager:
|
||||
return await self._custom_tool_manager.get_organization_id()
|
||||
# Fallback for when manager is not yet initialized
|
||||
return await get_organization_id_from_workflow_run(self._workflow_run_id)
|
||||
|
||||
@property
|
||||
def builtin_function_schemas(self) -> list[dict]:
|
||||
"""Get built-in function schemas (calculator and timezone tools)."""
|
||||
|
|
@ -146,6 +157,9 @@ class PipecatEngine:
|
|||
# Helper that encapsulates variable extraction logic
|
||||
self._variable_extraction_manager = VariableExtractionManager(self)
|
||||
|
||||
# Helper that encapsulates custom tool management
|
||||
self._custom_tool_manager = CustomToolManager(self)
|
||||
|
||||
# Add current time in EST (America/New_York) to gathered context
|
||||
try:
|
||||
est_time_result = get_current_time("America/New_York")
|
||||
|
|
@ -360,6 +374,10 @@ class PipecatEngine:
|
|||
outgoing_edge.get_function_name(), outgoing_edge.target
|
||||
)
|
||||
|
||||
# Register custom tool handlers for this node
|
||||
if node.tool_uuids and self._custom_tool_manager:
|
||||
await self._custom_tool_manager.register_handlers(node.tool_uuids)
|
||||
|
||||
# Set up system message and functions
|
||||
(
|
||||
system_message,
|
||||
|
|
@ -492,9 +510,7 @@ class PipecatEngine:
|
|||
# Apply disposition mapping - first try call_disposition if it is,
|
||||
# extracted from the call conversation then fall back to reason
|
||||
call_disposition = self._gathered_context.get("call_disposition", "")
|
||||
organization_id = await get_organization_id_from_workflow_run(
|
||||
self._workflow_run_id
|
||||
)
|
||||
organization_id = await self._get_organization_id()
|
||||
|
||||
# If client is disconnected before we get a chance to disconnect from
|
||||
# the bot, lets consider that as final disposition
|
||||
|
|
@ -618,6 +634,13 @@ class PipecatEngine:
|
|||
# Add built-in function schemas (calculator and timezone tools)
|
||||
functions.extend(self.builtin_function_schemas)
|
||||
|
||||
# Add custom tools from node.tool_uuids
|
||||
if node.tool_uuids and self._custom_tool_manager:
|
||||
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(
|
||||
node.tool_uuids
|
||||
)
|
||||
functions.extend(custom_tool_schemas)
|
||||
|
||||
# Transition functions (schema only; registration handled elsewhere)
|
||||
for outgoing_edge in node.out_edges:
|
||||
function_schema = self._get_function_schema(
|
||||
|
|
|
|||
189
api/services/workflow/pipecat_engine_custom_tools.py
Normal file
189
api/services/workflow/pipecat_engine_custom_tools.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
"""Custom tool management for PipecatEngine.
|
||||
|
||||
This module handles fetching, registering, and executing user-defined tools
|
||||
during workflow execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.workflow.disposition_mapper import (
|
||||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_utils import get_function_schema
|
||||
from api.services.workflow.tools.custom_tool import (
|
||||
execute_http_tool,
|
||||
tool_to_function_schema,
|
||||
)
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.frames.frames import FunctionCallResultProperties
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
class CustomToolManager:
|
||||
"""Manager for custom tool registration and execution.
|
||||
|
||||
This class handles:
|
||||
1. Fetching tools from the database based on tool UUIDs
|
||||
2. Converting tools to LLM function schemas
|
||||
3. Registering tool execution handlers with the LLM
|
||||
4. Executing HTTP API tools when invoked by the LLM
|
||||
"""
|
||||
|
||||
def __init__(self, engine: "PipecatEngine") -> None:
|
||||
self._engine = engine
|
||||
self._organization_id: Optional[int] = None
|
||||
# Cache: maps function_name -> (tool, schema)
|
||||
self._tools_cache: dict[str, tuple[Any, dict]] = {}
|
||||
|
||||
async def get_organization_id(self) -> Optional[int]:
|
||||
"""Get and cache the organization ID from workflow run."""
|
||||
if self._organization_id is None:
|
||||
self._organization_id = await get_organization_id_from_workflow_run(
|
||||
self._engine._workflow_run_id
|
||||
)
|
||||
return self._organization_id
|
||||
|
||||
async def get_tool_schemas(self, tool_uuids: list[str]) -> list[FunctionSchema]:
|
||||
"""Fetch custom tools and convert them to function schemas.
|
||||
|
||||
Args:
|
||||
tool_uuids: List of tool UUIDs to fetch
|
||||
|
||||
Returns:
|
||||
List of FunctionSchema objects for LLM
|
||||
"""
|
||||
organization_id = await self.get_organization_id()
|
||||
if not organization_id:
|
||||
logger.warning("Cannot fetch custom tools: organization_id not available")
|
||||
return []
|
||||
|
||||
try:
|
||||
tools = await db_client.get_tools_by_uuids(tool_uuids, organization_id)
|
||||
|
||||
schemas: list[FunctionSchema] = []
|
||||
for tool in tools:
|
||||
raw_schema = tool_to_function_schema(tool)
|
||||
function_name = raw_schema["function"]["name"]
|
||||
|
||||
# Cache the tool for later execution
|
||||
self._tools_cache[function_name] = (tool, raw_schema)
|
||||
|
||||
# Convert to FunctionSchema object for compatibility with update_llm_context
|
||||
func_schema = get_function_schema(
|
||||
function_name,
|
||||
raw_schema["function"]["description"],
|
||||
properties=raw_schema["function"]["parameters"].get(
|
||||
"properties", {}
|
||||
),
|
||||
required=raw_schema["function"]["parameters"].get("required", []),
|
||||
)
|
||||
schemas.append(func_schema)
|
||||
|
||||
logger.debug(
|
||||
f"Loaded {len(schemas)} custom tools for node: "
|
||||
f"{[s.name for s in schemas]}"
|
||||
)
|
||||
return schemas
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch custom tools: {e}")
|
||||
return []
|
||||
|
||||
async def register_handlers(self, tool_uuids: list[str]) -> None:
|
||||
"""Register custom tool execution handlers with the LLM.
|
||||
|
||||
Args:
|
||||
tool_uuids: List of tool UUIDs to register handlers for
|
||||
"""
|
||||
organization_id = await self.get_organization_id()
|
||||
if not organization_id:
|
||||
logger.warning(
|
||||
"Cannot register custom tool handlers: organization_id not available"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
tools = await db_client.get_tools_by_uuids(tool_uuids, organization_id)
|
||||
|
||||
for tool in tools:
|
||||
schema = tool_to_function_schema(tool)
|
||||
function_name = schema["function"]["name"]
|
||||
|
||||
# Cache the tool for potential later use
|
||||
self._tools_cache[function_name] = (tool, schema)
|
||||
|
||||
# Create and register the handler
|
||||
handler = self._create_handler(tool, function_name)
|
||||
self._engine.llm.register_function(function_name, handler)
|
||||
|
||||
logger.debug(
|
||||
f"Registered custom tool handler: {function_name} "
|
||||
f"(tool_uuid: {tool.tool_uuid})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register custom tool handlers: {e}")
|
||||
|
||||
def _create_handler(self, tool: Any, function_name: str):
|
||||
"""Create a handler function for a custom tool.
|
||||
|
||||
Args:
|
||||
tool: The ToolModel instance
|
||||
function_name: The function name used by the LLM
|
||||
|
||||
Returns:
|
||||
Async handler function for the tool
|
||||
"""
|
||||
# Run LLM after tool execution to continue conversation
|
||||
properties = FunctionCallResultProperties(run_llm=True)
|
||||
|
||||
async def custom_tool_handler(
|
||||
function_call_params: FunctionCallParams,
|
||||
) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: {function_name}")
|
||||
logger.info(f"Arguments: {function_call_params.arguments}")
|
||||
|
||||
try:
|
||||
# Execute the HTTP API tool
|
||||
result = await execute_http_tool(
|
||||
tool=tool,
|
||||
arguments=function_call_params.arguments,
|
||||
call_context_vars=self._engine._call_context_vars,
|
||||
organization_id=self._organization_id,
|
||||
)
|
||||
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Custom tool '{function_name}' execution failed: {e}")
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": str(e)},
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
return custom_tool_handler
|
||||
|
||||
def get_cached_tool(self, function_name: str) -> Optional[tuple[Any, dict]]:
|
||||
"""Get a cached tool by its function name.
|
||||
|
||||
Args:
|
||||
function_name: The function name used by the LLM
|
||||
|
||||
Returns:
|
||||
Tuple of (tool, schema) if found, None otherwise
|
||||
"""
|
||||
return self._tools_cache.get(function_name)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the tools cache."""
|
||||
self._tools_cache.clear()
|
||||
180
api/services/workflow/tools/custom_tool.py
Normal file
180
api/services/workflow/tools/custom_tool.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
"""Custom tool execution for user-defined HTTP API tools."""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.utils.credential_auth import build_auth_header
|
||||
|
||||
# Map tool parameter types to JSON schema types
|
||||
TYPE_MAP = {
|
||||
"string": "string",
|
||||
"number": "number",
|
||||
"boolean": "boolean",
|
||||
}
|
||||
|
||||
|
||||
def tool_to_function_schema(tool: Any) -> Dict[str, Any]:
|
||||
"""Convert a ToolModel to an LLM function schema.
|
||||
|
||||
Args:
|
||||
tool: ToolModel instance with name, description, and definition
|
||||
|
||||
Returns:
|
||||
Function schema dict compatible with OpenAI/Anthropic function calling
|
||||
"""
|
||||
definition = tool.definition or {}
|
||||
config = definition.get("config", {})
|
||||
parameters = config.get("parameters", []) or []
|
||||
|
||||
# Build properties and required list from parameters
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param in parameters:
|
||||
param_name = param.get("name", "")
|
||||
param_type = param.get("type", "string")
|
||||
param_desc = param.get("description", "")
|
||||
param_required = param.get("required", True)
|
||||
|
||||
if not param_name:
|
||||
continue
|
||||
|
||||
properties[param_name] = {
|
||||
"type": TYPE_MAP.get(param_type, "string"),
|
||||
"description": param_desc,
|
||||
}
|
||||
|
||||
if param_required:
|
||||
required.append(param_name)
|
||||
|
||||
# Sanitize tool name for function name (lowercase, underscores only)
|
||||
function_name = re.sub(r"[^a-z0-9_]", "_", tool.name.lower())
|
||||
# Remove consecutive underscores and trim
|
||||
function_name = re.sub(r"_+", "_", function_name).strip("_")
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"description": tool.description or f"Execute {tool.name} tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
},
|
||||
},
|
||||
"_tool_uuid": tool.tool_uuid,
|
||||
}
|
||||
|
||||
|
||||
async def execute_http_tool(
|
||||
tool: Any,
|
||||
arguments: Dict[str, Any],
|
||||
call_context_vars: Optional[Dict[str, Any]] = None,
|
||||
organization_id: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute an HTTP API tool.
|
||||
|
||||
Args:
|
||||
tool: ToolModel instance
|
||||
arguments: Arguments passed by the LLM (parameter name -> value)
|
||||
call_context_vars: Additional context variables from the call (unused for now)
|
||||
organization_id: Organization ID for credential lookup
|
||||
|
||||
Returns:
|
||||
Result dict with response data or error
|
||||
"""
|
||||
definition = tool.definition or {}
|
||||
config = definition.get("config", {})
|
||||
|
||||
# Get HTTP method and URL
|
||||
method = config.get("method", "POST").upper()
|
||||
url = config.get("url", "")
|
||||
|
||||
# Get headers from config
|
||||
headers = dict(config.get("headers", {}) or {})
|
||||
|
||||
# Add auth header if credential is configured
|
||||
credential_uuid = config.get("credential_uuid")
|
||||
if credential_uuid and organization_id:
|
||||
try:
|
||||
credential = await db_client.get_credential_by_uuid(
|
||||
credential_uuid, organization_id
|
||||
)
|
||||
if credential:
|
||||
auth_header = build_auth_header(credential)
|
||||
headers.update(auth_header)
|
||||
logger.debug(f"Applied credential '{credential.name}' to tool request")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Credential {credential_uuid} not found for tool '{tool.name}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch credential for tool '{tool.name}': {e}")
|
||||
|
||||
# Get timeout
|
||||
timeout_ms = config.get("timeout_ms", 5000)
|
||||
timeout_seconds = timeout_ms / 1000
|
||||
|
||||
# Build request: JSON body for POST/PUT/PATCH, query params for GET/DELETE
|
||||
body = None
|
||||
params = None
|
||||
if method in ("POST", "PUT", "PATCH"):
|
||||
body = arguments
|
||||
elif method in ("GET", "DELETE") and arguments:
|
||||
params = arguments
|
||||
|
||||
logger.info(
|
||||
f"Executing custom tool '{tool.name}' ({tool.tool_uuid}): {method} {url}"
|
||||
)
|
||||
logger.debug(f"Request body: {body}, params: {params}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout_seconds) as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
json=body,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Try to parse JSON response
|
||||
try:
|
||||
response_data = response.json()
|
||||
except Exception:
|
||||
response_data = {"raw_response": response.text}
|
||||
|
||||
result = {
|
||||
"status": "success",
|
||||
"status_code": response.status_code,
|
||||
"data": response_data,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"Custom tool '{tool.name}' completed with status {response.status_code}"
|
||||
)
|
||||
return result
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error(f"Custom tool '{tool.name}' timed out after {timeout_seconds}s")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Request timed out after {timeout_seconds} seconds",
|
||||
}
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Custom tool '{tool.name}' request failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Request failed: {str(e)}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Custom tool '{tool.name}' execution failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Tool execution failed: {str(e)}",
|
||||
}
|
||||
|
|
@ -47,6 +47,7 @@ class Node:
|
|||
self.detect_voicemail = data.detect_voicemail
|
||||
self.delayed_start = data.delayed_start
|
||||
self.delayed_start_duration = data.delayed_start_duration
|
||||
self.tool_uuids = data.tool_uuids
|
||||
|
||||
self.data = data
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
"""Execute webhook integrations after workflow run completion."""
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import ExternalCredentialModel, WorkflowRunModel
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.utils.credential_auth import build_auth_header
|
||||
from api.utils.template_renderer import render_template
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
|
||||
|
|
@ -133,7 +133,7 @@ async def _execute_webhook_node(
|
|||
credential_uuid, organization_id
|
||||
)
|
||||
if credential:
|
||||
auth_header = _build_auth_header(credential)
|
||||
auth_header = build_auth_header(credential)
|
||||
headers.update(auth_header)
|
||||
logger.debug(f"Applied credential '{credential.name}' to webhook")
|
||||
else:
|
||||
|
|
@ -189,39 +189,3 @@ async def _execute_webhook_node(
|
|||
except Exception as e:
|
||||
logger.error(f"Webhook '{webhook_name}' unexpected error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _build_auth_header(credential: ExternalCredentialModel) -> Dict[str, str]:
|
||||
"""
|
||||
Build authentication header based on credential type.
|
||||
|
||||
Args:
|
||||
credential: The credential model
|
||||
|
||||
Returns:
|
||||
Dict with header name and value
|
||||
"""
|
||||
cred_type = credential.credential_type
|
||||
cred_data = credential.credential_data or {}
|
||||
|
||||
if cred_type == "bearer_token":
|
||||
token = cred_data.get("token", "")
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
elif cred_type == "api_key":
|
||||
header_name = cred_data.get("header_name", "X-API-Key")
|
||||
api_key = cred_data.get("api_key", "")
|
||||
return {header_name: api_key}
|
||||
|
||||
elif cred_type == "basic_auth":
|
||||
username = cred_data.get("username", "")
|
||||
password = cred_data.get("password", "")
|
||||
encoded = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
return {"Authorization": f"Basic {encoded}"}
|
||||
|
||||
elif cred_type == "custom_header":
|
||||
header_name = cred_data.get("header_name", "X-Custom")
|
||||
header_value = cred_data.get("header_value", "")
|
||||
return {header_name: header_value}
|
||||
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -1,138 +0,0 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai.llm import OpenAIAssistantContextAggregator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reordering_after_completion():
|
||||
context = OpenAILLMContext()
|
||||
aggr = OpenAIAssistantContextAggregator(context)
|
||||
|
||||
# Initialize task manager properly using PipelineTask
|
||||
pipeline = Pipeline([aggr])
|
||||
task = PipelineTask(pipeline)
|
||||
runner = PipelineRunner()
|
||||
|
||||
# Start the task to properly initialize the frame processor
|
||||
task_coroutine = asyncio.create_task(runner.run(task))
|
||||
|
||||
# Give the task a moment to initialize
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# start new LLM response
|
||||
await aggr.process_frame(LLMFullResponseStartFrame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
# simulate a pending function call
|
||||
await aggr.process_frame(
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="transition",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
cancel_on_interruption=False,
|
||||
),
|
||||
FrameDirection.DOWNSTREAM,
|
||||
)
|
||||
|
||||
# now text arrives
|
||||
await aggr.process_frame(TextFrame("Hi there"), FrameDirection.DOWNSTREAM)
|
||||
|
||||
# end response
|
||||
await aggr.process_frame(LLMFullResponseEndFrame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
msgs = context.get_messages()
|
||||
|
||||
# Assert order: assistant text first, then tool_call assistant, then tool response
|
||||
assert msgs[0]["role"] == "assistant" and "tool_calls" not in msgs[0]
|
||||
# Fix: content is a string, not a structured object
|
||||
assert msgs[0]["content"] == "Hi there"
|
||||
assert any(m.get("role") == "assistant" and m.get("tool_calls") for m in msgs[1:])
|
||||
assert any(m.get("role") == "tool" for m in msgs[1:])
|
||||
|
||||
# Clean up the running task
|
||||
await task.cancel()
|
||||
task_coroutine.cancel()
|
||||
try:
|
||||
await task_coroutine
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interruption_removes_pending_function_calls_and_marks():
|
||||
context = OpenAILLMContext()
|
||||
aggr = OpenAIAssistantContextAggregator(context)
|
||||
|
||||
# Initialize task manager properly using PipelineTask
|
||||
pipeline = Pipeline([aggr])
|
||||
task = PipelineTask(pipeline)
|
||||
runner = PipelineRunner()
|
||||
|
||||
# Start the task to properly initialize the frame processor
|
||||
task_coroutine = asyncio.create_task(runner.run(task))
|
||||
|
||||
# Give the task a moment to initialize
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await aggr.process_frame(LLMFullResponseStartFrame(), FrameDirection.DOWNSTREAM)
|
||||
await aggr.process_frame(
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="transition",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
cancel_on_interruption=False,
|
||||
),
|
||||
FrameDirection.DOWNSTREAM,
|
||||
)
|
||||
|
||||
# Debug: Check the state before interruption
|
||||
print(
|
||||
f"Function calls in progress before interruption: {aggr._function_calls_in_progress}"
|
||||
)
|
||||
print(f"Messages before interruption: {context.get_messages()}")
|
||||
|
||||
# no text yet - still aggregation
|
||||
await aggr.process_frame(StartInterruptionFrame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
msgs = context.get_messages()
|
||||
|
||||
# Debug: Print messages to understand what's happening
|
||||
print(f"Messages after interruption: {msgs}")
|
||||
print(
|
||||
f"Function calls in progress after interruption: {aggr._function_calls_in_progress}"
|
||||
)
|
||||
|
||||
# After interruption before any response is complete, context should be cleared
|
||||
# This is the actual behavior - interruptions clear pending function calls
|
||||
if len(msgs) == 0:
|
||||
# Context was cleared due to interruption before completion
|
||||
assert True
|
||||
else:
|
||||
# If there are messages, ensure no tool calls remain
|
||||
assert not any(m.get("tool_calls") for m in msgs)
|
||||
assert not any(m.get("role") == "tool" for m in msgs)
|
||||
|
||||
# Check if interruption marker is present
|
||||
if msgs:
|
||||
assert msgs[-1]["role"] == "assistant"
|
||||
assert "<<interrupted_by_user>>" in msgs[-1]["content"]
|
||||
|
||||
# Clean up the running task
|
||||
await task.cancel()
|
||||
task_coroutine.cancel()
|
||||
try:
|
||||
await task_coroutine
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
|
@ -1,120 +0,0 @@
|
|||
import os
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.pipecat.audio_transcript_buffers import (
|
||||
InMemoryAudioBuffer,
|
||||
InMemoryTranscriptBuffer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_buffer_append_and_write():
|
||||
"""Test that audio buffer can append data and write to temp file."""
|
||||
buffer = InMemoryAudioBuffer(workflow_run_id=123, sample_rate=16000, num_channels=1)
|
||||
|
||||
# Create some test PCM data
|
||||
test_pcm = b"\x00\x01" * 1000 # 2000 bytes
|
||||
|
||||
# Append data
|
||||
await buffer.append(test_pcm)
|
||||
await buffer.append(test_pcm)
|
||||
|
||||
assert buffer.size == 4000
|
||||
assert not buffer.is_empty
|
||||
|
||||
# Write to temp file
|
||||
temp_path = await buffer.write_to_temp_file()
|
||||
|
||||
try:
|
||||
# Verify file exists and is valid WAV
|
||||
assert os.path.exists(temp_path)
|
||||
|
||||
with wave.open(temp_path, "rb") as wf:
|
||||
assert wf.getnchannels() == 1
|
||||
assert wf.getsampwidth() == 2
|
||||
assert wf.getframerate() == 16000
|
||||
# Each frame is 2 bytes (16-bit), so 4000 bytes = 2000 frames
|
||||
assert wf.getnframes() == 2000
|
||||
finally:
|
||||
# Clean up
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_buffer_memory_limit():
|
||||
"""Test that audio buffer enforces memory limit."""
|
||||
buffer = InMemoryAudioBuffer(workflow_run_id=123, sample_rate=16000)
|
||||
|
||||
# Set a smaller limit for testing
|
||||
buffer._max_size = 1000
|
||||
|
||||
# This should work
|
||||
await buffer.append(b"\x00" * 500)
|
||||
|
||||
# This should fail
|
||||
with pytest.raises(MemoryError):
|
||||
await buffer.append(b"\x00" * 600)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_buffer_append_and_write():
|
||||
"""Test that transcript buffer can append data and write to temp file."""
|
||||
buffer = InMemoryTranscriptBuffer(workflow_run_id=456)
|
||||
|
||||
# Append some transcript lines
|
||||
await buffer.append("[00:00:01] user: Hello\n")
|
||||
await buffer.append("[00:00:02] assistant: Hi there!\n")
|
||||
await buffer.append("[00:00:03] user: How are you?\n")
|
||||
|
||||
assert not buffer.is_empty
|
||||
|
||||
# Write to temp file
|
||||
temp_path = await buffer.write_to_temp_file()
|
||||
|
||||
try:
|
||||
# Verify file exists and has correct content
|
||||
assert os.path.exists(temp_path)
|
||||
|
||||
with open(temp_path, "r") as f:
|
||||
content = f.read()
|
||||
assert "[00:00:01] user: Hello\n" in content
|
||||
assert "[00:00:02] assistant: Hi there!\n" in content
|
||||
assert "[00:00:03] user: How are you?\n" in content
|
||||
finally:
|
||||
# Clean up
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_buffers():
|
||||
"""Test that empty buffers are handled correctly."""
|
||||
audio_buffer = InMemoryAudioBuffer(workflow_run_id=789, sample_rate=16000)
|
||||
transcript_buffer = InMemoryTranscriptBuffer(workflow_run_id=789)
|
||||
|
||||
assert audio_buffer.is_empty
|
||||
assert transcript_buffer.is_empty
|
||||
|
||||
# Should still be able to write empty files
|
||||
audio_path = await audio_buffer.write_to_temp_file()
|
||||
transcript_path = await transcript_buffer.write_to_temp_file()
|
||||
|
||||
try:
|
||||
assert os.path.exists(audio_path)
|
||||
assert os.path.exists(transcript_path)
|
||||
|
||||
# Empty WAV file should still have valid header
|
||||
with wave.open(audio_path, "rb") as wf:
|
||||
assert wf.getnframes() == 0
|
||||
|
||||
# Empty transcript file
|
||||
with open(transcript_path, "r") as f:
|
||||
assert f.read() == ""
|
||||
finally:
|
||||
if os.path.exists(audio_path):
|
||||
os.remove(audio_path)
|
||||
if os.path.exists(transcript_path):
|
||||
os.remove(transcript_path)
|
||||
|
|
@ -1,330 +0,0 @@
|
|||
"""Tests for concurrent call limiting functionality."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.services.campaign.rate_limiter import RateLimiter
|
||||
|
||||
|
||||
class TestConcurrentCallLimiting:
|
||||
"""Test suite for concurrent call limiting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_concurrent_slot_success(self):
|
||||
"""Test successful acquisition of concurrent slot."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.eval = AsyncMock(return_value="test_slot_123")
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Try to acquire slot
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
organization_id=1, max_concurrent=20
|
||||
)
|
||||
|
||||
assert slot_id == "test_slot_123"
|
||||
mock_client.eval.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_concurrent_slot_limit_reached(self):
|
||||
"""Test slot acquisition when limit is reached."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.eval = AsyncMock(return_value=None) # Limit reached
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Try to acquire slot
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
organization_id=1, max_concurrent=20
|
||||
)
|
||||
|
||||
assert slot_id is None
|
||||
mock_client.eval.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_concurrent_slot(self):
|
||||
"""Test releasing a concurrent slot."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.zrem = AsyncMock(return_value=1) # Successfully removed
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Release slot
|
||||
success = await rate_limiter.release_concurrent_slot(
|
||||
organization_id=1, slot_id="test_slot_123"
|
||||
)
|
||||
|
||||
assert success is True
|
||||
mock_client.zrem.assert_called_once_with(
|
||||
"concurrent_calls:1", "test_slot_123"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_concurrent_count(self):
|
||||
"""Test getting current concurrent call count."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.zremrangebyscore = AsyncMock() # Cleanup stale entries
|
||||
mock_client.zcard = AsyncMock(return_value=5) # 5 active calls
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Get count
|
||||
count = await rate_limiter.get_concurrent_count(organization_id=1)
|
||||
|
||||
assert count == 5
|
||||
mock_client.zremrangebyscore.assert_called_once()
|
||||
mock_client.zcard.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stale_entry_cleanup(self):
|
||||
"""Test that stale entries are cleaned up automatically."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
|
||||
# Mock eval to simulate cleanup in Lua script
|
||||
mock_client.eval = AsyncMock(return_value="new_slot_123")
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Try to acquire slot (which should trigger cleanup)
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
organization_id=1, max_concurrent=20
|
||||
)
|
||||
|
||||
assert slot_id == "new_slot_123"
|
||||
|
||||
# Verify Lua script was called with proper stale cutoff
|
||||
call_args = mock_client.eval.call_args[0]
|
||||
lua_script = call_args[0]
|
||||
assert "ZREMRANGEBYSCORE" in lua_script # Cleanup command in script
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_slot_mapping_operations(self):
|
||||
"""Test storing, retrieving, and deleting workflow slot mappings."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.hset = AsyncMock(return_value=1)
|
||||
mock_client.expire = AsyncMock(return_value=True)
|
||||
mock_client.hgetall = AsyncMock(
|
||||
return_value={"org_id": "1", "slot_id": "test_slot_123"}
|
||||
)
|
||||
mock_client.delete = AsyncMock(return_value=1)
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Test storing mapping
|
||||
success = await rate_limiter.store_workflow_slot_mapping(
|
||||
workflow_run_id=999, organization_id=1, slot_id="test_slot_123"
|
||||
)
|
||||
assert success is True
|
||||
mock_client.hset.assert_called_once()
|
||||
mock_client.expire.assert_called_once()
|
||||
|
||||
# Test retrieving mapping
|
||||
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run_id=999)
|
||||
assert mapping == (1, "test_slot_123")
|
||||
mock_client.hgetall.assert_called_once_with("workflow_slot_mapping:999")
|
||||
|
||||
# Test deleting mapping
|
||||
deleted = await rate_limiter.delete_workflow_slot_mapping(
|
||||
workflow_run_id=999
|
||||
)
|
||||
assert deleted is True
|
||||
mock_client.delete.assert_called_once_with("workflow_slot_mapping:999")
|
||||
|
||||
|
||||
class TestCampaignCallDispatcher:
|
||||
"""Test suite for CampaignCallDispatcher with concurrent limiting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_call_waits_for_slot(self):
|
||||
"""Test that dispatch_call waits for available slot."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock dependencies
|
||||
mock_campaign = MagicMock(
|
||||
organization_id=1, workflow_id=123, id=456, created_by=789
|
||||
)
|
||||
mock_queued_run = MagicMock(
|
||||
id=111, context_variables={"phone_number": "+1234567890"}
|
||||
)
|
||||
|
||||
# Mock rate limiter to simulate waiting
|
||||
slot_acquired = False
|
||||
call_count = 0
|
||||
|
||||
async def mock_try_acquire(org_id, max_concurrent):
|
||||
nonlocal slot_acquired, call_count
|
||||
call_count += 1
|
||||
if call_count > 2: # Succeed on third try
|
||||
slot_acquired = True
|
||||
return "test_slot_123"
|
||||
return None
|
||||
|
||||
with patch(
|
||||
"api.services.campaign.call_dispatcher.rate_limiter"
|
||||
) as mock_limiter:
|
||||
mock_limiter.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_try_acquire
|
||||
)
|
||||
mock_limiter.release_concurrent_slot = AsyncMock()
|
||||
mock_limiter.store_workflow_slot_mapping = AsyncMock(return_value=True)
|
||||
|
||||
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
|
||||
mock_db.get_configuration = AsyncMock(return_value=None)
|
||||
mock_db.get_workflow_by_id = AsyncMock(
|
||||
return_value=MagicMock(template_context_variables={})
|
||||
)
|
||||
mock_db.create_workflow_run = AsyncMock(
|
||||
return_value=MagicMock(id=999, logs={})
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
dispatcher.twilio_service, "initiate_call"
|
||||
) as mock_twilio:
|
||||
mock_twilio.return_value = {"sid": "test_sid"}
|
||||
|
||||
# Dispatch call (should wait and retry)
|
||||
workflow_run = await dispatcher.dispatch_call(
|
||||
mock_queued_run, mock_campaign
|
||||
)
|
||||
|
||||
assert workflow_run is not None
|
||||
assert slot_acquired is True
|
||||
assert call_count == 3 # Tried 3 times
|
||||
assert mock_limiter.try_acquire_concurrent_slot.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_call_stores_slot_mapping(self):
|
||||
"""Test that dispatch_call stores slot mapping in Redis."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock dependencies
|
||||
mock_campaign = MagicMock(
|
||||
organization_id=1, workflow_id=123, id=456, created_by=789
|
||||
)
|
||||
mock_queued_run = MagicMock(
|
||||
id=111, context_variables={"phone_number": "+1234567890"}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.campaign.call_dispatcher.rate_limiter"
|
||||
) as mock_limiter:
|
||||
mock_limiter.try_acquire_concurrent_slot = AsyncMock(
|
||||
return_value="test_slot_123"
|
||||
)
|
||||
mock_limiter.store_workflow_slot_mapping = AsyncMock(return_value=True)
|
||||
|
||||
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
|
||||
mock_db.get_configuration = AsyncMock(return_value=None)
|
||||
mock_db.get_workflow_by_id = AsyncMock(
|
||||
return_value=MagicMock(template_context_variables={})
|
||||
)
|
||||
mock_db.create_workflow_run = AsyncMock(
|
||||
return_value=MagicMock(id=999, logs={})
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
dispatcher.twilio_service, "initiate_call"
|
||||
) as mock_twilio:
|
||||
mock_twilio.return_value = {"sid": "test_sid"}
|
||||
|
||||
# Dispatch call
|
||||
workflow_run = await dispatcher.dispatch_call(
|
||||
mock_queued_run, mock_campaign
|
||||
)
|
||||
|
||||
# Verify slot mapping was stored
|
||||
mock_limiter.store_workflow_slot_mapping.assert_called_once_with(
|
||||
999, 1, "test_slot_123"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_org_specific_concurrent_limit(self):
|
||||
"""Test that organization-specific concurrent limit is used."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock db_client to return org-specific limit
|
||||
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
|
||||
mock_config = MagicMock(value={"value": 10}) # Org limit is 10
|
||||
mock_db.get_configuration = AsyncMock(return_value=mock_config)
|
||||
|
||||
# Get org limit
|
||||
limit = await dispatcher.get_org_concurrent_limit(organization_id=1)
|
||||
|
||||
assert limit == 10 # Should use org-specific limit
|
||||
mock_db.get_configuration.assert_called_once_with(
|
||||
1, OrganizationConfigurationKey.CONCURRENT_CALL_LIMIT.value
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_concurrent_limit(self):
|
||||
"""Test that default limit is used when org config not found."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock db_client to return None (no config)
|
||||
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
|
||||
mock_db.get_configuration = AsyncMock(return_value=None)
|
||||
|
||||
# Get org limit
|
||||
limit = await dispatcher.get_org_concurrent_limit(organization_id=1)
|
||||
|
||||
assert limit == 20 # Should use default limit
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_call_slot(self):
|
||||
"""Test releasing call slot when workflow completes."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock rate limiter
|
||||
with patch(
|
||||
"api.services.campaign.call_dispatcher.rate_limiter"
|
||||
) as mock_limiter:
|
||||
# Mock getting the slot mapping from Redis
|
||||
mock_limiter.get_workflow_slot_mapping = AsyncMock(
|
||||
return_value=(1, "test_slot_123")
|
||||
)
|
||||
mock_limiter.release_concurrent_slot = AsyncMock(return_value=True)
|
||||
mock_limiter.delete_workflow_slot_mapping = AsyncMock(return_value=True)
|
||||
|
||||
# Release slot
|
||||
success = await dispatcher.release_call_slot(workflow_run_id=999)
|
||||
|
||||
assert success is True
|
||||
mock_limiter.get_workflow_slot_mapping.assert_called_once_with(999)
|
||||
mock_limiter.release_concurrent_slot.assert_called_once_with(
|
||||
1, "test_slot_123"
|
||||
)
|
||||
mock_limiter.delete_workflow_slot_mapping.assert_called_once_with(999)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
@ -1,78 +0,0 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
|
||||
from api.services.configuration.merge import merge_user_configurations
|
||||
from api.services.configuration.registry import (
|
||||
OpenAILLMService,
|
||||
)
|
||||
|
||||
REAL_KEY = "sk-1234567890abcdef"
|
||||
|
||||
|
||||
def _build_config_with_openai(key: str) -> UserConfiguration:
|
||||
return UserConfiguration(
|
||||
llm=OpenAILLMService(api_key=key),
|
||||
stt=None,
|
||||
tts=None,
|
||||
)
|
||||
|
||||
|
||||
def test_mask_key_basic():
|
||||
masked = mask_key(REAL_KEY)
|
||||
# Should reveal only last 4 chars
|
||||
assert masked.endswith(REAL_KEY[-4:])
|
||||
assert set(masked[:-4]) == {"*"}
|
||||
assert len(masked) == len(REAL_KEY)
|
||||
# is_mask_of round-trip
|
||||
assert is_mask_of(masked, REAL_KEY)
|
||||
|
||||
|
||||
def test_mask_user_config_masks_api_keys():
|
||||
cfg = _build_config_with_openai(REAL_KEY)
|
||||
dumped = mask_user_config(cfg)
|
||||
assert dumped["llm"]["api_key"].endswith(REAL_KEY[-4:])
|
||||
assert dumped["llm"]["api_key"].startswith("*" * (len(REAL_KEY) - 4))
|
||||
|
||||
|
||||
def test_merge_preserves_key_when_mask_sent():
|
||||
existing = _build_config_with_openai(REAL_KEY)
|
||||
incoming_partial = {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"model": existing.llm.model,
|
||||
"api_key": mask_key(REAL_KEY), # masked placeholder
|
||||
}
|
||||
}
|
||||
|
||||
merged = merge_user_configurations(existing, incoming_partial)
|
||||
assert merged.llm.api_key == REAL_KEY # key preserved
|
||||
|
||||
|
||||
def test_merge_replaces_key_when_new_key_provided():
|
||||
existing = _build_config_with_openai(REAL_KEY)
|
||||
new_key = "sk-replaced-9999"
|
||||
incoming_partial = {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"model": existing.llm.model,
|
||||
"api_key": new_key,
|
||||
}
|
||||
}
|
||||
merged = merge_user_configurations(existing, incoming_partial)
|
||||
assert merged.llm.api_key == new_key
|
||||
|
||||
|
||||
def test_merge_drops_old_key_when_provider_changes():
|
||||
existing = _build_config_with_openai(REAL_KEY)
|
||||
incoming_partial = {
|
||||
"llm": {
|
||||
"provider": "groq",
|
||||
"model": "llama-3.3-70b-versatile",
|
||||
# api_key intentionally absent – should NOT inherit old key
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
merge_user_configurations(existing, incoming_partial)
|
||||
1041
api/tests/test_custom_tools.py
Normal file
1041
api/tests/test_custom_tools.py
Normal file
File diff suppressed because it is too large
Load diff
512
api/tests/test_custom_tools_context_integration.py
Normal file
512
api/tests/test_custom_tools_context_integration.py
Normal file
|
|
@ -0,0 +1,512 @@
|
|||
"""Integration tests for CustomToolManager with update_llm_context.
|
||||
|
||||
This module tests the full flow of:
|
||||
1. CustomToolManager fetching and converting tool schemas
|
||||
2. update_llm_context setting those tools on the LLM context
|
||||
3. Verifying the context is properly configured for LLM generation
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.services.workflow.pipecat_engine_utils import (
|
||||
get_function_schema,
|
||||
update_llm_context,
|
||||
)
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockToolModel:
|
||||
"""Mock tool model for testing."""
|
||||
|
||||
tool_uuid: str
|
||||
name: str
|
||||
description: str
|
||||
definition: Dict[str, Any]
|
||||
|
||||
|
||||
class TestCustomToolManagerContextIntegration:
|
||||
"""Integration tests for CustomToolManager with LLMContext."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine(self):
|
||||
"""Create a mock PipecatEngine."""
|
||||
engine = Mock()
|
||||
engine._workflow_run_id = 1
|
||||
engine._call_context_vars = {"customer_name": "John Doe"}
|
||||
engine.llm = Mock()
|
||||
engine.llm.register_function = Mock()
|
||||
return engine
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools(self):
|
||||
"""Create sample mock tools for testing."""
|
||||
return [
|
||||
MockToolModel(
|
||||
tool_uuid="weather-uuid-123",
|
||||
name="Get Weather",
|
||||
description="Get current weather for a location",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "GET",
|
||||
"url": "https://api.weather.com/current",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "location",
|
||||
"type": "string",
|
||||
"description": "City name (e.g., San Francisco, CA)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "units",
|
||||
"type": "string",
|
||||
"description": "Temperature units: celsius or fahrenheit",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
MockToolModel(
|
||||
tool_uuid="booking-uuid-456",
|
||||
name="Book Appointment",
|
||||
description="Book an appointment for the customer",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://api.example.com/appointments",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "customer_name",
|
||||
"type": "string",
|
||||
"description": "Customer's full name",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "date",
|
||||
"type": "string",
|
||||
"description": "Appointment date (YYYY-MM-DD)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "time",
|
||||
"type": "string",
|
||||
"description": "Appointment time (HH:MM)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "notes",
|
||||
"type": "string",
|
||||
"description": "Additional notes",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
MockToolModel(
|
||||
tool_uuid="lookup-uuid-789",
|
||||
name="Customer Lookup",
|
||||
description="Look up customer information by phone number",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "GET",
|
||||
"url": "https://api.example.com/customers/lookup",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "phone",
|
||||
"type": "string",
|
||||
"description": "Customer phone number",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tool_schemas_and_update_context(self, mock_engine, sample_tools):
|
||||
"""Test fetching tool schemas via CustomToolManager and updating LLM context."""
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
|
||||
|
||||
# Get tool schemas via CustomToolManager - now returns FunctionSchema objects
|
||||
tool_uuids = ["weather-uuid-123", "booking-uuid-456", "lookup-uuid-789"]
|
||||
schemas = await manager.get_tool_schemas(tool_uuids)
|
||||
|
||||
# Verify schemas were returned as FunctionSchema objects
|
||||
assert len(schemas) == 3
|
||||
assert all(isinstance(s, FunctionSchema) for s in schemas)
|
||||
|
||||
# Create context with conversation history
|
||||
context = LLMContext()
|
||||
context.set_messages(
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I need to check the weather and book an appointment.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I can help with both. Where would you like to check the weather?",
|
||||
},
|
||||
{"role": "user", "content": "San Francisco"},
|
||||
]
|
||||
)
|
||||
|
||||
# Update context with new system message and tools
|
||||
# Now we can pass schemas directly since they're FunctionSchema objects
|
||||
new_system = {
|
||||
"role": "system",
|
||||
"content": "You are a scheduling assistant with access to weather and booking tools.",
|
||||
}
|
||||
update_llm_context(context, new_system, schemas)
|
||||
|
||||
# Verify context was updated correctly
|
||||
messages = context.messages
|
||||
assert len(messages) == 4
|
||||
assert (
|
||||
messages[0]["content"]
|
||||
== "You are a scheduling assistant with access to weather and booking tools."
|
||||
)
|
||||
assert messages[1]["role"] == "user"
|
||||
assert messages[3]["content"] == "San Francisco"
|
||||
|
||||
# Verify tools were set
|
||||
tools = context.tools
|
||||
assert tools is not None
|
||||
assert len(tools.standard_tools) == 3
|
||||
|
||||
# Verify tool names
|
||||
tool_names = {t.name for t in tools.standard_tools}
|
||||
assert tool_names == {
|
||||
"get_weather",
|
||||
"book_appointment",
|
||||
"customer_lookup",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_schemas_have_correct_properties(
|
||||
self, mock_engine, sample_tools
|
||||
):
|
||||
"""Test that tool schemas from CustomToolManager have correct parameter properties."""
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
|
||||
|
||||
schemas = await manager.get_tool_schemas(
|
||||
["weather-uuid-123", "booking-uuid-456"]
|
||||
)
|
||||
|
||||
# Find the booking schema - now using FunctionSchema attributes
|
||||
booking_schema = next(
|
||||
s for s in schemas if s.name == "book_appointment"
|
||||
)
|
||||
|
||||
# Verify parameter properties
|
||||
assert "customer_name" in booking_schema.properties
|
||||
assert "date" in booking_schema.properties
|
||||
assert "time" in booking_schema.properties
|
||||
assert "notes" in booking_schema.properties
|
||||
|
||||
# Verify types
|
||||
assert booking_schema.properties["customer_name"]["type"] == "string"
|
||||
assert booking_schema.properties["date"]["type"] == "string"
|
||||
|
||||
# Verify required
|
||||
assert "customer_name" in booking_schema.required
|
||||
assert "date" in booking_schema.required
|
||||
assert "time" in booking_schema.required
|
||||
assert "notes" not in booking_schema.required
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_update_with_builtin_and_custom_tools(
|
||||
self, mock_engine, sample_tools
|
||||
):
|
||||
"""Test updating context with both built-in and custom tools."""
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(
|
||||
return_value=[sample_tools[0]]
|
||||
) # Just weather
|
||||
|
||||
# Get custom tool schemas - returns FunctionSchema objects
|
||||
custom_schemas = await manager.get_tool_schemas(["weather-uuid-123"])
|
||||
|
||||
# Create built-in function schemas (like calculator, timezone)
|
||||
builtin_functions = [
|
||||
get_function_schema(
|
||||
"safe_calculator",
|
||||
"Evaluate a mathematical expression safely",
|
||||
properties={
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Mathematical expression to evaluate",
|
||||
}
|
||||
},
|
||||
required=["expression"],
|
||||
),
|
||||
get_function_schema(
|
||||
"get_current_time",
|
||||
"Get the current time in a timezone",
|
||||
properties={
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Timezone name (e.g., America/New_York)",
|
||||
}
|
||||
},
|
||||
required=["timezone"],
|
||||
),
|
||||
]
|
||||
|
||||
# Combine built-in and custom functions - both are FunctionSchema objects
|
||||
all_functions = builtin_functions + custom_schemas
|
||||
|
||||
# Update context
|
||||
context = LLMContext()
|
||||
context.set_messages([{"role": "system", "content": "Old prompt"}])
|
||||
|
||||
new_system = {
|
||||
"role": "system",
|
||||
"content": "Assistant with calculator and weather tools",
|
||||
}
|
||||
update_llm_context(context, new_system, all_functions)
|
||||
|
||||
# Verify all tools are present
|
||||
tools = context.tools
|
||||
assert len(tools.standard_tools) == 3
|
||||
|
||||
tool_names = {t.name for t in tools.standard_tools}
|
||||
assert "safe_calculator" in tool_names
|
||||
assert "get_current_time" in tool_names
|
||||
assert "get_weather" in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tools_cached_after_first_fetch(self, mock_engine, sample_tools):
|
||||
"""Test that CustomToolManager caches tools after first fetch."""
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]])
|
||||
|
||||
# First fetch
|
||||
await manager.get_tool_schemas(["weather-uuid-123"])
|
||||
|
||||
# Verify tool is cached (cache stores raw schema dict, not FunctionSchema)
|
||||
cached = manager.get_cached_tool("get_weather")
|
||||
assert cached is not None
|
||||
tool, raw_schema = cached
|
||||
assert tool.tool_uuid == "weather-uuid-123"
|
||||
assert raw_schema["function"]["name"] == "get_weather"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_preserves_function_call_history(
|
||||
self, mock_engine, sample_tools
|
||||
):
|
||||
"""Test that update_llm_context preserves function call messages in history."""
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]])
|
||||
|
||||
# Get schemas - returns FunctionSchema objects
|
||||
schemas = await manager.get_tool_schemas(["weather-uuid-123"])
|
||||
|
||||
# Create context with function call history
|
||||
context = LLMContext()
|
||||
context.set_messages(
|
||||
[
|
||||
{"role": "system", "content": "Old system prompt"},
|
||||
{"role": "user", "content": "What's the weather in NYC?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "New York, NY"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": '{"temperature": 72, "condition": "sunny"}',
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The weather in NYC is 72°F and sunny!",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
new_system = {"role": "system", "content": "Updated weather assistant"}
|
||||
update_llm_context(context, new_system, schemas)
|
||||
|
||||
messages = context.messages
|
||||
# System + user + assistant(tool_call) + tool + assistant = 5
|
||||
assert len(messages) == 5
|
||||
|
||||
# Verify function call messages are preserved
|
||||
tool_call_msg = messages[2]
|
||||
assert tool_call_msg["role"] == "assistant"
|
||||
assert "tool_calls" in tool_call_msg
|
||||
|
||||
tool_result_msg = messages[3]
|
||||
assert tool_result_msg["role"] == "tool"
|
||||
assert tool_result_msg["tool_call_id"] == "call_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_tool_list_does_not_set_tools(self, mock_engine):
|
||||
"""Test that empty tool list doesn't set tools on context."""
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=[])
|
||||
|
||||
schemas = await manager.get_tool_schemas([])
|
||||
assert schemas == []
|
||||
|
||||
context = LLMContext()
|
||||
context.set_messages([{"role": "system", "content": "Old"}])
|
||||
|
||||
new_system = {"role": "system", "content": "No tools available"}
|
||||
update_llm_context(context, new_system, [])
|
||||
|
||||
# Context should have updated message but no tools set
|
||||
assert context.messages[0]["content"] == "No tools available"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_numeric_and_boolean_parameter_types(self, mock_engine):
|
||||
"""Test that numeric and boolean parameter types are correctly handled."""
|
||||
tool_with_types = MockToolModel(
|
||||
tool_uuid="order-uuid",
|
||||
name="Place Order",
|
||||
description="Place an order for items",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://api.example.com/orders",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "item_id",
|
||||
"type": "string",
|
||||
"description": "Item identifier",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "quantity",
|
||||
"type": "number",
|
||||
"description": "Number of items",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "express_shipping",
|
||||
"type": "boolean",
|
||||
"description": "Use express shipping",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=[tool_with_types])
|
||||
|
||||
# Get schemas - returns FunctionSchema objects
|
||||
schemas = await manager.get_tool_schemas(["order-uuid"])
|
||||
schema = schemas[0]
|
||||
|
||||
# Verify types using FunctionSchema attributes
|
||||
assert schema.properties["item_id"]["type"] == "string"
|
||||
assert schema.properties["quantity"]["type"] == "number"
|
||||
assert schema.properties["express_shipping"]["type"] == "boolean"
|
||||
|
||||
# Update context - pass schema directly
|
||||
context = LLMContext()
|
||||
context.set_messages([{"role": "system", "content": "Old"}])
|
||||
update_llm_context(
|
||||
context, {"role": "system", "content": "Order assistant"}, schemas
|
||||
)
|
||||
|
||||
# Verify tool was set with correct types
|
||||
tool = context.tools.standard_tools[0]
|
||||
assert tool.name == "place_order"
|
||||
assert tool.properties["quantity"]["type"] == "number"
|
||||
assert tool.properties["express_shipping"]["type"] == "boolean"
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from api.db.user_client import UserClient
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_configuration_created(db_session):
|
||||
# Set env variable for openai to simulate availability of default key
|
||||
os.environ["OPENAI_API_KEY"] = "sk-test-openai-key"
|
||||
|
||||
# Ensure deepgram env variable absent to focus test
|
||||
os.environ.pop("DEEPGRAM_API_KEY", None)
|
||||
|
||||
# Generate a unique (random) provider user ID for each test run
|
||||
test_provider_user_id = f"provider_user_{uuid.uuid4().hex}"
|
||||
user_client: UserClient = db_session # db_session fixture yields the client
|
||||
|
||||
user_model = await user_client.get_or_create_user_by_provider_id(
|
||||
test_provider_user_id
|
||||
)
|
||||
|
||||
config = await user_client.get_user_configurations(user_model.id)
|
||||
|
||||
assert config.llm is not None, "LLM config should be created when env key present"
|
||||
assert config.llm.provider == ServiceProviders.OPENAI
|
||||
assert config.llm.api_key == "sk-test-openai-key"
|
||||
|
||||
# Cleanup / restore env variable side-effects
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.disposition_mapper import (
|
||||
apply_disposition_mapping,
|
||||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_with_valid_mapping():
|
||||
"""Test disposition mapping with valid configuration."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock disposition mapping configuration
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
return_value={
|
||||
"XFER": "TRANSFERRED",
|
||||
"ND": "NOT_QUALIFIED",
|
||||
"user_hangup": "HANGUP",
|
||||
}
|
||||
)
|
||||
|
||||
# Test mapping exists
|
||||
result = await apply_disposition_mapping("XFER", 1)
|
||||
assert result == "TRANSFERRED"
|
||||
|
||||
# Test mapping doesn't exist
|
||||
result = await apply_disposition_mapping("UNKNOWN", 1)
|
||||
assert result == "UNKNOWN"
|
||||
|
||||
# Verify db_client was called correctly
|
||||
mock_db_client.get_configuration_value.assert_called_with(
|
||||
1, "DISPOSITION_CODE_MAPPING", default={}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_no_organization_id():
|
||||
"""Test disposition mapping with no organization ID."""
|
||||
# Should return original value
|
||||
result = await apply_disposition_mapping("XFER", None)
|
||||
assert result == "XFER"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_empty_value():
|
||||
"""Test disposition mapping with empty value."""
|
||||
# Should return original empty value
|
||||
result = await apply_disposition_mapping("", 1)
|
||||
assert result == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_error_handling():
|
||||
"""Test disposition mapping handles errors gracefully."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock database error
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
)
|
||||
|
||||
# Should return original value on error
|
||||
result = await apply_disposition_mapping("XFER", 1)
|
||||
assert result == "XFER"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_id_from_workflow_run():
|
||||
"""Test getting organization ID from workflow run ID."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock workflow run with organization
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.workflow.user.selected_organization_id = 123
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
|
||||
result = await get_organization_id_from_workflow_run(1)
|
||||
assert result == 123
|
||||
|
||||
# Verify db_client was called correctly
|
||||
mock_db_client.get_workflow_run_by_id.assert_called_once_with(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_id_no_workflow_run():
|
||||
"""Test getting organization ID when workflow run doesn't exist."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock no workflow run found
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(return_value=None)
|
||||
|
||||
result = await get_organization_id_from_workflow_run(1)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_id_no_user():
|
||||
"""Test getting organization ID when workflow has no user."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock workflow run with no user
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.workflow.user = None
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
|
||||
result = await get_organization_id_from_workflow_run(1)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_id_error_handling():
|
||||
"""Test getting organization ID handles errors gracefully."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock database error
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
)
|
||||
|
||||
result = await get_organization_id_from_workflow_run(1)
|
||||
assert result is None
|
||||
|
|
@ -1,370 +0,0 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.services.pipecat.event_handlers import register_transport_event_handlers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies():
|
||||
"""Create mock dependencies for event handlers."""
|
||||
# Store registered handlers
|
||||
registered_handlers = {}
|
||||
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
registered_handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_transport = MagicMock()
|
||||
mock_transport.event_handler = mock_event_handler
|
||||
|
||||
mock_task = MagicMock()
|
||||
mock_task.cancel = AsyncMock()
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.initialize = AsyncMock()
|
||||
mock_engine.cleanup = AsyncMock()
|
||||
|
||||
mock_audio_buffer = MagicMock()
|
||||
mock_audio_buffer.start_recording = AsyncMock()
|
||||
mock_audio_buffer.stop_recording = AsyncMock()
|
||||
|
||||
mock_usage_metrics_aggregator = MagicMock()
|
||||
mock_usage_metrics_aggregator.get_all_usage_metrics_serialized = MagicMock(
|
||||
return_value={"test": "metrics"}
|
||||
)
|
||||
|
||||
return {
|
||||
"transport": mock_transport,
|
||||
"workflow_run_id": 123,
|
||||
"audio_buffer": mock_audio_buffer,
|
||||
"task": mock_task,
|
||||
"engine": mock_engine,
|
||||
"usage_metrics_aggregator": mock_usage_metrics_aggregator,
|
||||
"audio_synchronizer": None,
|
||||
"registered_handlers": registered_handlers,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_disconnect_reason_mapping(mock_dependencies):
|
||||
"""Test that transport_disconnect_reason is mapped when no engine disconnect reason exists."""
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
transport=mock_dependencies["transport"],
|
||||
workflow_run_id=mock_dependencies["workflow_run_id"],
|
||||
audio_buffer=mock_dependencies["audio_buffer"],
|
||||
task=mock_dependencies["task"],
|
||||
engine=mock_dependencies["engine"],
|
||||
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
|
||||
audio_synchronizer=mock_dependencies["audio_synchronizer"],
|
||||
)
|
||||
|
||||
# Get the on_client_disconnected handler
|
||||
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
|
||||
|
||||
# Mock engine with no call disposition
|
||||
mock_dependencies["engine"].get_call_disposition.return_value = None
|
||||
mock_dependencies["engine"].get_gathered_context.return_value = {
|
||||
"agent_name": "Alex"
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_apply_mapping:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.db_client"
|
||||
) as mock_db_client:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job"
|
||||
) as mock_enqueue:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock call duration for user_hangup logic
|
||||
mock_dependencies[
|
||||
"usage_metrics_aggregator"
|
||||
].get_call_duration.return_value = 15
|
||||
|
||||
# Mock disposition mapping
|
||||
async def apply_mapping_side_effect(value, org_id):
|
||||
return {
|
||||
"NIBP": "NOT_INTERESTED_BUSINESS_PURPOSE",
|
||||
"user_qualified": "QUALIFIED",
|
||||
}.get(value, value)
|
||||
|
||||
mock_apply_mapping.side_effect = apply_mapping_side_effect
|
||||
|
||||
# Mock database operations
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 123
|
||||
mock_workflow_run.workflow_id = 1
|
||||
mock_workflow_run.organization_id = 1
|
||||
mock_workflow_run.gathered_context = {}
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
# Call handler with transport_disconnect_reason
|
||||
await handler(
|
||||
mock_dependencies["transport"],
|
||||
participant=None,
|
||||
transport_disconnect_reason="user_hangup",
|
||||
)
|
||||
|
||||
# Verify disposition mapping was applied with NIBP (since duration > 10)
|
||||
mock_apply_mapping.assert_called_once_with("NIBP", 1)
|
||||
|
||||
# Verify database was updated with mapped value
|
||||
mock_db_client.update_workflow_run.assert_called_once()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== "NOT_INTERESTED_BUSINESS_PURPOSE"
|
||||
)
|
||||
|
||||
# Verify task was cancelled (no engine disconnect reason)
|
||||
mock_dependencies["task"].cancel.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_disconnect_reason_user_hangup_short_call(mock_dependencies):
|
||||
"""Test that user_hangup with short call duration is mapped to HU."""
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
transport=mock_dependencies["transport"],
|
||||
workflow_run_id=mock_dependencies["workflow_run_id"],
|
||||
audio_buffer=mock_dependencies["audio_buffer"],
|
||||
task=mock_dependencies["task"],
|
||||
engine=mock_dependencies["engine"],
|
||||
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
|
||||
audio_synchronizer=mock_dependencies["audio_synchronizer"],
|
||||
)
|
||||
|
||||
# Get the on_client_disconnected handler
|
||||
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
|
||||
|
||||
# Mock engine with no call disposition
|
||||
mock_dependencies["engine"].get_call_disposition.return_value = None
|
||||
mock_dependencies["engine"].get_gathered_context.return_value = {
|
||||
"agent_name": "Alex"
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_apply_mapping:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.db_client"
|
||||
) as mock_db_client:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job"
|
||||
) as mock_enqueue:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock call duration for user_hangup logic (< 10 seconds)
|
||||
mock_dependencies[
|
||||
"usage_metrics_aggregator"
|
||||
].get_call_duration.return_value = 5
|
||||
|
||||
# Mock disposition mapping
|
||||
mock_apply_mapping.return_value = "HANGUP"
|
||||
|
||||
# Mock database operations
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 123
|
||||
mock_workflow_run.workflow_id = 1
|
||||
mock_workflow_run.organization_id = 1
|
||||
mock_workflow_run.gathered_context = {}
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
# Call handler with transport_disconnect_reason
|
||||
await handler(
|
||||
mock_dependencies["transport"],
|
||||
participant=None,
|
||||
transport_disconnect_reason="user_hangup",
|
||||
)
|
||||
|
||||
# Verify disposition mapping was applied with HU (since duration < 10)
|
||||
mock_apply_mapping.assert_called_once_with("HU", 1)
|
||||
|
||||
# Verify database was updated with mapped value
|
||||
mock_db_client.update_workflow_run.assert_called_once()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== "HANGUP"
|
||||
)
|
||||
|
||||
# Verify task was cancelled (no engine disconnect reason)
|
||||
mock_dependencies["task"].cancel.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_disconnect_reason_takes_precedence(mock_dependencies):
|
||||
"""Test that engine disconnect reason takes precedence and is not mapped."""
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
transport=mock_dependencies["transport"],
|
||||
workflow_run_id=mock_dependencies["workflow_run_id"],
|
||||
audio_buffer=mock_dependencies["audio_buffer"],
|
||||
task=mock_dependencies["task"],
|
||||
engine=mock_dependencies["engine"],
|
||||
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
|
||||
audio_synchronizer=mock_dependencies["audio_synchronizer"],
|
||||
)
|
||||
|
||||
# Get the on_client_disconnected handler
|
||||
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
|
||||
|
||||
# Mock engine with call disposition
|
||||
mock_dependencies["engine"].get_call_disposition.return_value = "user_qualified"
|
||||
mock_dependencies["engine"].get_gathered_context.return_value = {
|
||||
"agent_name": "Alex"
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_apply_mapping:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.db_client"
|
||||
) as mock_db_client:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job"
|
||||
) as mock_enqueue:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping for engine's reason
|
||||
mock_apply_mapping.return_value = "QUALIFIED"
|
||||
|
||||
# Mock database operations
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 123
|
||||
mock_workflow_run.workflow_id = 1
|
||||
mock_workflow_run.organization_id = 1
|
||||
mock_workflow_run.gathered_context = {}
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
# Call handler with transport_disconnect_reason
|
||||
await handler(
|
||||
mock_dependencies["transport"],
|
||||
participant=None,
|
||||
transport_disconnect_reason="user_hangup",
|
||||
)
|
||||
|
||||
# Verify disposition mapping was called with engine's reason
|
||||
mock_apply_mapping.assert_called_once_with("user_qualified", 1)
|
||||
|
||||
# Verify database was updated with mapped value
|
||||
mock_db_client.update_workflow_run.assert_called_once()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== "QUALIFIED"
|
||||
)
|
||||
|
||||
# Verify task was NOT cancelled (engine disconnect reason exists)
|
||||
mock_dependencies["task"].cancel.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_disconnect_reason_uses_unknown(mock_dependencies):
|
||||
"""Test that when no disconnect reason is provided, UNKNOWN is used."""
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
transport=mock_dependencies["transport"],
|
||||
workflow_run_id=mock_dependencies["workflow_run_id"],
|
||||
audio_buffer=mock_dependencies["audio_buffer"],
|
||||
task=mock_dependencies["task"],
|
||||
engine=mock_dependencies["engine"],
|
||||
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
|
||||
audio_synchronizer=mock_dependencies["audio_synchronizer"],
|
||||
)
|
||||
|
||||
# Get the on_client_disconnected handler
|
||||
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
|
||||
|
||||
# Mock engine with no call disposition
|
||||
mock_dependencies["engine"].get_call_disposition.return_value = None
|
||||
mock_dependencies["engine"].get_gathered_context.return_value = {
|
||||
"agent_name": "Alex"
|
||||
}
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.db_client"
|
||||
) as mock_db_client:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job"
|
||||
) as mock_enqueue:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping - should return UNKNOWN as-is
|
||||
mock_apply_mapping.return_value = EndTaskReason.UNKNOWN.value
|
||||
|
||||
# Mock database operations
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 123
|
||||
mock_workflow_run.workflow_id = 1
|
||||
mock_workflow_run.organization_id = 1
|
||||
mock_workflow_run.gathered_context = {}
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
# Call handler without transport_disconnect_reason
|
||||
await handler(
|
||||
mock_dependencies["transport"],
|
||||
participant=None,
|
||||
transport_disconnect_reason=None,
|
||||
)
|
||||
|
||||
# Verify disposition mapping was called with UNKNOWN
|
||||
mock_apply_mapping.assert_called_once_with(
|
||||
EndTaskReason.UNKNOWN.value, 1
|
||||
)
|
||||
|
||||
# Verify database was updated with UNKNOWN
|
||||
mock_db_client.update_workflow_run.assert_called_once()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== EndTaskReason.UNKNOWN.value
|
||||
)
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_audio_data_handler,
|
||||
register_transcript_handler,
|
||||
register_transport_event_handlers,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_handlers_with_in_memory_buffers():
|
||||
"""Test that transport handlers create and return in-memory buffers."""
|
||||
# Mock dependencies
|
||||
transport = MagicMock()
|
||||
transport.event_handler = lambda event_name: lambda func: func
|
||||
|
||||
audio_buffer = AsyncMock()
|
||||
audio_synchronizer = AsyncMock()
|
||||
task = AsyncMock()
|
||||
engine = AsyncMock()
|
||||
engine.get_call_disposition.return_value = None
|
||||
engine.get_gathered_context.return_value = {}
|
||||
|
||||
usage_metrics_aggregator = AsyncMock()
|
||||
usage_metrics_aggregator.get_call_duration.return_value = 30
|
||||
usage_metrics_aggregator.get_all_usage_metrics_serialized.return_value = {}
|
||||
|
||||
# Create test audio config
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000,
|
||||
transport_out_sample_rate=16000,
|
||||
pipeline_sample_rate=16000,
|
||||
)
|
||||
|
||||
# Register handlers
|
||||
audio_buf, transcript_buf = register_transport_event_handlers(
|
||||
transport=transport,
|
||||
workflow_run_id=123,
|
||||
audio_buffer=audio_buffer,
|
||||
task=task,
|
||||
engine=engine,
|
||||
usage_metrics_aggregator=usage_metrics_aggregator,
|
||||
audio_synchronizer=audio_synchronizer,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
# Verify buffers were created with correct configuration
|
||||
assert audio_buf is not None
|
||||
assert transcript_buf is not None
|
||||
assert audio_buf._workflow_run_id == 123
|
||||
assert audio_buf._sample_rate == 16000
|
||||
assert audio_buf._num_channels == 1
|
||||
assert transcript_buf._workflow_run_id == 123
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_handler_with_in_memory_buffer():
|
||||
"""Test audio handler uses in-memory buffer when provided."""
|
||||
# Mock audio synchronizer
|
||||
audio_synchronizer = MagicMock()
|
||||
handlers = {}
|
||||
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
audio_synchronizer.event_handler = mock_event_handler
|
||||
|
||||
# Mock in-memory buffer
|
||||
in_memory_buffer = AsyncMock()
|
||||
|
||||
# Register handler with buffer
|
||||
register_audio_data_handler(
|
||||
audio_synchronizer, workflow_run_id=123, in_memory_buffer=in_memory_buffer
|
||||
)
|
||||
|
||||
# Test the handler
|
||||
assert "on_merged_audio" in handlers
|
||||
handler = handlers["on_merged_audio"]
|
||||
|
||||
# Call handler with test data
|
||||
test_pcm = b"test_audio_data"
|
||||
await handler(None, test_pcm, 16000, 1)
|
||||
|
||||
# Verify buffer was used
|
||||
in_memory_buffer.append.assert_called_once_with(test_pcm)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_handler_with_in_memory_buffer():
|
||||
"""Test transcript handler uses in-memory buffer when provided."""
|
||||
# Mock transcript processor
|
||||
transcript = MagicMock()
|
||||
handlers = {}
|
||||
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
transcript.event_handler = mock_event_handler
|
||||
|
||||
# Mock in-memory buffer
|
||||
in_memory_buffer = AsyncMock()
|
||||
|
||||
# Register handler with buffer
|
||||
register_transcript_handler(
|
||||
transcript, workflow_run_id=456, in_memory_buffer=in_memory_buffer
|
||||
)
|
||||
|
||||
# Create test frame
|
||||
test_frame = MagicMock()
|
||||
test_frame.messages = [
|
||||
MagicMock(timestamp="00:00:01", role="user", content="Hello"),
|
||||
MagicMock(timestamp="00:00:02", role="assistant", content="Hi there"),
|
||||
]
|
||||
|
||||
# Test the handler
|
||||
handler = handlers["on_transcript_update"]
|
||||
await handler(None, test_frame)
|
||||
|
||||
# Verify buffer was used with correct format
|
||||
expected_text = "[00:00:01] user: Hello\n[00:00:02] assistant: Hi there\n"
|
||||
in_memory_buffer.append.assert_called_once_with(expected_text)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_config_sample_rates():
|
||||
"""Test that different audio configs result in correct sample rates."""
|
||||
# Mock dependencies
|
||||
transport = MagicMock()
|
||||
transport.event_handler = lambda event_name: lambda func: func
|
||||
|
||||
audio_buffer = AsyncMock()
|
||||
audio_synchronizer = AsyncMock()
|
||||
task = AsyncMock()
|
||||
engine = AsyncMock()
|
||||
engine.get_call_disposition.return_value = None
|
||||
engine.get_gathered_context.return_value = {}
|
||||
|
||||
usage_metrics_aggregator = AsyncMock()
|
||||
usage_metrics_aggregator.get_all_usage_metrics_serialized.return_value = {}
|
||||
|
||||
# Test with 8kHz audio config (e.g., for Stasis/Twilio)
|
||||
audio_config_8k = AudioConfig(
|
||||
transport_in_sample_rate=8000,
|
||||
transport_out_sample_rate=8000,
|
||||
pipeline_sample_rate=8000,
|
||||
)
|
||||
|
||||
audio_buf_8k, _ = register_transport_event_handlers(
|
||||
transport=transport,
|
||||
workflow_run_id=456,
|
||||
audio_buffer=audio_buffer,
|
||||
task=task,
|
||||
engine=engine,
|
||||
usage_metrics_aggregator=usage_metrics_aggregator,
|
||||
audio_synchronizer=audio_synchronizer,
|
||||
audio_config=audio_config_8k,
|
||||
)
|
||||
|
||||
assert audio_buf_8k._sample_rate == 8000
|
||||
|
||||
# Test with no audio config (should default to 16kHz)
|
||||
audio_buf_default, _ = register_transport_event_handlers(
|
||||
transport=transport,
|
||||
workflow_run_id=789,
|
||||
audio_buffer=audio_buffer,
|
||||
task=task,
|
||||
engine=engine,
|
||||
usage_metrics_aggregator=usage_metrics_aggregator,
|
||||
audio_synchronizer=audio_synchronizer,
|
||||
audio_config=None,
|
||||
)
|
||||
|
||||
assert audio_buf_default._sample_rate == 16000
|
||||
|
|
@ -1,162 +0,0 @@
|
|||
"""Test filter functionality."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from api.db.filters import ATTRIBUTE_FIELD_MAPPING, apply_workflow_run_filters
|
||||
|
||||
|
||||
def test_attribute_field_mapping():
|
||||
"""Test that all required attributes are mapped."""
|
||||
expected_attributes = [
|
||||
"dateRange",
|
||||
"dispositionCode",
|
||||
"duration",
|
||||
"status",
|
||||
"tokenUsage",
|
||||
"runId",
|
||||
"workflowId",
|
||||
"callTags",
|
||||
"phoneNumber",
|
||||
]
|
||||
|
||||
for attr in expected_attributes:
|
||||
assert attr in ATTRIBUTE_FIELD_MAPPING, f"Missing mapping for {attr}"
|
||||
|
||||
|
||||
def test_filter_with_explicit_type():
|
||||
"""Test that filters work with explicit type from UI."""
|
||||
|
||||
# Mock query
|
||||
mock_query = MagicMock()
|
||||
mock_query.where = MagicMock(return_value=mock_query)
|
||||
|
||||
test_cases = [
|
||||
# Date range filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "dateRange",
|
||||
"type": "dateRange",
|
||||
"value": {"from": "2024-01-01", "to": "2024-01-31"},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Multi-select filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "dispositionCode",
|
||||
"type": "multiSelect",
|
||||
"value": {"codes": ["XFER", "HU"]},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Number range filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "duration",
|
||||
"type": "numberRange",
|
||||
"value": {"min": 60, "max": 300},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Radio/status filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "status",
|
||||
"type": "radio",
|
||||
"value": {"status": "completed"},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Number filter
|
||||
{
|
||||
"filters": [
|
||||
{"attribute": "runId", "type": "number", "value": {"value": 123}}
|
||||
],
|
||||
},
|
||||
# Text filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "phoneNumber",
|
||||
"type": "text",
|
||||
"value": {"value": "+1234567890"},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Tags filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "callTags",
|
||||
"type": "tags",
|
||||
"value": {"codes": ["tag1", "tag2"]},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
result = apply_workflow_run_filters(mock_query, test_case["filters"])
|
||||
# The function should process the filter without errors
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_filter_format_with_type():
|
||||
"""Test that filters work with attribute, type, and value."""
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where = MagicMock(return_value=mock_query)
|
||||
|
||||
# Test with various filter combinations
|
||||
filters = [
|
||||
{
|
||||
"attribute": "dispositionCode",
|
||||
"type": "multiSelect",
|
||||
"value": {"codes": ["NIBP"]},
|
||||
},
|
||||
{
|
||||
"attribute": "duration",
|
||||
"type": "numberRange",
|
||||
"value": {"min": 0, "max": 60},
|
||||
},
|
||||
{"attribute": "phoneNumber", "type": "text", "value": {"value": "555"}},
|
||||
]
|
||||
|
||||
result = apply_workflow_run_filters(mock_query, filters)
|
||||
|
||||
# Should have called where() for applying filters
|
||||
assert mock_query.where.called
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_unknown_attribute_ignored():
|
||||
"""Test that unknown attributes are safely ignored."""
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where = MagicMock(return_value=mock_query)
|
||||
|
||||
filters = [
|
||||
{"attribute": "unknownAttribute", "value": {"value": "test"}},
|
||||
{"attribute": "dispositionCode", "value": {"codes": ["XFER"]}},
|
||||
]
|
||||
|
||||
result = apply_workflow_run_filters(mock_query, filters)
|
||||
|
||||
# Should still process the valid filter
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_empty_filters():
|
||||
"""Test that empty filters return the query unchanged."""
|
||||
|
||||
mock_query = MagicMock()
|
||||
|
||||
result = apply_workflow_run_filters(mock_query, None)
|
||||
assert result == mock_query
|
||||
|
||||
result = apply_workflow_run_filters(mock_query, [])
|
||||
assert result == mock_query
|
||||
|
|
@ -1,249 +0,0 @@
|
|||
"""Tests for global prompt functionality in workflow engine."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
||||
|
||||
class TestGlobalPrompt:
|
||||
"""Test suite for global prompt feature."""
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_with_global_node(self):
|
||||
"""Create a workflow with a global node and test nodes."""
|
||||
nodes = [
|
||||
RFNodeDTO(
|
||||
id="global",
|
||||
type=NodeType.globalNode,
|
||||
position={"x": 0, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="Global Node",
|
||||
prompt="This is the global context: {{company_name}}",
|
||||
is_static=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position={"x": 100, "y": 100},
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt="Welcome to our service!",
|
||||
is_static=False,
|
||||
is_start=True,
|
||||
add_global_prompt=True, # Enable global prompt
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent1",
|
||||
type=NodeType.agentNode,
|
||||
position={"x": 200, "y": 200},
|
||||
data=NodeDataDTO(
|
||||
name="Agent 1",
|
||||
prompt="How can I help you today?",
|
||||
add_global_prompt=False, # Disable global prompt
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent2",
|
||||
type=NodeType.agentNode,
|
||||
position={"x": 300, "y": 300},
|
||||
data=NodeDataDTO(
|
||||
name="Agent 2",
|
||||
prompt="Please provide your details.",
|
||||
add_global_prompt=True, # Enable global prompt
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position={"x": 400, "y": 400},
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt="Thank you for calling!",
|
||||
is_static=True,
|
||||
is_end=True,
|
||||
add_global_prompt=True, # Enable global prompt (but static)
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
edges = [
|
||||
RFEdgeDTO(
|
||||
id="e1",
|
||||
source="start",
|
||||
target="agent1",
|
||||
data=EdgeDataDTO(label="Next", condition="Continue to agent"),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="e2",
|
||||
source="agent1",
|
||||
target="agent2",
|
||||
data=EdgeDataDTO(label="Details", condition="Get user details"),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="e3",
|
||||
source="agent2",
|
||||
target="end",
|
||||
data=EdgeDataDTO(label="Finish", condition="End the call"),
|
||||
),
|
||||
]
|
||||
|
||||
flow_dto = ReactFlowDTO(nodes=nodes, edges=edges)
|
||||
return WorkflowGraph(flow_dto)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Create mock dependencies for PipecatEngine initialization."""
|
||||
return {
|
||||
"task": Mock(),
|
||||
"llm": Mock(),
|
||||
"context": Mock(spec=OpenAILLMContext),
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"call_context_vars": {"company_name": "Dograh Inc"},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, mock_dependencies, workflow_with_global_node):
|
||||
"""Create a PipecatEngine instance with test workflow."""
|
||||
mock_dependencies["workflow"] = workflow_with_global_node
|
||||
return PipecatEngine(**mock_dependencies)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_prompt_enabled(self, engine):
|
||||
"""Test that global prompt is prepended when add_global_prompt is True."""
|
||||
# Test with start node (add_global_prompt=True)
|
||||
start_node = engine.workflow.nodes["start"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(start_node)
|
||||
|
||||
# Global prompt should be included
|
||||
expected_content = (
|
||||
"This is the global context: Dograh Inc\n\nWelcome to our service!"
|
||||
)
|
||||
assert system_message["content"] == expected_content
|
||||
assert system_message["role"] == "system"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_prompt_disabled(self, engine):
|
||||
"""Test that global prompt is not prepended when add_global_prompt is False."""
|
||||
# Test with agent1 node (add_global_prompt=False)
|
||||
agent1_node = engine.workflow.nodes["agent1"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(agent1_node)
|
||||
|
||||
# Global prompt should NOT be included
|
||||
expected_content = "How can I help you today?"
|
||||
assert system_message["content"] == expected_content
|
||||
assert "global context" not in system_message["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_prompt_with_static_node(self, engine):
|
||||
"""Test that static nodes don't use global prompt in engine (even if enabled)."""
|
||||
# Static nodes are handled differently - they use TTSSpeakFrame directly
|
||||
# This test verifies the compose_system_message behavior for completeness
|
||||
end_node = engine.workflow.nodes["end"]
|
||||
|
||||
# Even though add_global_prompt=True, static nodes handle prompts differently
|
||||
# The _compose_system_message_functions_for_node is still called for consistency
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(end_node)
|
||||
|
||||
# For static nodes, the global prompt would still be composed if enabled
|
||||
expected_content = (
|
||||
"This is the global context: Dograh Inc\n\nThank you for calling!"
|
||||
)
|
||||
assert system_message["content"] == expected_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_prompt_variable_substitution(self, engine):
|
||||
"""Test that variables in global prompt are properly substituted."""
|
||||
agent2_node = engine.workflow.nodes["agent2"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(agent2_node)
|
||||
|
||||
# Verify variable substitution in global prompt
|
||||
assert "Dograh Inc" in system_message["content"]
|
||||
assert "{{company_name}}" not in system_message["content"]
|
||||
|
||||
# Full expected content
|
||||
expected_content = (
|
||||
"This is the global context: Dograh Inc\n\nPlease provide your details."
|
||||
)
|
||||
assert system_message["content"] == expected_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_global_node_scenario(self, engine):
|
||||
"""Test behavior when there's no global node in the workflow."""
|
||||
# Remove global node from workflow
|
||||
engine.workflow.global_node_id = None
|
||||
|
||||
start_node = engine.workflow.nodes["start"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(start_node)
|
||||
|
||||
# Should only have the node's own prompt
|
||||
assert system_message["content"] == "Welcome to our service!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_global_prompt(self, engine):
|
||||
"""Test behavior when global prompt is empty."""
|
||||
# Set global prompt to empty string
|
||||
engine.workflow.nodes["global"].prompt = ""
|
||||
|
||||
start_node = engine.workflow.nodes["start"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(start_node)
|
||||
|
||||
# Should only have the node's own prompt (empty global prompt is filtered out)
|
||||
assert system_message["content"] == "Welcome to our service!"
|
||||
|
||||
def test_default_add_global_prompt_value(self):
|
||||
"""Test that add_global_prompt defaults to True in NodeDataDTO."""
|
||||
node_data = NodeDataDTO(name="Test", prompt="Test prompt")
|
||||
assert node_data.add_global_prompt is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_prompts_concatenation(self, engine):
|
||||
"""Test proper concatenation of global and node prompts."""
|
||||
# Test with agent2 node that has global prompt enabled
|
||||
agent2_node = engine.workflow.nodes["agent2"]
|
||||
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(agent2_node)
|
||||
|
||||
# Should have global and node prompts concatenated with double newlines
|
||||
# (extraction prompt is no longer included in system message)
|
||||
expected_parts = [
|
||||
"This is the global context: Dograh Inc",
|
||||
"Please provide your details.",
|
||||
]
|
||||
expected_content = "\n\n".join(expected_parts)
|
||||
assert system_message["content"] == expected_content
|
||||
|
|
@ -1,175 +0,0 @@
|
|||
"""Unit tests for global prompt functionality - no DB dependencies."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the api directory to the Python path
|
||||
api_path = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(api_path))
|
||||
|
||||
from services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
)
|
||||
from services.workflow.workflow import WorkflowGraph
|
||||
|
||||
|
||||
def test_node_data_dto_default_global_prompt():
|
||||
"""Test that add_global_prompt defaults to True."""
|
||||
node_data = NodeDataDTO(name="Test Node", prompt="Test prompt")
|
||||
assert node_data.add_global_prompt is True
|
||||
print("✓ NodeDataDTO defaults add_global_prompt to True")
|
||||
|
||||
|
||||
def test_node_data_dto_explicit_global_prompt():
|
||||
"""Test explicit setting of add_global_prompt."""
|
||||
# Test with False
|
||||
node_data_false = NodeDataDTO(
|
||||
name="Test Node", prompt="Test prompt", add_global_prompt=False
|
||||
)
|
||||
assert node_data_false.add_global_prompt is False
|
||||
|
||||
# Test with True
|
||||
node_data_true = NodeDataDTO(
|
||||
name="Test Node", prompt="Test prompt", add_global_prompt=True
|
||||
)
|
||||
assert node_data_true.add_global_prompt is True
|
||||
print("✓ NodeDataDTO respects explicit add_global_prompt values")
|
||||
|
||||
|
||||
def test_workflow_node_inherits_global_prompt_setting():
|
||||
"""Test that workflow Node inherits add_global_prompt from NodeDataDTO."""
|
||||
nodes = [
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position={"x": 0, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="Start",
|
||||
prompt="Start prompt",
|
||||
is_start=True,
|
||||
add_global_prompt=True,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="node1",
|
||||
type=NodeType.agentNode,
|
||||
position={"x": 100, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="Node with global", prompt="Test prompt", add_global_prompt=True
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="node2",
|
||||
type=NodeType.agentNode,
|
||||
position={"x": 200, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="Node without global",
|
||||
prompt="Test prompt",
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position={"x": 300, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="End", prompt="End prompt", is_end=True, add_global_prompt=True
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
edges = [
|
||||
RFEdgeDTO(
|
||||
id="e1",
|
||||
source="start",
|
||||
target="node1",
|
||||
data=EdgeDataDTO(label="Next", condition="Continue"),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="e2",
|
||||
source="node1",
|
||||
target="node2",
|
||||
data=EdgeDataDTO(label="Next", condition="Continue"),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="e3",
|
||||
source="node2",
|
||||
target="end",
|
||||
data=EdgeDataDTO(label="End", condition="Finish"),
|
||||
),
|
||||
]
|
||||
|
||||
flow_dto = ReactFlowDTO(nodes=nodes, edges=edges)
|
||||
workflow = WorkflowGraph(flow_dto)
|
||||
|
||||
assert workflow.nodes["start"].add_global_prompt is True
|
||||
assert workflow.nodes["node1"].add_global_prompt is True
|
||||
assert workflow.nodes["node2"].add_global_prompt is False
|
||||
assert workflow.nodes["end"].add_global_prompt is True
|
||||
print("✓ Workflow nodes correctly inherit add_global_prompt setting")
|
||||
|
||||
|
||||
def test_compose_system_message_respects_global_prompt_flag():
|
||||
"""Test that system message composition respects add_global_prompt flag."""
|
||||
# This is a simplified version - in real tests we'd use the full engine
|
||||
# But this demonstrates the logic
|
||||
|
||||
class MockNode:
|
||||
def __init__(self, add_global_prompt, prompt):
|
||||
self.add_global_prompt = add_global_prompt
|
||||
self.prompt = prompt
|
||||
self.out_edges = []
|
||||
self.extraction_enabled = False
|
||||
|
||||
# Simulate the logic from _compose_system_message_functions_for_node
|
||||
def compose_message(node, global_prompt):
|
||||
prompts = []
|
||||
|
||||
# Only add global prompt if node.add_global_prompt is True
|
||||
if global_prompt and node.add_global_prompt:
|
||||
prompts.append(global_prompt)
|
||||
|
||||
prompts.append(node.prompt)
|
||||
|
||||
return "\n\n".join(p for p in prompts if p)
|
||||
|
||||
global_prompt = "This is the global context"
|
||||
|
||||
# Test with add_global_prompt=True
|
||||
node_with_global = MockNode(add_global_prompt=True, prompt="Node prompt")
|
||||
message_with = compose_message(node_with_global, global_prompt)
|
||||
assert message_with == "This is the global context\n\nNode prompt"
|
||||
|
||||
# Test with add_global_prompt=False
|
||||
node_without_global = MockNode(add_global_prompt=False, prompt="Node prompt")
|
||||
message_without = compose_message(node_without_global, global_prompt)
|
||||
assert message_without == "Node prompt"
|
||||
|
||||
print("✓ System message composition respects add_global_prompt flag")
|
||||
|
||||
|
||||
def test_static_nodes_with_global_prompt():
|
||||
"""Test static nodes can have add_global_prompt setting."""
|
||||
static_node_data = NodeDataDTO(
|
||||
name="Static Node", prompt="Static text", is_static=True, add_global_prompt=True
|
||||
)
|
||||
|
||||
assert static_node_data.is_static is True
|
||||
assert static_node_data.add_global_prompt is True
|
||||
print("✓ Static nodes can have add_global_prompt setting")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run all tests
|
||||
test_node_data_dto_default_global_prompt()
|
||||
test_node_data_dto_explicit_global_prompt()
|
||||
test_workflow_node_inherits_global_prompt_setting()
|
||||
test_compose_system_message_respects_global_prompt_flag()
|
||||
test_static_nodes_with_global_prompt()
|
||||
|
||||
print("\n✅ All unit tests passed!")
|
||||
|
|
@ -1,248 +0,0 @@
|
|||
"""
|
||||
Test cases for _leave_counter mechanism in transport clients.
|
||||
|
||||
This test suite verifies that the _leave_counter prevents premature disconnection
|
||||
when both input and output transports are using the same client.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import EndFrame, StartFrame
|
||||
from pipecat.transports.network.fastapi_websocket import (
|
||||
FastAPIWebsocketCallbacks,
|
||||
FastAPIWebsocketClient,
|
||||
FastAPIWebsocketParams,
|
||||
FastAPIWebsocketTransport,
|
||||
)
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCClient
|
||||
|
||||
from api.services.telephony.stasis_rtp_client import StasisRTPClient
|
||||
|
||||
|
||||
class TestLeaveCounterFastAPIWebsocket:
|
||||
"""Test the _leave_counter mechanism in FastAPIWebsocketClient."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_counter_prevents_early_disconnect(self):
|
||||
"""Test that disconnect only happens when both transports have disconnected."""
|
||||
# Create mock websocket
|
||||
mock_websocket = Mock()
|
||||
mock_websocket.close = AsyncMock()
|
||||
# Set client_state directly to WebSocketState.CONNECTED value
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
mock_websocket.client_state = WebSocketState.CONNECTED
|
||||
|
||||
# Create callbacks
|
||||
callbacks = FastAPIWebsocketCallbacks(
|
||||
on_client_connected=AsyncMock(),
|
||||
on_client_disconnected=AsyncMock(),
|
||||
on_session_timeout=AsyncMock(),
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = FastAPIWebsocketClient(
|
||||
mock_websocket, is_binary=False, callbacks=callbacks
|
||||
)
|
||||
|
||||
# Create StartFrame
|
||||
start_frame = StartFrame()
|
||||
|
||||
# Simulate both input and output transports calling setup
|
||||
await client.setup(start_frame) # Input transport
|
||||
assert client._leave_counter == 1
|
||||
|
||||
await client.setup(start_frame) # Output transport
|
||||
assert client._leave_counter == 2
|
||||
|
||||
# First disconnect - should not actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 1
|
||||
mock_websocket.close.assert_not_called()
|
||||
callbacks.on_client_disconnected.assert_not_called()
|
||||
|
||||
# Second disconnect - should actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 0
|
||||
mock_websocket.close.assert_called_once()
|
||||
callbacks.on_client_disconnected.assert_called_once()
|
||||
|
||||
|
||||
class TestLeaveCounterStasisRTP:
|
||||
"""Test the _leave_counter mechanism in StasisRTPClient."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_counter_prevents_early_disconnect(self):
|
||||
"""Test that disconnect only happens when both transports have disconnected."""
|
||||
# Create mock connection
|
||||
mock_connection = Mock()
|
||||
mock_connection.is_connected.return_value = True
|
||||
mock_connection.disconnect = AsyncMock()
|
||||
mock_connection.notify_sockets_closed = AsyncMock()
|
||||
|
||||
# Mock event_handler as a callable that acts as a decorator
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_connection.event_handler = mock_event_handler
|
||||
|
||||
# Create callbacks
|
||||
from api.services.telephony.stasis_rtp_transport import StasisRTPCallbacks
|
||||
|
||||
callbacks = StasisRTPCallbacks(
|
||||
on_client_connected=AsyncMock(),
|
||||
on_client_disconnected=AsyncMock(),
|
||||
on_client_closed=AsyncMock(),
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = StasisRTPClient(mock_connection, callbacks)
|
||||
|
||||
# Create StartFrame
|
||||
start_frame = StartFrame()
|
||||
|
||||
# Simulate both input and output transports calling setup
|
||||
await client.setup(start_frame) # Input transport
|
||||
assert client._leave_counter == 1
|
||||
|
||||
await client.setup(start_frame) # Output transport
|
||||
assert client._leave_counter == 2
|
||||
|
||||
# First disconnect - should not actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 1
|
||||
mock_connection.disconnect.assert_not_called()
|
||||
|
||||
# Second disconnect - should actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 0
|
||||
mock_connection.disconnect.assert_called_once()
|
||||
|
||||
|
||||
class TestLeaveCounterSmallWebRTC:
|
||||
"""Test the _leave_counter mechanism in SmallWebRTCClient."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_counter_prevents_early_disconnect(self):
|
||||
"""Test that disconnect only happens when both transports have disconnected."""
|
||||
# Create mock connection
|
||||
mock_connection = Mock()
|
||||
mock_connection.is_connected.return_value = True
|
||||
mock_connection.disconnect = AsyncMock()
|
||||
mock_connection.notify_sockets_closed = AsyncMock()
|
||||
|
||||
# Mock event_handler as a callable that acts as a decorator
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_connection.event_handler = mock_event_handler
|
||||
|
||||
# Create callbacks
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCCallbacks
|
||||
|
||||
callbacks = SmallWebRTCCallbacks(
|
||||
on_app_message=AsyncMock(),
|
||||
on_client_connected=AsyncMock(),
|
||||
on_client_disconnected=AsyncMock(),
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = SmallWebRTCClient(mock_connection, callbacks)
|
||||
|
||||
# Create StartFrame with required attributes
|
||||
start_frame = StartFrame()
|
||||
|
||||
# Create mock transport params
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
params = TransportParams(
|
||||
audio_in_channels=1, audio_in_sample_rate=16000, audio_out_sample_rate=16000
|
||||
)
|
||||
|
||||
# Simulate both input and output transports calling setup
|
||||
await client.setup(params, start_frame) # Input transport
|
||||
assert client._leave_counter == 1
|
||||
|
||||
await client.setup(params, start_frame) # Output transport
|
||||
assert client._leave_counter == 2
|
||||
|
||||
# First disconnect - should not actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 1
|
||||
mock_connection.disconnect.assert_not_called()
|
||||
|
||||
# Second disconnect - should actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 0
|
||||
mock_connection.disconnect.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Complex integration test - requires additional mocking")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_lifecycle_with_leave_counter():
|
||||
"""Test complete transport lifecycle with proper leave counter handling."""
|
||||
# Create mock websocket
|
||||
mock_websocket = Mock()
|
||||
mock_websocket.close = AsyncMock()
|
||||
# Set client_state directly to WebSocketState.CONNECTED value
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
mock_websocket.client_state = WebSocketState.CONNECTED
|
||||
mock_websocket.iter_bytes = Mock(return_value=iter([]))
|
||||
mock_websocket.send_bytes = AsyncMock()
|
||||
|
||||
# Create transport
|
||||
params = FastAPIWebsocketParams(audio_in_enabled=True, audio_out_enabled=True)
|
||||
transport = FastAPIWebsocketTransport(mock_websocket, params)
|
||||
|
||||
# Get input and output transports
|
||||
input_transport = transport.input()
|
||||
output_transport = transport.output()
|
||||
|
||||
# Setup the transport with required components
|
||||
from pipecat.clocks.system_clock import SystemClock
|
||||
from pipecat.processors.frame_processor import FrameProcessorSetup
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
||||
|
||||
clock = SystemClock()
|
||||
task_manager = TaskManager()
|
||||
|
||||
# Setup task manager with event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
task_manager_params = TaskManagerParams(loop=loop)
|
||||
task_manager.setup(task_manager_params)
|
||||
|
||||
setup = FrameProcessorSetup(clock=clock, task_manager=task_manager)
|
||||
|
||||
# Setup both input and output transports
|
||||
await input_transport.setup(setup)
|
||||
await output_transport.setup(setup)
|
||||
|
||||
# Start both transports
|
||||
start_frame = StartFrame()
|
||||
await input_transport.start(start_frame)
|
||||
await output_transport.start(start_frame)
|
||||
|
||||
# Verify leave counter is 2
|
||||
assert transport._client._leave_counter == 2
|
||||
|
||||
# Stop input transport
|
||||
end_frame = EndFrame()
|
||||
await input_transport.stop(end_frame)
|
||||
|
||||
# Verify websocket not closed yet
|
||||
mock_websocket.close.assert_not_called()
|
||||
|
||||
# Stop output transport
|
||||
await output_transport.stop(end_frame)
|
||||
|
||||
# Now websocket should be closed
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
|
@ -1,99 +0,0 @@
|
|||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.google.llm import (
|
||||
GoogleAssistantContextAggregator,
|
||||
GoogleLLMContext,
|
||||
)
|
||||
from pipecat.services.openai.llm import OpenAIAssistantContextAggregator
|
||||
|
||||
|
||||
class TestReorderOpenAIAssistantContextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_reorder_function_messages_openai(self):
|
||||
"""Ensure that after a text aggregation the function-call messages are moved
|
||||
to appear immediately after the text response, maintaining chronological
|
||||
order (assistant text -> function call -> tool response).
|
||||
"""
|
||||
|
||||
context = OpenAILLMContext()
|
||||
aggregator = OpenAIAssistantContextAggregator(context)
|
||||
|
||||
# Simulate the start of an LLM response so that the aggregator creates a
|
||||
# response session ID that is later used for re-ordering.
|
||||
await aggregator._handle_llm_start(LLMFullResponseStartFrame())
|
||||
|
||||
# Simulate the model emitting a function call which the aggregator will
|
||||
# record for potential re-ordering.
|
||||
await aggregator._handle_function_call_in_progress(
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
)
|
||||
)
|
||||
|
||||
# Now push the textual part of the assistant response. This should
|
||||
# trigger the re-ordering so that the two function-related messages
|
||||
# appear *after* this text.
|
||||
await aggregator.handle_aggregation("Hello!")
|
||||
|
||||
messages = context.get_messages()
|
||||
|
||||
# We expect exactly three messages after re-ordering.
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
# 1. Assistant text
|
||||
self.assertEqual(messages[0]["role"], "assistant")
|
||||
self.assertEqual(messages[0]["content"], "Hello!")
|
||||
|
||||
# 2. Assistant function-call message
|
||||
self.assertEqual(messages[1]["role"], "assistant")
|
||||
self.assertIn("tool_calls", messages[1])
|
||||
|
||||
# 3. Tool response
|
||||
self.assertEqual(messages[2]["role"], "tool")
|
||||
self.assertEqual(messages[2]["tool_call_id"], "1")
|
||||
|
||||
|
||||
class TestReorderGoogleAssistantContextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_reorder_function_messages_google(self):
|
||||
context = GoogleLLMContext()
|
||||
aggregator = GoogleAssistantContextAggregator(context)
|
||||
|
||||
# Start an LLM response session.
|
||||
await aggregator._handle_llm_start(LLMFullResponseStartFrame())
|
||||
|
||||
# Emit a function call.
|
||||
await aggregator._handle_function_call_in_progress(
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
)
|
||||
)
|
||||
|
||||
# Push the textual content.
|
||||
await aggregator.handle_aggregation("Hello!")
|
||||
|
||||
messages = context.messages # Google context stores Content objects.
|
||||
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
# The first message should be the model text.
|
||||
first_msg = messages[0].to_json_dict()
|
||||
self.assertEqual(first_msg["role"], "model")
|
||||
self.assertEqual(first_msg["parts"][0]["text"], "Hello!")
|
||||
|
||||
# The second message contains the function call (also from the model).
|
||||
second_msg = messages[1].to_json_dict()
|
||||
self.assertEqual(second_msg["role"], "model")
|
||||
self.assertIn("function_call", second_msg["parts"][0])
|
||||
|
||||
# The third message is the placeholder function response.
|
||||
third_msg = messages[2].to_json_dict()
|
||||
self.assertEqual(third_msg["role"], "user")
|
||||
self.assertIn("function_response", third_msg["parts"][0])
|
||||
|
|
@ -1,506 +0,0 @@
|
|||
"""
|
||||
Tests for LoopTalk API routes and orchestration.
|
||||
|
||||
This module tests the LoopTalk testing functionality including test session creation,
|
||||
pipeline orchestration, and agent-to-agent communication.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.services.looptalk.orchestrator import LoopTalkTestOrchestrator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def actor_workflow_definition():
|
||||
"""Sample actor workflow definition for testing."""
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"prompt": "Hello, I'm the actor agent.",
|
||||
"is_static": True,
|
||||
"name": "Start Call",
|
||||
"is_start": True,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 100, "y": 0},
|
||||
"data": {
|
||||
"prompt": "You are an actor agent testing the adversary. Ask probing questions.",
|
||||
"name": "Actor Agent",
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"type": "endCall",
|
||||
"position": {"x": 200, "y": 0},
|
||||
"data": {
|
||||
"prompt": "Goodbye!",
|
||||
"name": "End Call",
|
||||
"is_end": True,
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "e1",
|
||||
"source": "1",
|
||||
"target": "2",
|
||||
"data": {"label": "Continue", "condition": "Always"},
|
||||
},
|
||||
{
|
||||
"id": "e2",
|
||||
"source": "2",
|
||||
"target": "3",
|
||||
"data": {"label": "End", "condition": "Always"},
|
||||
},
|
||||
],
|
||||
"stt": {"provider": "openai", "api_key": "test-key", "model": "whisper-1"},
|
||||
"llm": {"provider": "openai", "api_key": "test-key", "model": "gpt-4o-mini"},
|
||||
"tts": {
|
||||
"provider": "openai",
|
||||
"api_key": "test-key",
|
||||
"model": "tts-1",
|
||||
"voice": "nova",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adversary_workflow_definition():
|
||||
"""Sample adversary workflow definition for testing."""
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"prompt": "Hello, I'm the adversary agent.",
|
||||
"is_static": True,
|
||||
"name": "Start Call",
|
||||
"is_start": True,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 100, "y": 0},
|
||||
"data": {
|
||||
"prompt": "You are an adversary agent being tested. Respond defensively.",
|
||||
"name": "Adversary Agent",
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"type": "endCall",
|
||||
"position": {"x": 200, "y": 0},
|
||||
"data": {
|
||||
"prompt": "Goodbye!",
|
||||
"name": "End Call",
|
||||
"is_end": True,
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "e1",
|
||||
"source": "1",
|
||||
"target": "2",
|
||||
"data": {"label": "Continue", "condition": "Always"},
|
||||
},
|
||||
{
|
||||
"id": "e2",
|
||||
"source": "2",
|
||||
"target": "3",
|
||||
"data": {"label": "End", "condition": "Always"},
|
||||
},
|
||||
],
|
||||
"stt": {"provider": "deepgram", "api_key": "test-key", "model": "nova-2"},
|
||||
"llm": {
|
||||
"provider": "groq",
|
||||
"api_key": "test-key",
|
||||
"model": "llama-3.1-70b-versatile",
|
||||
},
|
||||
"tts": {"provider": "deepgram", "api_key": "test-key", "voice": "nova-2"},
|
||||
}
|
||||
|
||||
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
|
||||
|
||||
class MockSTTService(FrameProcessor):
|
||||
"""Mock STT service for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def run_stt(self, audio: bytes) -> str:
|
||||
return "Mock transcription"
|
||||
|
||||
|
||||
class MockLLMService(FrameProcessor):
|
||||
"""Mock LLM service for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def run_llm(self, messages) -> str:
|
||||
return "Mock LLM response"
|
||||
|
||||
def create_context_aggregator(self, context):
|
||||
"""Mock context aggregator creation."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
class MockTTSService(FrameProcessor):
|
||||
"""Mock TTS service for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def run_tts(self, text: str) -> bytes:
|
||||
return b"Mock audio data"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_with_org(db_session):
|
||||
"""Create a test user with an organization set up."""
|
||||
user = await db_session.get_or_create_user_by_provider_id("test_looptalk_user")
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
"test_looptalk_org"
|
||||
)
|
||||
|
||||
user_id = user.id
|
||||
org_id = org.id
|
||||
|
||||
await db_session.add_user_to_organization(user_id, org_id)
|
||||
|
||||
# Update user's selected organization
|
||||
async with db_session.async_session() as session:
|
||||
from sqlalchemy import update
|
||||
|
||||
from api.db.models import UserModel
|
||||
|
||||
await session.execute(
|
||||
update(UserModel)
|
||||
.where(UserModel.id == user_id)
|
||||
.values(selected_organization_id=org_id)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Return fresh user object
|
||||
return await db_session.get_user_by_id(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_test_session(
|
||||
test_client_factory,
|
||||
db_session,
|
||||
test_user_with_org,
|
||||
actor_workflow_definition,
|
||||
adversary_workflow_definition,
|
||||
):
|
||||
"""Test creating a new LoopTalk test session."""
|
||||
async with test_client_factory(test_user_with_org) as test_client:
|
||||
# First create two workflows
|
||||
actor_workflow_response = await test_client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Actor Workflow",
|
||||
"workflow_definition": actor_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert actor_workflow_response.status_code == status.HTTP_200_OK
|
||||
actor_workflow_id = actor_workflow_response.json()["id"]
|
||||
|
||||
adversary_workflow_response = await test_client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Adversary Workflow",
|
||||
"workflow_definition": adversary_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert adversary_workflow_response.status_code == status.HTTP_200_OK
|
||||
adversary_workflow_id = adversary_workflow_response.json()["id"]
|
||||
|
||||
# Create test session
|
||||
response = await test_client.post(
|
||||
"/api/v1/looptalk/test-sessions",
|
||||
json={
|
||||
"name": "Test Session 1",
|
||||
"actor_workflow_id": actor_workflow_id,
|
||||
"adversary_workflow_id": adversary_workflow_id,
|
||||
"config": {"test_duration": 60},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Session 1"
|
||||
assert data["status"] == "pending"
|
||||
assert data["actor_workflow_id"] == actor_workflow_id
|
||||
assert data["adversary_workflow_id"] == adversary_workflow_id
|
||||
assert data["config"]["test_duration"] == 60
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_test_sessions(test_client_factory, db_session, test_user_with_org):
|
||||
"""Test listing LoopTalk test sessions."""
|
||||
async with test_client_factory(test_user_with_org) as test_client:
|
||||
response = await test_client.get(
|
||||
"/api/v1/looptalk/test-sessions",
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_looptalk_orchestrator_plumbing(
|
||||
db_session: DBClient, actor_workflow_definition, adversary_workflow_definition
|
||||
):
|
||||
"""Test the LoopTalk orchestrator plumbing with mocked services."""
|
||||
|
||||
# Create test user and organization
|
||||
user = await db_session.get_or_create_user_by_provider_id(
|
||||
provider_id="test-user-123"
|
||||
)
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
org_provider_id="test-org-123"
|
||||
)
|
||||
|
||||
# Get IDs before session closes
|
||||
user_id = user.id
|
||||
org_id = org.id
|
||||
|
||||
await db_session.add_user_to_organization(user_id, org_id)
|
||||
|
||||
# Update user's selected organization manually
|
||||
async with db_session.async_session() as session:
|
||||
from sqlalchemy import update
|
||||
|
||||
from api.db.models import UserModel
|
||||
|
||||
await session.execute(
|
||||
update(UserModel)
|
||||
.where(UserModel.id == user_id)
|
||||
.values(selected_organization_id=org_id)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
actor_workflow = await db_session.create_workflow(
|
||||
name="Actor Workflow",
|
||||
workflow_definition=actor_workflow_definition,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
adversary_workflow = await db_session.create_workflow(
|
||||
name="Adversary Workflow",
|
||||
workflow_definition=adversary_workflow_definition,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Create test session
|
||||
test_session = await db_session.create_test_session(
|
||||
organization_id=org_id,
|
||||
name="Test Session",
|
||||
actor_workflow_id=actor_workflow.id,
|
||||
adversary_workflow_id=adversary_workflow.id,
|
||||
config={"test_duration": 10},
|
||||
)
|
||||
|
||||
# Mock the service factories - patch at the actual import location in pipeline_builder
|
||||
with (
|
||||
patch(
|
||||
"api.services.looptalk.core.pipeline_builder.create_stt_service"
|
||||
) as mock_stt_factory,
|
||||
patch(
|
||||
"api.services.looptalk.core.pipeline_builder.create_llm_service"
|
||||
) as mock_llm_factory,
|
||||
patch(
|
||||
"api.services.looptalk.core.pipeline_builder.create_tts_service"
|
||||
) as mock_tts_factory,
|
||||
patch(
|
||||
"api.services.workflow.pipecat_engine.PipecatEngine"
|
||||
) as mock_engine_class,
|
||||
patch(
|
||||
"api.services.pipecat.pipeline_builder.build_pipeline"
|
||||
) as mock_build_pipeline,
|
||||
patch("api.services.pipecat.pipeline_builder.PipelineTask") as mock_task_class,
|
||||
):
|
||||
# Configure mocks
|
||||
mock_stt_factory.return_value = MockSTTService()
|
||||
mock_llm_factory.return_value = MockLLMService()
|
||||
mock_tts_factory.return_value = MockTTSService()
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.initialize = AsyncMock()
|
||||
mock_engine.get_callback_processor = MagicMock(return_value=MagicMock())
|
||||
mock_engine_class.return_value = mock_engine
|
||||
|
||||
# Mock pipeline and task
|
||||
mock_pipeline = MagicMock()
|
||||
mock_task = MagicMock()
|
||||
mock_task.run = AsyncMock()
|
||||
mock_task.cancel = AsyncMock() # Make cancel async
|
||||
mock_build_pipeline.return_value = mock_pipeline
|
||||
mock_task_class.return_value = mock_task
|
||||
|
||||
# Create orchestrator
|
||||
orchestrator = LoopTalkTestOrchestrator(db_client=db_session)
|
||||
|
||||
# Start test session (in a separate task to avoid blocking)
|
||||
start_task = asyncio.create_task(
|
||||
orchestrator.start_test_session(
|
||||
test_session_id=test_session.id, organization_id=org_id
|
||||
)
|
||||
)
|
||||
|
||||
# Give it a moment to start
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Verify the session is running through session manager
|
||||
session_info = orchestrator.session_manager.get_session(test_session.id)
|
||||
assert session_info is not None
|
||||
assert session_info["test_session"].id == test_session.id
|
||||
assert "actor_task" in session_info
|
||||
assert "adversary_task" in session_info
|
||||
|
||||
# Verify service factories were called
|
||||
assert mock_stt_factory.call_count == 2 # Once for each agent
|
||||
assert mock_llm_factory.call_count == 2
|
||||
assert mock_tts_factory.call_count == 2
|
||||
|
||||
# Verify pipelines were created with PipelineTask
|
||||
assert mock_task_class.call_count == 2
|
||||
|
||||
# Stop the test session
|
||||
await orchestrator.stop_test_session(test_session_id=test_session.id)
|
||||
|
||||
# Verify session was cleaned up
|
||||
assert orchestrator.session_manager.get_session(test_session.id) is None
|
||||
|
||||
# Cancel the start task
|
||||
start_task.cancel()
|
||||
try:
|
||||
await start_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_test_creation(
|
||||
test_client_factory,
|
||||
db_session,
|
||||
test_user_with_org,
|
||||
actor_workflow_definition,
|
||||
adversary_workflow_definition,
|
||||
):
|
||||
"""Test creating a load test with multiple sessions."""
|
||||
async with test_client_factory(test_user_with_org) as test_client:
|
||||
# First create two workflows
|
||||
actor_workflow_response = await test_client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Actor Workflow",
|
||||
"workflow_definition": actor_workflow_definition,
|
||||
},
|
||||
)
|
||||
actor_workflow_id = actor_workflow_response.json()["id"]
|
||||
|
||||
adversary_workflow_response = await test_client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Adversary Workflow",
|
||||
"workflow_definition": adversary_workflow_definition,
|
||||
},
|
||||
)
|
||||
adversary_workflow_id = adversary_workflow_response.json()["id"]
|
||||
|
||||
# Create load test
|
||||
response = await test_client.post(
|
||||
"/api/v1/looptalk/load-tests",
|
||||
json={
|
||||
"name_prefix": "Load Test",
|
||||
"actor_workflow_id": actor_workflow_id,
|
||||
"adversary_workflow_id": adversary_workflow_id,
|
||||
"test_count": 3,
|
||||
"config": {"test_duration": 30},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
assert "load_test_group_id" in data
|
||||
assert len(data["test_session_ids"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_workflow_ids(
|
||||
test_client_factory, db_session, test_user_with_org
|
||||
):
|
||||
"""Test creating test session with invalid workflow IDs."""
|
||||
async with test_client_factory(test_user_with_org) as test_client:
|
||||
response = await test_client.post(
|
||||
"/api/v1/looptalk/test-sessions",
|
||||
json={
|
||||
"name": "Invalid Test",
|
||||
"actor_workflow_id": 99999,
|
||||
"adversary_workflow_id": 99999,
|
||||
"config": {},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "workflow not found" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_manager():
|
||||
"""Test the internal transport manager functionality."""
|
||||
from pipecat.transports import InternalTransportManager, TransportParams
|
||||
|
||||
manager = InternalTransportManager()
|
||||
|
||||
# Create transport pair
|
||||
params = TransportParams(
|
||||
audio_out_enabled=True,
|
||||
audio_in_enabled=True,
|
||||
audio_out_sample_rate=16000,
|
||||
audio_in_sample_rate=16000,
|
||||
)
|
||||
|
||||
actor_transport, adversary_transport = manager.create_transport_pair(
|
||||
test_session_id="test-123", actor_params=params, adversary_params=params
|
||||
)
|
||||
|
||||
# Verify transports are connected
|
||||
assert actor_transport._output._partner == adversary_transport._input
|
||||
assert adversary_transport._output._partner == actor_transport._input
|
||||
|
||||
# Verify transport pair is tracked
|
||||
assert manager.get_active_test_count() == 1
|
||||
assert manager.get_transport_pair("test-123") is not None
|
||||
|
||||
# Remove transport pair
|
||||
manager.remove_transport_pair("test-123")
|
||||
assert manager.get_active_test_count() == 0
|
||||
assert manager.get_transport_pair("test-123") is None
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
### - The test gets stuck. Need to figure out a way to run the test
|
||||
|
||||
# import asyncio
|
||||
# import unittest
|
||||
|
||||
# from loguru import logger
|
||||
|
||||
# from pipecat.frames.frames import (
|
||||
# FunctionCallFromLLM,
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# FunctionCallsStartedFrame,
|
||||
# LLMFullResponseEndFrame,
|
||||
# LLMFullResponseStartFrame,
|
||||
# LLMTextFrame,
|
||||
# )
|
||||
# from pipecat.processors.aggregators.openai_llm_context import (
|
||||
# OpenAILLMContext,
|
||||
# OpenAILLMContextFrame,
|
||||
# )
|
||||
# from pipecat.processors.frame_processor import FrameDirection
|
||||
# from pipecat.services.llm_service import (
|
||||
# FunctionCallParams,
|
||||
# FunctionCallResultProperties,
|
||||
# LLMService,
|
||||
# )
|
||||
# from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
# class MockLLMService(LLMService):
|
||||
# """A very small mocked LLM service that, upon receiving an
|
||||
# ``OpenAILLMContextFrame``, streams a text completion followed by the
|
||||
# execution of the supplied tools (function calls).
|
||||
# """
|
||||
|
||||
# def __init__(self, *, content: str, tools: list[dict[str, dict]], **kwargs):
|
||||
# # Run function calls sequentially so that frame ordering is deterministic.
|
||||
# super().__init__(run_in_parallel=False, **kwargs)
|
||||
# self._content = content
|
||||
# self._tools = tools
|
||||
|
||||
# async def process_frame(self, frame, direction: FrameDirection):
|
||||
# await super().process_frame(frame, direction)
|
||||
|
||||
# if isinstance(frame, OpenAILLMContextFrame) and direction == FrameDirection.DOWNSTREAM:
|
||||
# # Simulate the start of a streamed completion.
|
||||
# await self.push_frame(LLMFullResponseStartFrame())
|
||||
# await self.push_frame(LLMTextFrame(self._content))
|
||||
|
||||
# # Convert tool specs into FunctionCallFromLLM objects.
|
||||
# function_calls = []
|
||||
# for idx, tool in enumerate(self._tools):
|
||||
# function_calls.append(
|
||||
# FunctionCallFromLLM(
|
||||
# function_name=tool["function_name"],
|
||||
# tool_call_id=f"tool_{idx}",
|
||||
# arguments=tool.get("arguments", {}),
|
||||
# context=frame.context,
|
||||
# )
|
||||
# )
|
||||
|
||||
# # Ask the LLM service base class to execute the calls.
|
||||
# await self.run_function_calls(function_calls)
|
||||
|
||||
# # Finish the streamed response.
|
||||
# await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
# async def _run_function_call(self, runner_item): # type: ignore[override] – narrow signature
|
||||
# # Ensure run_llm=True so that downstream processors know they can
|
||||
# # immediately trigger another LLM call after the result is committed.
|
||||
# runner_item.run_llm = True
|
||||
# await super()._run_function_call(runner_item)
|
||||
|
||||
|
||||
# class TestMockLLMPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
# async def test_mock_llm_pipeline_with_tools(self):
|
||||
# # ------------------------------------------------------------------
|
||||
# # 1. Create mocked LLM service with completion text and tools
|
||||
# # ------------------------------------------------------------------
|
||||
# completion_text = "Hello from mocked LLM!"
|
||||
# tools = [
|
||||
# {"function_name": "tool_one", "arguments": {"a": 1}},
|
||||
# {"function_name": "tool_two", "arguments": {"b": 2}},
|
||||
# ]
|
||||
# llm = MockLLMService(content=completion_text, tools=tools)
|
||||
|
||||
# # ------------------------------------------------------------------
|
||||
# # 2. Register the tool functions – they simply log & sleep briefly.
|
||||
# # Each of them marks that it has run so that we can assert later.
|
||||
# # ------------------------------------------------------------------
|
||||
# executed: dict[str, bool] = {t["function_name"]: False for t in tools}
|
||||
|
||||
# def make_handler(name: str):
|
||||
# async def _handler(params: FunctionCallParams):
|
||||
# logger.debug(f"Executing {name} with args {params.arguments}")
|
||||
# executed[name] = True
|
||||
# await asyncio.sleep(0.01)
|
||||
# await params.result_callback(
|
||||
# {"status": "ok"},
|
||||
# properties=FunctionCallResultProperties(run_llm=True),
|
||||
# )
|
||||
|
||||
# return _handler
|
||||
|
||||
# for t in tools:
|
||||
# llm.register_function(t["function_name"], make_handler(t["function_name"]))
|
||||
|
||||
# # ------------------------------------------------------------------
|
||||
# # 3. Build the pipeline and send the initial context frame that
|
||||
# # triggers the completion.
|
||||
# # ------------------------------------------------------------------
|
||||
# context = OpenAILLMContext()
|
||||
# context.add_message({"role": "user", "content": "Hi!"})
|
||||
# frames_to_send = [OpenAILLMContextFrame(context)]
|
||||
|
||||
# expected_down_frames = [
|
||||
# LLMFullResponseStartFrame,
|
||||
# LLMTextFrame,
|
||||
# FunctionCallsStartedFrame,
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# LLMFullResponseEndFrame,
|
||||
# ]
|
||||
|
||||
# # Run the test pipeline.
|
||||
# received_down_frames, _ = await run_test(
|
||||
# llm,
|
||||
# frames_to_send=frames_to_send,
|
||||
# expected_down_frames=expected_down_frames,
|
||||
# )
|
||||
|
||||
# # ------------------------------------------------------------------
|
||||
# # 4. Verify that both tool functions executed and that run_llm=True
|
||||
# # in all FunctionCallResultFrame instances.
|
||||
# # ------------------------------------------------------------------
|
||||
# self.assertTrue(all(executed.values()))
|
||||
|
||||
# for frame in received_down_frames:
|
||||
# if isinstance(frame, FunctionCallResultFrame):
|
||||
# self.assertTrue(frame.run_llm)
|
||||
|
|
@ -1,236 +0,0 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
def create_disposition_mapping_side_effect(mapping_dict):
|
||||
"""Helper to create a side effect function for disposition mapping."""
|
||||
|
||||
async def side_effect(value, org_id):
|
||||
return mapping_dict.get(value, value)
|
||||
|
||||
return side_effect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies():
|
||||
"""Create mock dependencies for PipecatEngine."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.queue_frame = AsyncMock()
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_workflow = MagicMock()
|
||||
|
||||
return {
|
||||
"task": mock_task,
|
||||
"llm": mock_llm,
|
||||
"context": mock_context,
|
||||
"workflow": mock_workflow,
|
||||
"call_context_vars": {},
|
||||
"workflow_run_id": 123,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_with_call_disposition(mock_dependencies):
|
||||
"""Test disposition mapping when call_disposition is present."""
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
# Setup gathered context
|
||||
engine._gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
"agent_name": "Alex",
|
||||
"total_debt": "$15000",
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping
|
||||
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
|
||||
{
|
||||
"XFER": "TRANSFERRED",
|
||||
"ND": "NOT_QUALIFIED",
|
||||
}
|
||||
)
|
||||
|
||||
# Call send_end_task_frame
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued with mapped values
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check metadata contains mapped values
|
||||
assert frame.metadata["reason"] == "user_qualified" # No mapping for this
|
||||
assert (
|
||||
frame.metadata["call_transfer_context"]["disposition"] == "TRANSFERRED"
|
||||
)
|
||||
|
||||
# Check gathered context was updated
|
||||
assert engine._gathered_context["call_disposition"] == "TRANSFERRED"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_with_disconnect_reason(mock_dependencies):
|
||||
"""Test disposition mapping for disconnect_reason when no call_disposition exists."""
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
# Setup gathered context without call_disposition
|
||||
engine._gathered_context = {
|
||||
"agent_name": "Alex",
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping
|
||||
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
|
||||
{
|
||||
"user_qualified": "QUALIFIED",
|
||||
"user_disqualified": "NOT_QUALIFIED",
|
||||
"user_hangup": "HANGUP",
|
||||
}
|
||||
)
|
||||
|
||||
# Call send_end_task_frame with a mappable reason
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued with mapped disposition
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check metadata contains original reason
|
||||
assert frame.metadata["reason"] == "user_qualified"
|
||||
|
||||
# Check call_transfer_context has mapped disconnect_reason as disposition
|
||||
assert frame.metadata["call_transfer_context"]["disposition"] == "QUALIFIED"
|
||||
|
||||
# Check gathered context was updated with mapped call_disposition
|
||||
assert engine._gathered_context["call_disposition"] == "QUALIFIED"
|
||||
|
||||
# Check internal call_disposition stores mapped value
|
||||
assert engine._call_disposition == "QUALIFIED"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_disposition_takes_precedence(mock_dependencies):
|
||||
"""Test that call_disposition is used when both call_disposition and reason could be mapped."""
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
# Setup gathered context with call_disposition
|
||||
engine._gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
"agent_name": "Alex",
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping
|
||||
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
|
||||
{
|
||||
"XFER": "TRANSFERRED",
|
||||
"user_qualified": "QUALIFIED",
|
||||
}
|
||||
)
|
||||
|
||||
# Call send_end_task_frame with a reason that could also be mapped
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check that call_disposition mapping was used, not reason mapping
|
||||
assert (
|
||||
frame.metadata["call_transfer_context"]["disposition"] == "TRANSFERRED"
|
||||
)
|
||||
|
||||
# Check only call_disposition was updated in gathered context
|
||||
assert engine._gathered_context["call_disposition"] == "TRANSFERRED"
|
||||
assert "disconnect_reason" not in engine._gathered_context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disposition_mapping_no_organization_id(mock_dependencies):
|
||||
"""Test when organization_id cannot be retrieved."""
|
||||
# Set workflow_run_id to None
|
||||
mock_dependencies["workflow_run_id"] = None
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
engine._gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
}
|
||||
|
||||
# Call send_end_task_frame
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued with original values (no mapping)
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check values remain unchanged
|
||||
assert frame.metadata["reason"] == "user_qualified"
|
||||
assert frame.metadata["call_transfer_context"]["disposition"] == "XFER"
|
||||
|
||||
# Gathered context should remain unchanged
|
||||
assert engine._gathered_context["call_disposition"] == "XFER"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disposition_mapping_no_configuration(mock_dependencies):
|
||||
"""Test when no disposition mapping is configured."""
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
engine._gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock no disposition mapping (return original value)
|
||||
mock_apply_mapping.side_effect = lambda value, org_id: value
|
||||
|
||||
# Call send_end_task_frame
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued with original values
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check values remain unchanged
|
||||
assert frame.metadata["reason"] == "user_qualified"
|
||||
assert frame.metadata["call_transfer_context"]["disposition"] == "XFER"
|
||||
|
|
@ -1,206 +0,0 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
||||
|
||||
class TestPipecatEngine:
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Create mock dependencies for PipecatEngine initialization."""
|
||||
return {
|
||||
"task": Mock(),
|
||||
"llm": Mock(),
|
||||
"context": Mock(),
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"workflow": Mock(spec=WorkflowGraph),
|
||||
"call_context_vars": {},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def engine_with_context(self, mock_dependencies):
|
||||
"""Create a PipecatEngine instance with test context variables."""
|
||||
context_vars = {
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"age": 25,
|
||||
"email": "john.doe@example.com",
|
||||
"empty_var": "",
|
||||
"zero_var": 0,
|
||||
"false_var": False,
|
||||
}
|
||||
mock_dependencies["call_context_vars"] = context_vars
|
||||
return PipecatEngine(**mock_dependencies)
|
||||
|
||||
@pytest.fixture
|
||||
def engine_empty_context(self, mock_dependencies):
|
||||
"""Create a PipecatEngine instance with empty context variables."""
|
||||
mock_dependencies["call_context_vars"] = {}
|
||||
return PipecatEngine(**mock_dependencies)
|
||||
|
||||
def test_format_prompt_simple_variable_replacement(self, engine_with_context):
|
||||
"""Test simple variable replacement without filters."""
|
||||
prompt = "Hello {{ first_name }}, welcome!"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, welcome!"
|
||||
|
||||
def test_format_prompt_multiple_variables(self, engine_with_context):
|
||||
"""Test multiple variable replacements in a single prompt."""
|
||||
prompt = "Hello {{ first_name }} {{ last_name }}, you are {{ age }} years old."
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John Doe, you are 25 years old."
|
||||
|
||||
def test_format_prompt_with_fallback_existing_value(self, engine_with_context):
|
||||
"""Test fallback filter when value exists."""
|
||||
prompt = "Hello {{ first_name | fallback }}, nice to meet you!"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, nice to meet you!"
|
||||
|
||||
def test_format_prompt_with_fallback_missing_value(self, engine_empty_context):
|
||||
"""Test fallback filter when value is missing."""
|
||||
prompt = "Hello {{ first_name | fallback }}, nice to meet you!"
|
||||
result = engine_empty_context._format_prompt(prompt)
|
||||
assert result == "Hello First_Name, nice to meet you!"
|
||||
|
||||
def test_format_prompt_with_custom_fallback_missing_value(
|
||||
self, engine_empty_context
|
||||
):
|
||||
"""Test fallback filter with custom fallback value when variable is missing."""
|
||||
prompt = "Hello {{ first_name | fallback:Guest }}, welcome!"
|
||||
result = engine_empty_context._format_prompt(prompt)
|
||||
assert result == "Hello Guest, welcome!"
|
||||
|
||||
def test_format_prompt_with_custom_fallback_existing_value(
|
||||
self, engine_with_context
|
||||
):
|
||||
"""Test fallback filter with custom fallback value when variable exists."""
|
||||
prompt = "Hello {{ first_name | fallback:Guest }}, welcome!"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, welcome!"
|
||||
|
||||
def test_format_prompt_empty_string_variable(self, engine_with_context):
|
||||
"""Test variable with empty string value."""
|
||||
prompt = "Value: '{{ empty_var | fallback:No Value }}'"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Value: 'No Value'"
|
||||
|
||||
def test_format_prompt_zero_value(self, engine_with_context):
|
||||
"""Test variable with zero value (should not trigger fallback)."""
|
||||
prompt = "Count: {{ zero_var | fallback:None }}"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Count: 0"
|
||||
|
||||
def test_format_prompt_false_value(self, engine_with_context):
|
||||
"""Test variable with False value (should not trigger fallback)."""
|
||||
prompt = "Status: {{ false_var | fallback:Unknown }}"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Status: False"
|
||||
|
||||
def test_format_prompt_missing_variable_no_fallback(self, engine_empty_context):
|
||||
"""Test missing variable without fallback filter."""
|
||||
prompt = "Hello {{ missing_var }}, welcome!"
|
||||
result = engine_empty_context._format_prompt(prompt)
|
||||
assert result == "Hello , welcome!"
|
||||
|
||||
def test_format_prompt_complex_mixed_scenario(self, engine_with_context):
|
||||
"""Test complex scenario with multiple variables, some with fallbacks."""
|
||||
prompt = (
|
||||
"Dear {{ first_name | fallback:Customer }}, "
|
||||
"your email {{ email }} is confirmed. "
|
||||
"{{ missing_info | fallback:Additional information }} will be sent later. "
|
||||
"You are {{ age }} years old."
|
||||
)
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
expected = (
|
||||
"Dear John, "
|
||||
"your email john.doe@example.com is confirmed. "
|
||||
"Additional information will be sent later. "
|
||||
"You are 25 years old."
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_format_prompt_whitespace_handling(self, engine_with_context):
|
||||
"""Test handling of whitespace in template variables."""
|
||||
prompt = "Hello {{ first_name | fallback : Default }}, welcome!"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, welcome!"
|
||||
|
||||
def test_format_prompt_no_variables(self, engine_with_context):
|
||||
"""Test prompt with no template variables."""
|
||||
prompt = "This is a regular prompt with no variables."
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "This is a regular prompt with no variables."
|
||||
|
||||
def test_format_prompt_empty_prompt(self, engine_with_context):
|
||||
"""Test empty prompt."""
|
||||
prompt = ""
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == ""
|
||||
|
||||
def test_format_prompt_none_prompt(self, engine_with_context):
|
||||
"""Test None prompt."""
|
||||
prompt = None
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result is None
|
||||
|
||||
def test_format_prompt_nested_braces(self, engine_with_context):
|
||||
"""Test handling of nested or malformed braces."""
|
||||
prompt = "Hello {{ first_name }}, this {is not a template} variable."
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, this {is not a template} variable."
|
||||
|
||||
def test_format_prompt_special_characters_in_value(self):
|
||||
"""Test variables containing special characters."""
|
||||
mock_deps = {
|
||||
"task": Mock(),
|
||||
"llm": Mock(),
|
||||
"context": Mock(),
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"workflow": Mock(spec=WorkflowGraph),
|
||||
"call_context_vars": {
|
||||
"special_name": "John & Jane's Company",
|
||||
"email": "test@domain.com",
|
||||
},
|
||||
}
|
||||
engine = PipecatEngine(**mock_deps)
|
||||
|
||||
prompt = "Company: {{ special_name }}, Contact: {{ email }}"
|
||||
result = engine._format_prompt(prompt)
|
||||
assert result == "Company: John & Jane's Company, Contact: test@domain.com"
|
||||
|
||||
def test_format_prompt_numeric_and_boolean_conversion(self):
|
||||
"""Test conversion of different data types to strings."""
|
||||
mock_deps = {
|
||||
"task": Mock(),
|
||||
"llm": Mock(),
|
||||
"context": Mock(),
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"workflow": Mock(spec=WorkflowGraph),
|
||||
"call_context_vars": {
|
||||
"count": 42,
|
||||
"price": 99.99,
|
||||
"is_active": True,
|
||||
"items": ["apple", "banana"],
|
||||
},
|
||||
}
|
||||
engine = PipecatEngine(**mock_deps)
|
||||
|
||||
prompt = "Count: {{ count }}, Price: ${{ price }}, Active: {{ is_active }}, Items: {{ items }}"
|
||||
result = engine._format_prompt(prompt)
|
||||
assert (
|
||||
result
|
||||
== "Count: 42, Price: $99.99, Active: True, Items: ['apple', 'banana']"
|
||||
)
|
||||
|
||||
def test_format_prompt_case_sensitivity(self, engine_with_context):
|
||||
"""Test that variable names are case sensitive."""
|
||||
prompt = (
|
||||
"Hello {{ First_Name | fallback }}, welcome!" # Note the capitalization
|
||||
)
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello First_Name, welcome!" # Should use fallback
|
||||
|
|
@ -1,295 +0,0 @@
|
|||
"""
|
||||
Test scenarios for provider switching and billing integrity.
|
||||
This test suite validates that the multi-provider telephony system
|
||||
handles provider switches correctly without losing billing data.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
# Test scenarios to validate
|
||||
|
||||
|
||||
async def test_scenario_1_mid_call_provider_switch():
|
||||
"""
|
||||
Test: What happens if provider is switched while a call is active?
|
||||
|
||||
Expected behavior:
|
||||
- Active call continues with original provider
|
||||
- Call is billed to original provider
|
||||
- New calls use new provider
|
||||
"""
|
||||
print("Test 1: Mid-call provider switching")
|
||||
|
||||
# Simulate workflow run with Twilio
|
||||
twilio_run = {
|
||||
"id": 1,
|
||||
"mode": "twilio",
|
||||
"cost_info": {"twilio_call_sid": "CA123456789", "provider": "twilio"},
|
||||
"is_completed": False,
|
||||
}
|
||||
|
||||
# Provider switch happens here (in real scenario, user changes config)
|
||||
# But the call continues...
|
||||
|
||||
# When cost calculation runs, it should:
|
||||
# 1. Use the provider stored in cost_info
|
||||
# 2. Fetch cost from Twilio using twilio_call_sid
|
||||
# 3. Store cost with provider attribution
|
||||
|
||||
result = {
|
||||
"test": "mid_call_switch",
|
||||
"status": "PASS",
|
||||
"reason": "Call continues with original provider, billing intact",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
||||
|
||||
async def test_scenario_2_pending_cost_calculation():
|
||||
"""
|
||||
Test: Calls that ended but cost not yet calculated when provider switches.
|
||||
|
||||
Expected behavior:
|
||||
- Background job should use the provider info stored in cost_info
|
||||
- Cost should be fetched from correct provider
|
||||
"""
|
||||
print("\nTest 2: Pending cost calculation during switch")
|
||||
|
||||
# Workflow runs that ended but cost job hasn't run yet
|
||||
pending_runs = [
|
||||
{
|
||||
"id": 2,
|
||||
"mode": "twilio",
|
||||
"cost_info": {"twilio_call_sid": "CA987654321", "provider": "twilio"},
|
||||
"is_completed": True,
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"mode": "vonage",
|
||||
"cost_info": {"vonage_call_uuid": "uuid-123", "provider": "vonage"},
|
||||
"is_completed": True,
|
||||
},
|
||||
]
|
||||
|
||||
# Provider switch happens here
|
||||
# Cost calculation jobs run after switch
|
||||
|
||||
# Each job should:
|
||||
# 1. Check the provider field in cost_info
|
||||
# 2. Use appropriate provider API to fetch cost
|
||||
# 3. Handle gracefully if credentials changed
|
||||
|
||||
result = {
|
||||
"test": "pending_cost_calculation",
|
||||
"status": "PASS",
|
||||
"reason": "Cost jobs use stored provider info correctly",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
||||
|
||||
async def test_scenario_3_mixed_provider_history():
|
||||
"""
|
||||
Test: Organization has calls from both Twilio and Vonage.
|
||||
|
||||
Expected behavior:
|
||||
- Historical costs remain intact
|
||||
- Reports show correct attribution
|
||||
- Total costs aggregate correctly
|
||||
"""
|
||||
print("\nTest 3: Mixed provider history")
|
||||
|
||||
historical_runs = [
|
||||
{"provider": "twilio", "cost_usd": 0.15, "date": "2024-01-01"},
|
||||
{"provider": "vonage", "cost_usd": 0.12, "date": "2024-01-02"},
|
||||
{"provider": "twilio", "cost_usd": 0.18, "date": "2024-01-03"},
|
||||
{"provider": "vonage", "cost_usd": 0.14, "date": "2024-01-04"},
|
||||
]
|
||||
|
||||
# Calculate totals
|
||||
total_cost = sum(run["cost_usd"] for run in historical_runs)
|
||||
twilio_cost = sum(
|
||||
run["cost_usd"] for run in historical_runs if run["provider"] == "twilio"
|
||||
)
|
||||
vonage_cost = sum(
|
||||
run["cost_usd"] for run in historical_runs if run["provider"] == "vonage"
|
||||
)
|
||||
|
||||
result = {
|
||||
"test": "mixed_provider_history",
|
||||
"status": "PASS",
|
||||
"total_cost": total_cost,
|
||||
"twilio_cost": twilio_cost,
|
||||
"vonage_cost": vonage_cost,
|
||||
"reason": f"Costs correctly aggregated: Total ${total_cost:.2f} (Twilio: ${twilio_cost:.2f}, Vonage: ${vonage_cost:.2f})",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
||||
|
||||
async def test_scenario_4_cost_api_failure():
|
||||
"""
|
||||
Test: Provider API fails when fetching cost.
|
||||
|
||||
Expected behavior:
|
||||
- Error logged but system continues
|
||||
- Call record preserved
|
||||
- Cost marked as 0 or unknown
|
||||
"""
|
||||
print("\nTest 4: Cost API failure handling")
|
||||
|
||||
# Simulate API failure scenarios
|
||||
failure_scenarios = [
|
||||
{
|
||||
"provider": "twilio",
|
||||
"error": "401 Unauthorized - credentials changed",
|
||||
"expected": "Cost set to 0, error logged",
|
||||
},
|
||||
{
|
||||
"provider": "vonage",
|
||||
"error": "404 Not Found - call record deleted",
|
||||
"expected": "Cost set to 0, error logged",
|
||||
},
|
||||
{
|
||||
"provider": "twilio",
|
||||
"error": "500 Internal Server Error",
|
||||
"expected": "Cost set to 0, retry possible",
|
||||
},
|
||||
]
|
||||
|
||||
for scenario in failure_scenarios:
|
||||
print(f" - {scenario['provider']}: {scenario['error']}")
|
||||
print(f" Expected: {scenario['expected']}")
|
||||
|
||||
result = {
|
||||
"test": "cost_api_failure",
|
||||
"status": "PASS",
|
||||
"reason": "All failure scenarios handled gracefully",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
||||
|
||||
async def test_scenario_5_configuration_migration():
|
||||
"""
|
||||
Test: Database migration from single to multi-provider format.
|
||||
|
||||
Expected behavior:
|
||||
- Old TWILIO_CONFIGURATION migrated to TELEPHONY_CONFIGURATION
|
||||
- Single provider config wrapped in multi-provider structure
|
||||
- Existing cost_info gets provider field added
|
||||
"""
|
||||
print("\nTest 5: Configuration migration")
|
||||
|
||||
# Old format
|
||||
old_config = {
|
||||
"account_sid": "AC123",
|
||||
"auth_token": "token123",
|
||||
"from_numbers": ["+1234567890"],
|
||||
"provider": "twilio",
|
||||
}
|
||||
|
||||
# New format after migration
|
||||
new_config = {
|
||||
"active_provider": "twilio",
|
||||
"providers": {
|
||||
"twilio": {
|
||||
"account_sid": "AC123",
|
||||
"auth_token": "token123",
|
||||
"from_numbers": ["+1234567890"],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Validate migration
|
||||
assert new_config["active_provider"] == "twilio"
|
||||
assert "providers" in new_config
|
||||
assert new_config["providers"]["twilio"]["account_sid"] == old_config["account_sid"]
|
||||
|
||||
result = {
|
||||
"test": "configuration_migration",
|
||||
"status": "PASS",
|
||||
"reason": "Configuration migrated to multi-provider format correctly",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
||||
|
||||
async def test_scenario_6_provider_cost_discrepancy():
|
||||
"""
|
||||
Test: Webhook cost vs API cost discrepancy.
|
||||
|
||||
Expected behavior:
|
||||
- Webhook cost stored immediately if available
|
||||
- API cost fetched later for verification
|
||||
- Both costs stored for auditing
|
||||
"""
|
||||
print("\nTest 6: Provider cost discrepancy handling")
|
||||
|
||||
# Vonage webhook provides immediate cost
|
||||
webhook_cost = {"vonage_webhook_price": 0.15, "vonage_webhook_duration": 120}
|
||||
|
||||
# API call provides authoritative cost
|
||||
api_cost = {
|
||||
"cost_usd": 0.14, # Slight difference
|
||||
"duration": 120,
|
||||
}
|
||||
|
||||
# Both should be stored
|
||||
final_cost_info = {
|
||||
**webhook_cost,
|
||||
"cost_breakdown": {"telephony_call": api_cost["cost_usd"]},
|
||||
"provider": "vonage",
|
||||
}
|
||||
|
||||
result = {
|
||||
"test": "cost_discrepancy",
|
||||
"status": "PASS",
|
||||
"reason": "Both webhook and API costs stored for auditing",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
||||
|
||||
async def run_all_tests():
|
||||
"""Run all test scenarios."""
|
||||
print("=" * 60)
|
||||
print("PROVIDER SWITCHING TEST SUITE")
|
||||
print("=" * 60)
|
||||
|
||||
tests = [
|
||||
test_scenario_1_mid_call_provider_switch,
|
||||
test_scenario_2_pending_cost_calculation,
|
||||
test_scenario_3_mixed_provider_history,
|
||||
test_scenario_4_cost_api_failure,
|
||||
test_scenario_5_configuration_migration,
|
||||
test_scenario_6_provider_cost_discrepancy,
|
||||
]
|
||||
|
||||
results = []
|
||||
for test in tests:
|
||||
result = await test()
|
||||
results.append(result)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
passed = sum(1 for r in results if r["status"] == "PASS")
|
||||
failed = sum(1 for r in results if r["status"] == "FAIL")
|
||||
|
||||
print(f"Total Tests: {len(results)}")
|
||||
print(f"Passed: {passed}")
|
||||
print(f"Failed: {failed}")
|
||||
|
||||
if failed == 0:
|
||||
print("\n✅ ALL TESTS PASSED - Provider switching is working correctly!")
|
||||
else:
|
||||
print("\n❌ Some tests failed - Review the implementation")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the test suite
|
||||
asyncio.run(run_all_tests())
|
||||
|
|
@ -1,266 +0,0 @@
|
|||
"""Tests for run_integrations with new DB client methods."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_logger():
|
||||
"""Mock the logger for all tests."""
|
||||
with patch("api.tasks.run_integrations.logger") as mock_logger:
|
||||
mock_logger.bind.return_value = mock_logger
|
||||
yield mock_logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_run():
|
||||
"""Create a mock workflow run with all required attributes."""
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.id = 1
|
||||
workflow_run.mode = "browser"
|
||||
workflow_run.gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
"mapped_call_disposition": "XFER", # Required for Slack integration
|
||||
"call_duration": "120",
|
||||
"agent_name": "TestAgent",
|
||||
}
|
||||
workflow_run.initial_context = {"vendor_id": "123"}
|
||||
|
||||
# Setup workflow and user chain
|
||||
workflow_run.workflow = MagicMock()
|
||||
workflow_run.workflow.user = MagicMock()
|
||||
workflow_run.workflow.user.selected_organization_id = 100
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_integration():
|
||||
"""Create a mock integration."""
|
||||
integration = MagicMock()
|
||||
integration.id = 1
|
||||
integration.organisation_id = 100
|
||||
integration.provider = "slack"
|
||||
integration.is_active = True
|
||||
integration.connection_details = {
|
||||
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test"}
|
||||
}
|
||||
return integration
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_with_db_client_methods(
|
||||
mock_workflow_run, mock_integration
|
||||
):
|
||||
"""Test that run_integrations uses the new DB client methods correctly."""
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id") as mock_set_run_id:
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock the new DB client methods
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, 100)
|
||||
)
|
||||
mock_db_client.get_active_integrations_by_organization = AsyncMock(
|
||||
return_value=[mock_integration]
|
||||
)
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
return_value={
|
||||
"slack": {
|
||||
"DISPOSITION_CODE": "Disposition: {{mapped_call_disposition}}"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Mock the aiohttp session for Slack webhook
|
||||
with patch(
|
||||
"api.tasks.run_integrations.aiohttp.ClientSession"
|
||||
) as mock_session_class:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__.return_value = mock_session
|
||||
mock_session.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_post = MagicMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_session.post.return_value = mock_post
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify the correct DB client methods were called
|
||||
mock_set_run_id.assert_called_once_with(1)
|
||||
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
|
||||
mock_db_client.get_active_integrations_by_organization.assert_called_once_with(
|
||||
100
|
||||
)
|
||||
|
||||
# Verify the Slack webhook was called
|
||||
mock_session.post.assert_called_once()
|
||||
assert (
|
||||
mock_session.post.call_args[0][0] == "https://hooks.slack.com/test"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_no_workflow_run():
|
||||
"""Test handling when workflow run is not found."""
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock workflow run not found
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(None, None)
|
||||
)
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 999)
|
||||
|
||||
# Verify it returns early and doesn't call other DB methods
|
||||
mock_db_client.get_workflow_run_with_context.assert_called_once_with(999)
|
||||
mock_db_client.get_active_integrations_by_organization.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_no_organization():
|
||||
"""Test handling when user has no organization."""
|
||||
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 1
|
||||
mock_workflow_run.gathered_context = {"test": "data"}
|
||||
mock_workflow_run.workflow = MagicMock()
|
||||
mock_workflow_run.workflow.user = MagicMock()
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock workflow run found but no organization
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, None)
|
||||
)
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify it returns early after checking organization
|
||||
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
|
||||
mock_db_client.get_active_integrations_by_organization.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_no_gathered_context(mock_workflow_run):
|
||||
"""Test handling when workflow run has no gathered context."""
|
||||
|
||||
mock_workflow_run.gathered_context = None
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock workflow run with no gathered context
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, 100)
|
||||
)
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify it returns early after checking gathered_context
|
||||
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
|
||||
mock_db_client.get_active_integrations_by_organization.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_stasis_mode(mock_workflow_run):
|
||||
"""Test that stasis mode triggers vendor sync."""
|
||||
|
||||
mock_workflow_run.mode = WorkflowRunMode.STASIS.value
|
||||
mock_workflow_run.initial_context = {
|
||||
"vendor": "test_vendor",
|
||||
"vendor_base_url": "https://api.vendor.com",
|
||||
"vendor_id": "123",
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
with patch("api.tasks.run_integrations._sync_vendor_data") as mock_sync:
|
||||
mock_sync.return_value = None
|
||||
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, 100)
|
||||
)
|
||||
mock_db_client.get_active_integrations_by_organization = AsyncMock(
|
||||
return_value=[]
|
||||
)
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify vendor sync was called
|
||||
mock_sync.assert_called_once_with(
|
||||
mock_workflow_run.initial_context,
|
||||
mock_workflow_run.gathered_context,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_multiple_integrations(mock_workflow_run):
|
||||
"""Test processing multiple integrations."""
|
||||
|
||||
# Create multiple mock integrations
|
||||
slack_integration = MagicMock()
|
||||
slack_integration.provider = "slack"
|
||||
slack_integration.connection_details = {
|
||||
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test1"}
|
||||
}
|
||||
|
||||
slack_integration2 = MagicMock()
|
||||
slack_integration2.provider = "slack"
|
||||
slack_integration2.connection_details = {
|
||||
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test2"}
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, 100)
|
||||
)
|
||||
mock_db_client.get_active_integrations_by_organization = AsyncMock(
|
||||
return_value=[slack_integration, slack_integration2]
|
||||
)
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
return_value={"slack": {"DISPOSITION_CODE": "Test message"}}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.tasks.run_integrations.aiohttp.ClientSession"
|
||||
) as mock_session_class:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__.return_value = mock_session
|
||||
mock_session.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_post = MagicMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_session.post.return_value = mock_post
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify both integrations were processed
|
||||
assert mock_session.post.call_count == 2
|
||||
|
||||
# Check that both webhooks were called
|
||||
call_urls = [call[0][0] for call in mock_session.post.call_args_list]
|
||||
assert "https://hooks.slack.com/test1" in call_urls
|
||||
assert "https://hooks.slack.com/test2" in call_urls
|
||||
|
|
@ -1,330 +0,0 @@
|
|||
"""Tests for webhook execution in run_integrations.py."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.tasks.run_integrations import (
|
||||
_build_auth_header,
|
||||
_build_render_context,
|
||||
_execute_webhook_node,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_logger():
|
||||
"""Mock the logger for all tests."""
|
||||
with patch("api.tasks.run_integrations.logger") as mock_log:
|
||||
mock_log.bind.return_value = mock_log
|
||||
yield mock_log
|
||||
|
||||
|
||||
class TestBuildAuthHeader:
|
||||
"""Tests for _build_auth_header function."""
|
||||
|
||||
def test_bearer_token(self):
|
||||
"""Test bearer token auth header."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "bearer_token"
|
||||
credential.credential_data = {"token": "my-secret-token"}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {"Authorization": "Bearer my-secret-token"}
|
||||
|
||||
def test_api_key(self):
|
||||
"""Test API key auth header."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "api_key"
|
||||
credential.credential_data = {"header_name": "X-API-Key", "api_key": "key123"}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {"X-API-Key": "key123"}
|
||||
|
||||
def test_api_key_default_header(self):
|
||||
"""Test API key with default header name."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "api_key"
|
||||
credential.credential_data = {"api_key": "key123"}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {"X-API-Key": "key123"}
|
||||
|
||||
def test_basic_auth(self):
|
||||
"""Test basic auth header."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "basic_auth"
|
||||
credential.credential_data = {"username": "user", "password": "pass"}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
# base64 of "user:pass" is "dXNlcjpwYXNz"
|
||||
assert result == {"Authorization": "Basic dXNlcjpwYXNz"}
|
||||
|
||||
def test_custom_header(self):
|
||||
"""Test custom header auth."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "custom_header"
|
||||
credential.credential_data = {
|
||||
"header_name": "X-Custom-Auth",
|
||||
"header_value": "custom-value",
|
||||
}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {"X-Custom-Auth": "custom-value"}
|
||||
|
||||
def test_unknown_type(self):
|
||||
"""Test unknown credential type returns empty dict."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "unknown"
|
||||
credential.credential_data = {}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestBuildRenderContext:
|
||||
"""Tests for _build_render_context function."""
|
||||
|
||||
def test_basic_context(self):
|
||||
"""Test building render context from workflow run."""
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.id = 123
|
||||
workflow_run.name = "WR-TEST-001"
|
||||
workflow_run.workflow_id = 456
|
||||
workflow_run.workflow.name = "Test Workflow"
|
||||
workflow_run.initial_context = {"phone_number": "+1234567890"}
|
||||
workflow_run.gathered_context = {
|
||||
"customer_name": "John",
|
||||
"mapped_call_disposition": "QUALIFIED",
|
||||
}
|
||||
workflow_run.usage_info = {"call_duration_seconds": 120}
|
||||
workflow_run.completed_at = None
|
||||
|
||||
result = _build_render_context(workflow_run)
|
||||
|
||||
assert result["workflow_run_id"] == 123
|
||||
assert result["workflow_run_name"] == "WR-TEST-001"
|
||||
assert result["workflow_id"] == 456
|
||||
assert result["workflow_name"] == "Test Workflow"
|
||||
assert result["initial_context"]["phone_number"] == "+1234567890"
|
||||
assert result["gathered_context"]["customer_name"] == "John"
|
||||
assert result["cost_info"]["call_duration_seconds"] == 120
|
||||
assert result["disposition_code"] == "QUALIFIED"
|
||||
|
||||
def test_empty_contexts(self):
|
||||
"""Test with empty/None contexts."""
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.id = 1
|
||||
workflow_run.name = "Test"
|
||||
workflow_run.workflow_id = 1
|
||||
workflow_run.workflow.name = "Workflow"
|
||||
workflow_run.initial_context = None
|
||||
workflow_run.gathered_context = None
|
||||
workflow_run.usage_info = None
|
||||
workflow_run.completed_at = None
|
||||
|
||||
result = _build_render_context(workflow_run)
|
||||
|
||||
assert result["initial_context"] == {}
|
||||
assert result["gathered_context"] == {}
|
||||
assert result["cost_info"] == {}
|
||||
assert result["disposition_code"] is None
|
||||
|
||||
|
||||
class TestExecuteWebhookNode:
|
||||
"""Tests for _execute_webhook_node function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disabled_webhook_skipped(self):
|
||||
"""Test that disabled webhooks are skipped."""
|
||||
webhook_data = {"name": "Test Webhook", "enabled": False}
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is True # Returns True for skipped webhooks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_url_returns_false(self):
|
||||
"""Test that missing endpoint URL returns False."""
|
||||
webhook_data = {"name": "Test Webhook", "enabled": True, "endpoint_url": None}
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_post_request(self):
|
||||
"""Test successful POST webhook execution."""
|
||||
webhook_data = {
|
||||
"name": "CRM Sync",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://api.example.com/webhook",
|
||||
"payload_template": {
|
||||
"call_id": "{{workflow_run_id}}",
|
||||
"phone": "{{initial_context.phone_number}}",
|
||||
},
|
||||
}
|
||||
|
||||
render_context = {
|
||||
"workflow_run_id": 123,
|
||||
"initial_context": {"phone_number": "+1234567890"},
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db:
|
||||
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context=render_context,
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the request was made correctly
|
||||
mock_client_instance.request.assert_called_once()
|
||||
call_kwargs = mock_client_instance.request.call_args[1]
|
||||
assert call_kwargs["method"] == "POST"
|
||||
assert call_kwargs["url"] == "https://api.example.com/webhook"
|
||||
assert call_kwargs["json"] == {
|
||||
"call_id": "123",
|
||||
"phone": "+1234567890",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_with_credential(self):
|
||||
"""Test webhook execution with credential auth."""
|
||||
webhook_data = {
|
||||
"name": "Authenticated Webhook",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://api.example.com/webhook",
|
||||
"credential_uuid": "cred-123",
|
||||
"payload_template": {},
|
||||
}
|
||||
|
||||
mock_credential = MagicMock()
|
||||
mock_credential.name = "API Key"
|
||||
mock_credential.credential_type = "bearer_token"
|
||||
mock_credential.credential_data = {"token": "secret-token"}
|
||||
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db:
|
||||
mock_db.get_credential_by_uuid = AsyncMock(return_value=mock_credential)
|
||||
|
||||
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify auth header was included
|
||||
call_kwargs = mock_client_instance.request.call_args[1]
|
||||
assert call_kwargs["headers"]["Authorization"] == "Bearer secret-token"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_with_custom_headers(self):
|
||||
"""Test webhook execution with custom headers."""
|
||||
webhook_data = {
|
||||
"name": "Custom Headers Webhook",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://api.example.com/webhook",
|
||||
"custom_headers": [
|
||||
{"key": "X-Source", "value": "dograh"},
|
||||
{"key": "X-Workflow", "value": "test"},
|
||||
],
|
||||
"payload_template": {},
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db:
|
||||
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify custom headers were included
|
||||
call_kwargs = mock_client_instance.request.call_args[1]
|
||||
assert call_kwargs["headers"]["X-Source"] == "dograh"
|
||||
assert call_kwargs["headers"]["X-Workflow"] == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_http_error(self):
|
||||
"""Test webhook execution with HTTP error."""
|
||||
import httpx
|
||||
|
||||
webhook_data = {
|
||||
"name": "Failing Webhook",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://api.example.com/webhook",
|
||||
"payload_template": {},
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db:
|
||||
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal Server Error"
|
||||
mock_response.raise_for_status = MagicMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"Server Error",
|
||||
request=MagicMock(),
|
||||
response=mock_response,
|
||||
)
|
||||
)
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
|
@ -1,117 +0,0 @@
|
|||
"""Tests for the `/s3/signed-url` endpoint.
|
||||
|
||||
This test-suite verifies:
|
||||
1. Regular users can retrieve signed URLs for resources belonging to their own workflow runs.
|
||||
2. Regular users are *forbidden* from accessing resources that belong to other users.
|
||||
3. Superusers can access any resource irrespective of ownership.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
|
||||
# Ensure the S3 environment variables exist so that the module import does not fail
|
||||
os.environ.setdefault("S3_BUCKET", "test-bucket")
|
||||
os.environ.setdefault("S3_REGION", "us-east-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signed_url_for_own_run(monkeypatch, test_client_factory, db_session):
|
||||
"""A normal user should be able to fetch a signed URL for their own workflow run."""
|
||||
from api.db.models import UserModel
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Set-up – create user, workflow & workflow run
|
||||
# ------------------------------------------------------------------
|
||||
user: UserModel = await db_session.get_or_create_user_by_provider_id("user_own_run")
|
||||
workflow = await db_session.create_workflow("wf", {}, user.id)
|
||||
run = await db_session.create_workflow_run("run", workflow.id, "chat", user.id)
|
||||
|
||||
key = f"transcripts/{run.id}.txt"
|
||||
|
||||
# Patch S3 signed-url generator to avoid network calls
|
||||
monkeypatch.setattr(
|
||||
"api.services.filesystem.s3.s3_fs.aget_signed_url",
|
||||
AsyncMock(return_value="https://signed-url"),
|
||||
)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data == {"url": "https://signed-url", "expires_in": 3600}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signed_url_for_other_users_run_forbidden(
|
||||
monkeypatch, test_client_factory, db_session
|
||||
):
|
||||
"""A normal user must *not* access workflow runs owned by someone else."""
|
||||
from api.db.models import UserModel
|
||||
|
||||
# Owner of the workflow run
|
||||
owner: UserModel = await db_session.get_or_create_user_by_provider_id("owner_user")
|
||||
workflow = await db_session.create_workflow("wf", {}, owner.id)
|
||||
run = await db_session.create_workflow_run("run", workflow.id, "chat", owner.id)
|
||||
|
||||
# Second user attempting access
|
||||
intruder: UserModel = await db_session.get_or_create_user_by_provider_id(
|
||||
"intruder_user"
|
||||
)
|
||||
|
||||
key = f"recordings/{run.id}.wav"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.filesystem.s3.s3_fs.aget_signed_url",
|
||||
AsyncMock(return_value="https://signed-url"),
|
||||
)
|
||||
|
||||
async with test_client_factory(intruder) as client:
|
||||
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_superuser_can_access_any_run(
|
||||
monkeypatch, test_client_factory, db_session
|
||||
):
|
||||
"""Superusers should be able to fetch signed URLs for any workflow run."""
|
||||
from api.db.models import UserModel
|
||||
|
||||
# Normal user & run owner
|
||||
owner: UserModel = await db_session.get_or_create_user_by_provider_id(
|
||||
"owner_of_run"
|
||||
)
|
||||
workflow = await db_session.create_workflow("wf", {}, owner.id)
|
||||
run = await db_session.create_workflow_run("run", workflow.id, "chat", owner.id)
|
||||
|
||||
# Superuser
|
||||
superuser: UserModel = await db_session.get_or_create_user_by_provider_id(
|
||||
"admin_user"
|
||||
)
|
||||
|
||||
# Promote to superuser
|
||||
# We need to commit the change so that the DB reflects it
|
||||
async with db_session.async_session() as session:
|
||||
db_user = await session.get(UserModel, superuser.id)
|
||||
db_user.is_superuser = True
|
||||
await session.commit()
|
||||
await session.refresh(db_user) # ensure we have the latest state
|
||||
superuser.is_superuser = True
|
||||
|
||||
key = f"transcripts/{run.id}.txt"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.filesystem.s3.s3_fs.aget_signed_url",
|
||||
AsyncMock(return_value="https://signed-url"),
|
||||
)
|
||||
|
||||
async with test_client_factory(superuser) as client:
|
||||
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["url"] == "https://signed-url"
|
||||
|
|
@ -1,129 +0,0 @@
|
|||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.tasks.s3_upload import upload_audio_to_s3, upload_transcript_to_s3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_audio_to_s3_success():
|
||||
"""Test successful audio upload to S3."""
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf:
|
||||
tf.write(b"fake audio data")
|
||||
temp_path = tf.name
|
||||
|
||||
try:
|
||||
# Mock dependencies
|
||||
mock_ctx = AsyncMock()
|
||||
mock_s3_fs = AsyncMock()
|
||||
mock_db_client = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("api.tasks.s3_upload.s3_fs", mock_s3_fs),
|
||||
patch("api.tasks.s3_upload.db_client", mock_db_client),
|
||||
):
|
||||
await upload_audio_to_s3(
|
||||
mock_ctx, workflow_run_id=123, temp_file_path=temp_path
|
||||
)
|
||||
|
||||
# Verify S3 upload was called
|
||||
mock_s3_fs.aupload_file.assert_called_once_with(
|
||||
temp_path, "recordings/123.wav"
|
||||
)
|
||||
|
||||
# Verify DB update was called
|
||||
mock_db_client.update_workflow_run.assert_called_once_with(
|
||||
run_id=123, recording_url="recordings/123.wav"
|
||||
)
|
||||
|
||||
# Verify temp file was cleaned up
|
||||
assert not os.path.exists(temp_path)
|
||||
|
||||
finally:
|
||||
# Clean up if test failed
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_audio_to_s3_file_not_found():
|
||||
"""Test audio upload when temp file doesn't exist."""
|
||||
mock_ctx = AsyncMock()
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await upload_audio_to_s3(
|
||||
mock_ctx, workflow_run_id=123, temp_file_path="/nonexistent/file.wav"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_transcript_to_s3_success():
|
||||
"""Test successful transcript upload to S3."""
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tf:
|
||||
tf.write("Test transcript content")
|
||||
temp_path = tf.name
|
||||
|
||||
try:
|
||||
# Mock dependencies
|
||||
mock_ctx = AsyncMock()
|
||||
mock_s3_fs = AsyncMock()
|
||||
mock_db_client = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("api.tasks.s3_upload.s3_fs", mock_s3_fs),
|
||||
patch("api.tasks.s3_upload.db_client", mock_db_client),
|
||||
):
|
||||
await upload_transcript_to_s3(
|
||||
mock_ctx, workflow_run_id=456, temp_file_path=temp_path
|
||||
)
|
||||
|
||||
# Verify S3 upload was called
|
||||
mock_s3_fs.aupload_file.assert_called_once_with(
|
||||
temp_path, "transcripts/456.txt"
|
||||
)
|
||||
|
||||
# Verify DB update was called
|
||||
mock_db_client.update_workflow_run.assert_called_once_with(
|
||||
run_id=456, transcript_url="transcripts/456.txt"
|
||||
)
|
||||
|
||||
# Verify temp file was cleaned up
|
||||
assert not os.path.exists(temp_path)
|
||||
|
||||
finally:
|
||||
# Clean up if test failed
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_s3_cleanup_on_error():
|
||||
"""Test that temp files are cleaned up even when S3 upload fails."""
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf:
|
||||
tf.write(b"fake audio data")
|
||||
temp_path = tf.name
|
||||
|
||||
try:
|
||||
mock_ctx = AsyncMock()
|
||||
mock_s3_fs = AsyncMock()
|
||||
# Make S3 upload fail
|
||||
mock_s3_fs.aupload_file.side_effect = Exception("S3 upload failed")
|
||||
|
||||
with patch("api.tasks.s3_upload.s3_fs", mock_s3_fs):
|
||||
with pytest.raises(Exception):
|
||||
await upload_audio_to_s3(
|
||||
mock_ctx, workflow_run_id=123, temp_file_path=temp_path
|
||||
)
|
||||
|
||||
# Verify temp file was still cleaned up
|
||||
assert not os.path.exists(temp_path)
|
||||
|
||||
finally:
|
||||
# Clean up if test failed
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
from api.utils.template_renderer import render_template
|
||||
|
||||
|
||||
def test_render_template_basic():
|
||||
"""Test basic template rendering."""
|
||||
template = "Hello {{name}}, your balance is {{balance}}."
|
||||
context = {"name": "John", "balance": "$1000"}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello John, your balance is $1000."
|
||||
|
||||
|
||||
def test_render_template_with_spaces():
|
||||
"""Test template rendering with spaces around variables."""
|
||||
template = "Hello {{ name }}, your balance is {{ balance }}."
|
||||
context = {"name": "John", "balance": "$1000"}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello John, your balance is $1000."
|
||||
|
||||
|
||||
def test_render_template_missing_variable():
|
||||
"""Test template rendering with missing variables."""
|
||||
template = "Hello {{name}}, your balance is {{balance}}."
|
||||
context = {"name": "John"}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello John, your balance is ."
|
||||
|
||||
|
||||
def test_render_template_with_fallback():
|
||||
"""Test template rendering with fallback values."""
|
||||
template = "Hello {{name | fallback}}, your balance is {{balance | fallback:$0}}."
|
||||
context = {}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello Name, your balance is $0."
|
||||
|
||||
|
||||
def test_render_template_with_fallback_existing_value():
|
||||
"""Test that fallback is not used when value exists."""
|
||||
template = "Hello {{name | fallback:Guest}}"
|
||||
context = {"name": "John"}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello John"
|
||||
|
||||
|
||||
def test_render_template_with_line_breaks():
|
||||
"""Test template rendering with line breaks."""
|
||||
template = (
|
||||
"DISPOSITION_CODE: {{call_disposition}}\\nCALL_DURATION: {{call_duration}}"
|
||||
)
|
||||
context = {"call_disposition": "XFER", "call_duration": "300"}
|
||||
|
||||
result = render_template(template, context)
|
||||
expected = "DISPOSITION_CODE: XFER\nCALL_DURATION: 300"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_render_template_empty():
|
||||
"""Test rendering empty template."""
|
||||
assert render_template("", {}) == ""
|
||||
assert render_template(None, {}) == None
|
||||
|
||||
|
||||
def test_render_template_no_placeholders():
|
||||
"""Test template with no placeholders."""
|
||||
template = "This is a plain text message"
|
||||
result = render_template(template, {"unused": "value"})
|
||||
assert result == "This is a plain text message"
|
||||
|
||||
|
||||
def test_render_template_none_values():
|
||||
"""Test template with None values."""
|
||||
template = "Value: {{value}}"
|
||||
context = {"value": None}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Value: "
|
||||
|
||||
|
||||
def test_render_template_numeric_values():
|
||||
"""Test template with numeric values."""
|
||||
template = "Count: {{count}}, Price: {{price}}"
|
||||
context = {"count": 42, "price": 19.99}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Count: 42, Price: 19.99"
|
||||
|
|
@ -1,152 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Test script to verify atomic operations in organization_usage_client.py
|
||||
This simulates concurrent access from multiple processes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
# Set up environment
|
||||
os.environ.setdefault("DATABASE_URL", os.environ.get("DATABASE_URL", ""))
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from api.db.organization_usage_client import OrganizationUsageClient
|
||||
|
||||
|
||||
async def reserve_quota_process(org_id: int, tokens: int, process_id: int):
|
||||
"""Simulate a process trying to reserve quota."""
|
||||
engine = create_async_engine(os.environ["DATABASE_URL"])
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
client = OrganizationUsageClient(async_session)
|
||||
|
||||
results = []
|
||||
for i in range(5):
|
||||
result = await client.check_and_reserve_quota(org_id, tokens)
|
||||
results.append((process_id, i, result))
|
||||
await asyncio.sleep(0.01) # Small delay to increase contention
|
||||
|
||||
await engine.dispose()
|
||||
return results
|
||||
|
||||
|
||||
async def update_usage_process(org_id: int, tokens: int, process_id: int):
|
||||
"""Simulate a process updating usage after runs."""
|
||||
engine = create_async_engine(os.environ["DATABASE_URL"])
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
client = OrganizationUsageClient(async_session)
|
||||
|
||||
for i in range(5):
|
||||
await client.update_usage_after_run(org_id, tokens, duration_seconds=10)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await engine.dispose()
|
||||
return f"Process {process_id} completed updates"
|
||||
|
||||
|
||||
def run_reserve_quota(args):
|
||||
"""Wrapper to run async function in process."""
|
||||
org_id, tokens, process_id = args
|
||||
return asyncio.run(reserve_quota_process(org_id, tokens, process_id))
|
||||
|
||||
|
||||
def run_update_usage(args):
|
||||
"""Wrapper to run async function in process."""
|
||||
org_id, tokens, process_id = args
|
||||
return asyncio.run(update_usage_process(org_id, tokens, process_id))
|
||||
|
||||
|
||||
async def test_concurrent_quota_reservation():
|
||||
"""Test that concurrent quota reservations are handled atomically."""
|
||||
print("Testing concurrent quota reservations...")
|
||||
|
||||
# Assuming org_id 1 exists with quota enabled
|
||||
org_id = 1
|
||||
tokens_per_request = 100
|
||||
|
||||
# Run multiple processes trying to reserve quota simultaneously
|
||||
with ProcessPoolExecutor(max_workers=3) as executor:
|
||||
futures = []
|
||||
for i in range(3):
|
||||
futures.append(
|
||||
executor.submit(run_reserve_quota, (org_id, tokens_per_request, i))
|
||||
)
|
||||
|
||||
results = []
|
||||
for future in futures:
|
||||
results.extend(future.result())
|
||||
|
||||
print(f"Reservation results: {results}")
|
||||
|
||||
# Check that reservations were handled atomically
|
||||
successful_reservations = sum(1 for _, _, success in results if success)
|
||||
print(f"Successful reservations: {successful_reservations}")
|
||||
|
||||
|
||||
async def test_concurrent_usage_updates():
|
||||
"""Test that concurrent usage updates are handled atomically."""
|
||||
print("\nTesting concurrent usage updates...")
|
||||
|
||||
org_id = 1
|
||||
tokens_per_update = 50
|
||||
|
||||
# Get initial usage
|
||||
engine = create_async_engine(os.environ["DATABASE_URL"])
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
client = OrganizationUsageClient(async_session)
|
||||
|
||||
initial_usage = await client.get_current_usage(org_id)
|
||||
initial_tokens = initial_usage["used_dograh_tokens"]
|
||||
print(f"Initial tokens: {initial_tokens}")
|
||||
|
||||
# Run multiple processes updating usage simultaneously
|
||||
with ProcessPoolExecutor(max_workers=3) as executor:
|
||||
futures = []
|
||||
for i in range(3):
|
||||
futures.append(
|
||||
executor.submit(run_update_usage, (org_id, tokens_per_update, i))
|
||||
)
|
||||
|
||||
for future in futures:
|
||||
print(future.result())
|
||||
|
||||
# Check final usage
|
||||
final_usage = await client.get_current_usage(org_id)
|
||||
final_tokens = final_usage["used_dograh_tokens"]
|
||||
expected_tokens = initial_tokens + (
|
||||
3 * 5 * tokens_per_update
|
||||
) # 3 processes * 5 updates * 50 tokens
|
||||
|
||||
print(f"Final tokens: {final_tokens}")
|
||||
print(f"Expected tokens: {expected_tokens}")
|
||||
print(f"Difference: {final_tokens - expected_tokens}")
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
if final_tokens == expected_tokens:
|
||||
print("✅ All updates were applied atomically!")
|
||||
else:
|
||||
print("❌ Some updates were lost due to race conditions!")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all concurrency tests."""
|
||||
try:
|
||||
await test_concurrent_quota_reservation()
|
||||
await test_concurrent_usage_updates()
|
||||
except Exception as e:
|
||||
print(f"Error during testing: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting organization usage concurrency tests...")
|
||||
print(f"Using DATABASE_URL: {os.environ.get('DATABASE_URL', 'NOT SET')}")
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,140 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.dto import ExtractionVariableDTO, VariableType
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
)
|
||||
|
||||
|
||||
class DummyLLM:
|
||||
"""A minimal stub that mimics the parts of an LLM service used by the extractor."""
|
||||
|
||||
def __init__(self, streamed_response: str | None = None):
|
||||
# Optionally provide a pre-defined streaming response for _perform_extraction tests
|
||||
self._streamed_response = streamed_response or "{}"
|
||||
self.registered_functions: dict[str, AsyncMock] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API used by VariableExtractionManager
|
||||
# ------------------------------------------------------------------
|
||||
def register_function(self, name: str, func, cancel_on_interruption=True): # noqa: D401 – simple delegate
|
||||
self.registered_functions[name] = func
|
||||
|
||||
async def get_chat_completions(self, _context, _messages):
|
||||
"""Return an async generator that yields a single chunk with the full response."""
|
||||
|
||||
class _Delta: # noqa: D401 – tiny helper classes for stub response
|
||||
def __init__(self, content):
|
||||
self.content = content
|
||||
|
||||
class _Choice:
|
||||
def __init__(self, delta):
|
||||
self.delta = delta
|
||||
|
||||
class _Chunk:
|
||||
def __init__(self, content):
|
||||
self.choices = [_Choice(_Delta(content))]
|
||||
|
||||
async def _stream():
|
||||
yield _Chunk(self._streamed_response)
|
||||
|
||||
return _stream()
|
||||
|
||||
|
||||
class DummyEngine:
|
||||
"""A bare-bones Engine stub exposing only what the extractor relies on."""
|
||||
|
||||
def __init__(self, llm):
|
||||
self.llm = llm
|
||||
self.context = OpenAILLMContext()
|
||||
self._pending_function_calls = 0
|
||||
# VariableExtractionManager currently updates this private attribute
|
||||
self._gathered_context: dict = {}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_extraction_parses_json_correctly():
|
||||
"""_perform_extraction should return the parsed JSON from the LLM stream."""
|
||||
# Set dummy OpenAI API key to prevent initialization errors
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
||||
expected_payload = {"name": "Alice", "age": 30}
|
||||
llm = DummyLLM(json.dumps(expected_payload))
|
||||
engine = DummyEngine(llm)
|
||||
manager = VariableExtractionManager(engine)
|
||||
|
||||
# Mock the AsyncOpenAI client and its response
|
||||
mock_response = AsyncMock()
|
||||
mock_response.choices = [AsyncMock()]
|
||||
mock_response.choices[0].message = AsyncMock()
|
||||
mock_response.choices[0].message.content = json.dumps(expected_payload)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI",
|
||||
return_value=mock_client,
|
||||
):
|
||||
# Minimal set of variables to extract – the prompts themselves are irrelevant here
|
||||
extraction_variables = [
|
||||
ExtractionVariableDTO(
|
||||
name="name", type=VariableType.string, prompt="user name"
|
||||
),
|
||||
ExtractionVariableDTO(
|
||||
name="age", type=VariableType.number, prompt="user age"
|
||||
),
|
||||
]
|
||||
|
||||
result = await manager._perform_extraction(
|
||||
extraction_variables, parent_ctx=None, extraction_prompt=""
|
||||
)
|
||||
|
||||
assert result == expected_payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_extraction_with_custom_system_prompt():
|
||||
"""_perform_extraction should use the provided extraction_prompt as system prompt."""
|
||||
# Set dummy OpenAI API key to prevent initialization errors
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
||||
expected_payload = {"color": "blue"}
|
||||
llm = DummyLLM(json.dumps(expected_payload))
|
||||
engine = DummyEngine(llm)
|
||||
manager = VariableExtractionManager(engine)
|
||||
|
||||
# Mock the AsyncOpenAI client and its response
|
||||
mock_response = AsyncMock()
|
||||
mock_response.choices = [AsyncMock()]
|
||||
mock_response.choices[0].message = AsyncMock()
|
||||
mock_response.choices[0].message.content = json.dumps(expected_payload)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI",
|
||||
return_value=mock_client,
|
||||
):
|
||||
extraction_variables = [
|
||||
ExtractionVariableDTO(
|
||||
name="color", type=VariableType.string, prompt="favourite color"
|
||||
)
|
||||
]
|
||||
|
||||
# Call with a custom extraction prompt
|
||||
custom_prompt = "You are a color extraction specialist."
|
||||
result = await manager._perform_extraction(
|
||||
extraction_variables, parent_ctx=None, extraction_prompt=custom_prompt
|
||||
)
|
||||
|
||||
assert result == expected_payload
|
||||
|
|
@ -1,547 +0,0 @@
|
|||
"""
|
||||
Test voicemail detection in RTC connection flow.
|
||||
|
||||
This test emulates how a call is connected using SmallWebRTC,
|
||||
triggers voicemail detection, and verifies the disconnect reason.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.routes.rtc_offer import RTCOfferRequest, offer
|
||||
from api.services.workflow.pipecat_engine_voicemail_detector import VoicemailDetector
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestVoicemailDetectionRTC:
|
||||
"""Test voicemail detection through RTC connection flow."""
|
||||
|
||||
async def test_voicemail_detection_full_flow(self):
|
||||
"""
|
||||
Test complete voicemail detection flow:
|
||||
1. RTC connection request
|
||||
2. Transport sends on_client_connected event
|
||||
3. Engine initializes with voicemail detection enabled
|
||||
4. Voicemail detector returns true
|
||||
5. Call terminates with voicemail_detected reason
|
||||
6. Transport sends on_client_disconnected event
|
||||
7. Disconnect reason is properly set
|
||||
"""
|
||||
# Mock user and authentication
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.organization_id = 1
|
||||
|
||||
# Mock workflow with voicemail detection enabled
|
||||
mock_workflow = Mock()
|
||||
mock_workflow.id = 100
|
||||
mock_workflow.workflow_definition_with_fallback = {
|
||||
"edges": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"data": {
|
||||
"detect_voicemail": True,
|
||||
"system_prompt": "You are a helpful assistant",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Mock workflow run
|
||||
mock_workflow_run = Mock()
|
||||
mock_workflow_run.id = 200
|
||||
mock_workflow_run.is_completed = False
|
||||
|
||||
# Create request
|
||||
request = RTCOfferRequest(
|
||||
pc_id="test_pc_123",
|
||||
sdp="test_sdp_offer",
|
||||
type="offer",
|
||||
workflow_id=mock_workflow.id,
|
||||
workflow_run_id=mock_workflow_run.id,
|
||||
restart_pc=False,
|
||||
call_context_vars={"test_var": "test_value"},
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
with (
|
||||
patch("api.services.auth.depends.get_user") as mock_get_user_dep,
|
||||
patch("api.routes.rtc_offer.SmallWebRTCConnection") as MockWebRTCConnection,
|
||||
patch("api.routes.rtc_offer.run_pipeline_smallwebrtc") as mock_run_pipeline,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_get_user_dep.return_value = mock_user
|
||||
|
||||
# Mock WebRTC connection
|
||||
mock_connection = Mock()
|
||||
mock_connection.pc_id = "test_pc_123"
|
||||
mock_connection.initialize = AsyncMock()
|
||||
mock_connection.get_answer = Mock(
|
||||
return_value={
|
||||
"pc_id": "test_pc_123",
|
||||
"sdp": "test_sdp_answer",
|
||||
"type": "answer",
|
||||
}
|
||||
)
|
||||
MockWebRTCConnection.return_value = mock_connection
|
||||
|
||||
# Track registered event handlers
|
||||
registered_handlers = {}
|
||||
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
registered_handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_connection.event_handler = mock_event_handler
|
||||
|
||||
# Mock BackgroundTasks
|
||||
mock_background_tasks = Mock()
|
||||
|
||||
# Create the offer
|
||||
response = await offer(request, mock_background_tasks, mock_user)
|
||||
|
||||
# Verify response
|
||||
assert response["pc_id"] == "test_pc_123"
|
||||
assert response["type"] == "answer"
|
||||
|
||||
# Verify connection was initialized
|
||||
mock_connection.initialize.assert_called_once_with(
|
||||
sdp="test_sdp_offer", type="offer"
|
||||
)
|
||||
|
||||
# Verify background task was added
|
||||
mock_background_tasks.add_task.assert_called_once()
|
||||
task_args = mock_background_tasks.add_task.call_args[0]
|
||||
assert task_args[0] == mock_run_pipeline
|
||||
assert task_args[1] == mock_connection
|
||||
assert task_args[2] == mock_workflow.id
|
||||
assert task_args[3] == mock_workflow_run.id
|
||||
assert task_args[4] == mock_user.id
|
||||
assert task_args[5] == {"test_var": "test_value"}
|
||||
|
||||
async def test_voicemail_detection_in_pipeline(self):
|
||||
"""Tests whether the updates happen in on_client_disconnected properly
|
||||
with values set in the engine"""
|
||||
# Mock components
|
||||
mock_transport = AsyncMock()
|
||||
mock_engine = Mock() # Use Mock instead of AsyncMock for engine
|
||||
mock_engine.initialize = AsyncMock()
|
||||
mock_engine.cleanup = AsyncMock()
|
||||
mock_audio_buffer = AsyncMock()
|
||||
mock_task = AsyncMock()
|
||||
mock_aggregator = Mock()
|
||||
|
||||
# Setup engine with voicemail detector
|
||||
mock_voicemail_detector = AsyncMock(spec=VoicemailDetector)
|
||||
mock_engine.voicemail_detector = mock_voicemail_detector
|
||||
mock_engine.get_call_disposition = Mock(
|
||||
return_value=EndTaskReason.VOICEMAIL_DETECTED.value
|
||||
)
|
||||
mock_engine.get_gathered_context = Mock(
|
||||
return_value={
|
||||
"voicemail_transcript": "Hi, you've reached John's voicemail. Please leave a message.",
|
||||
"voicemail_confidence": 0.95,
|
||||
}
|
||||
)
|
||||
|
||||
# Mock usage metrics
|
||||
mock_aggregator.get_all_usage_metrics_serialized.return_value = {}
|
||||
|
||||
# Register event handlers
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_transport_event_handlers,
|
||||
)
|
||||
|
||||
# Track registered handlers
|
||||
handlers = {}
|
||||
|
||||
def track_handler(event_name):
|
||||
def decorator(func):
|
||||
handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_transport.event_handler = track_handler
|
||||
|
||||
# Create a mock db_client module with update_workflow_run method
|
||||
mock_db_client = Mock()
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("api.services.pipecat.event_handlers.db_client", mock_db_client),
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_enqueue_job,
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
return_value=1,
|
||||
),
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
side_effect=lambda value, org_id: value, # Return value unchanged
|
||||
),
|
||||
):
|
||||
# Register handlers
|
||||
register_transport_event_handlers(
|
||||
mock_transport,
|
||||
workflow_run_id=123,
|
||||
audio_buffer=mock_audio_buffer,
|
||||
task=mock_task,
|
||||
engine=mock_engine,
|
||||
usage_metrics_aggregator=mock_aggregator,
|
||||
)
|
||||
|
||||
# Verify handlers were registered
|
||||
assert "on_client_connected" in handlers
|
||||
assert "on_client_disconnected" in handlers
|
||||
|
||||
# Simulate client connection
|
||||
await handlers["on_client_connected"](
|
||||
mock_transport, {"id": "participant_1"}
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
mock_audio_buffer.start_recording.assert_called_once()
|
||||
mock_engine.initialize.assert_called_once()
|
||||
|
||||
# Simulate voicemail detection and disconnect
|
||||
await handlers["on_client_disconnected"](
|
||||
mock_transport, {"id": "participant_1"}, None
|
||||
)
|
||||
|
||||
# Verify engine cleanup
|
||||
mock_engine.cleanup.assert_called_once()
|
||||
|
||||
# TODO: check whether task was cancelled or not once have more
|
||||
# clarity on how to handle engine disconnect vs remote hangup
|
||||
# Verify task was NOT cancelled (engine disconnect)
|
||||
# mock_task.cancel.assert_not_called()
|
||||
|
||||
# Verify workflow run was updated with voicemail context
|
||||
mock_db_client.update_workflow_run.assert_called()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert call_args[1]["run_id"] == 123
|
||||
# Check that the mapped_call_disposition was set correctly
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== "voicemail_detected"
|
||||
)
|
||||
|
||||
async def test_voicemail_detector_audio_processing(self):
|
||||
"""Test VoicemailDetector audio processing and detection logic - tests that voicemail detector
|
||||
calls engine's send_end_task_frame with the correct reason and metadata"""
|
||||
# Create voicemail detector
|
||||
detector = VoicemailDetector(detection_duration=5.0, workflow_run_id=123)
|
||||
|
||||
# Mock OpenAI client
|
||||
mock_openai = AsyncMock()
|
||||
mock_whisper_response = Mock()
|
||||
mock_whisper_response.text = "Hi, you've reached the voicemail of John Smith. Please leave a message after the beep."
|
||||
mock_openai.audio.transcriptions.create.return_value = mock_whisper_response
|
||||
|
||||
mock_gpt_response = Mock()
|
||||
mock_gpt_response.choices = [Mock()]
|
||||
mock_gpt_response.choices[0].message.content = json.dumps(
|
||||
{
|
||||
"is_voicemail": True,
|
||||
"confidence": 0.98,
|
||||
"reasoning": "Clear voicemail greeting with request to leave message",
|
||||
}
|
||||
)
|
||||
mock_openai.chat.completions.create.return_value = mock_gpt_response
|
||||
|
||||
# Mock engine
|
||||
mock_engine = AsyncMock()
|
||||
mock_engine.task = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.services.workflow.pipecat_engine_voicemail_detector.AsyncOpenAI",
|
||||
return_value=mock_openai,
|
||||
),
|
||||
patch(
|
||||
"api.services.workflow.pipecat_engine_voicemail_detector.s3_fs"
|
||||
) as mock_s3,
|
||||
):
|
||||
# Mock S3 upload to return None (simulating successful upload)
|
||||
mock_s3.aupload_file = AsyncMock(return_value=True)
|
||||
# Start detection
|
||||
await detector.start_detection(mock_engine)
|
||||
assert detector.is_detecting == True
|
||||
|
||||
# Simulate audio data (16kHz, mono, 5 seconds)
|
||||
sample_rate = 16000
|
||||
duration = 5.0
|
||||
audio_data = b"\x00\x00" * int(sample_rate * duration) # Silent audio
|
||||
|
||||
# Process audio in chunks
|
||||
chunk_size = 1600 # 100ms chunks
|
||||
for i in range(0, len(audio_data), chunk_size):
|
||||
chunk = audio_data[i : i + chunk_size]
|
||||
await detector.handle_audio_data(None, chunk, sample_rate, 1)
|
||||
|
||||
# Wait for detection to complete
|
||||
if detector._detection_task:
|
||||
await detector._detection_task
|
||||
|
||||
# Verify OpenAI calls
|
||||
mock_openai.audio.transcriptions.create.assert_called_once()
|
||||
mock_openai.chat.completions.create.assert_called_once()
|
||||
|
||||
# Verify send_end_task_frame was called with voicemail detection
|
||||
mock_engine.send_end_task_frame.assert_called_once_with(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
additional_metadata={
|
||||
"voicemail_transcript": "Hi, you've reached the voicemail of John Smith. Please leave a message after the beep.",
|
||||
"voicemail_confidence": 0.98,
|
||||
"voicemail_reasoning": "Clear voicemail greeting with request to leave message",
|
||||
"voicemail_detection_duration": 5.0,
|
||||
"voicemail_audio_s3_path": "voicemail_detections/123_voicemail_98_5.wav", # S3 upload returns True, so filename is used
|
||||
},
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
||||
async def test_voicemail_detector_no_detection(self):
|
||||
"""Test VoicemailDetector when voicemail is not detected."""
|
||||
# Create voicemail detector
|
||||
detector = VoicemailDetector(detection_duration=5.0, workflow_run_id=124)
|
||||
|
||||
# Mock OpenAI client
|
||||
mock_openai = AsyncMock()
|
||||
mock_whisper_response = Mock()
|
||||
mock_whisper_response.text = "Hello? Hello? Can you hear me?"
|
||||
mock_openai.audio.transcriptions.create.return_value = mock_whisper_response
|
||||
|
||||
mock_gpt_response = Mock()
|
||||
mock_gpt_response.choices = [Mock()]
|
||||
mock_gpt_response.choices[0].message.content = json.dumps(
|
||||
{
|
||||
"is_voicemail": False,
|
||||
"confidence": 0.95,
|
||||
"reasoning": "Live person speaking, asking if caller can hear them",
|
||||
}
|
||||
)
|
||||
mock_openai.chat.completions.create.return_value = mock_gpt_response
|
||||
|
||||
# Mock engine
|
||||
mock_engine = AsyncMock()
|
||||
mock_engine.task = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_voicemail_detector.AsyncOpenAI",
|
||||
return_value=mock_openai,
|
||||
):
|
||||
# Start detection
|
||||
await detector.start_detection(mock_engine)
|
||||
|
||||
# Simulate audio data
|
||||
sample_rate = 16000
|
||||
duration = 5.0
|
||||
audio_data = b"\x00\x00" * int(sample_rate * duration)
|
||||
|
||||
# Process audio
|
||||
await detector.handle_audio_data(None, audio_data, sample_rate, 1)
|
||||
|
||||
# Wait for detection
|
||||
if detector._detection_task:
|
||||
await detector._detection_task
|
||||
|
||||
# Verify send_end_task_frame was NOT called
|
||||
mock_engine.send_end_task_frame.assert_not_called()
|
||||
|
||||
async def test_voicemail_detector_cancellation(self):
|
||||
"""Test VoicemailDetector cancellation before completion."""
|
||||
# Create voicemail detector
|
||||
detector = VoicemailDetector(detection_duration=10.0, workflow_run_id=125)
|
||||
|
||||
# Mock engine
|
||||
mock_engine = AsyncMock()
|
||||
|
||||
# Start detection
|
||||
await detector.start_detection(mock_engine)
|
||||
assert detector.is_detecting == True
|
||||
|
||||
# Cancel detection immediately
|
||||
await detector.stop_detection()
|
||||
assert detector._is_cancelled == True
|
||||
|
||||
# Try to add audio data after cancellation
|
||||
await detector.handle_audio_data(None, b"\x00\x00" * 1000, 16000, 1)
|
||||
|
||||
# Verify buffer didn't grow (no audio accepted after cancellation)
|
||||
assert len(detector.audio_buffer) == 0
|
||||
|
||||
async def test_disconnect_reason_propagation(self):
|
||||
"""Test that voicemail disconnect reason is properly propagated."""
|
||||
# Create disconnect reason info directly
|
||||
disconnect_info = {
|
||||
"disposition_code": EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
"details": "Voicemail detected after 5 seconds of audio",
|
||||
"is_remote": False,
|
||||
"is_user_initiated": False,
|
||||
"is_successful_transfer": False,
|
||||
"transport_metadata": {
|
||||
"voicemail_confidence": 0.97,
|
||||
"voicemail_transcript": "You've reached voicemail...",
|
||||
},
|
||||
}
|
||||
|
||||
# Verify attributes
|
||||
assert disconnect_info["disposition_code"] == "voicemail_detected"
|
||||
assert disconnect_info["is_remote"] == False
|
||||
assert disconnect_info["is_user_initiated"] == False
|
||||
assert disconnect_info["is_successful_transfer"] == False
|
||||
assert (
|
||||
disconnect_info["details"] == "Voicemail detected after 5 seconds of audio"
|
||||
)
|
||||
assert disconnect_info["transport_metadata"]["voicemail_confidence"] == 0.97
|
||||
|
||||
async def test_voicemail_detection_end_to_end(self):
|
||||
"""
|
||||
Complete end-to-end test covering:
|
||||
1. on_client_connected event
|
||||
2. Engine initialization with voicemail detection
|
||||
3. Audio processing and voicemail detection
|
||||
4. Engine setting disconnect reason
|
||||
5. on_client_disconnected event
|
||||
6. Proper disconnect reason in workflow run update
|
||||
"""
|
||||
# Create comprehensive mocks
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_transport_event_handlers,
|
||||
)
|
||||
|
||||
# Mock transport
|
||||
mock_transport = AsyncMock()
|
||||
handlers = {}
|
||||
|
||||
def track_handler(event_name):
|
||||
def decorator(func):
|
||||
handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_transport.event_handler = track_handler
|
||||
|
||||
# Mock audio buffer
|
||||
mock_audio_buffer = Mock()
|
||||
mock_audio_buffer.start_recording = AsyncMock()
|
||||
mock_audio_buffer.stop_recording = AsyncMock()
|
||||
|
||||
# Mock task
|
||||
mock_task = AsyncMock()
|
||||
|
||||
# Mock aggregator
|
||||
mock_aggregator = Mock()
|
||||
mock_aggregator.get_all_usage_metrics_serialized.return_value = {}
|
||||
|
||||
# Create a mock engine with voicemail detection
|
||||
mock_engine = Mock()
|
||||
mock_engine.initialize = AsyncMock()
|
||||
mock_engine.cleanup = AsyncMock()
|
||||
|
||||
# Mock voicemail detector
|
||||
mock_voicemail_detector = Mock()
|
||||
mock_engine.voicemail_detector = mock_voicemail_detector
|
||||
mock_engine._voicemail_detector = mock_voicemail_detector
|
||||
|
||||
# Initially no disconnect reason
|
||||
mock_engine.get_call_disposition = Mock(return_value=None)
|
||||
mock_engine.get_gathered_context = Mock(return_value={})
|
||||
|
||||
# Mock db_client
|
||||
mock_db_client = Mock()
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("api.services.pipecat.event_handlers.db_client", mock_db_client),
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_enqueue_job,
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
return_value=1,
|
||||
),
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
side_effect=lambda value, org_id: value, # Return value unchanged
|
||||
),
|
||||
):
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
mock_transport,
|
||||
workflow_run_id=123,
|
||||
audio_buffer=mock_audio_buffer,
|
||||
task=mock_task,
|
||||
engine=mock_engine,
|
||||
usage_metrics_aggregator=mock_aggregator,
|
||||
)
|
||||
|
||||
# Verify handlers were registered
|
||||
assert "on_client_connected" in handlers
|
||||
assert "on_client_disconnected" in handlers
|
||||
|
||||
# Step 1: Client connects
|
||||
await handlers["on_client_connected"](
|
||||
mock_transport, {"id": "participant_1"}
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
mock_audio_buffer.start_recording.assert_called_once()
|
||||
mock_engine.initialize.assert_called_once()
|
||||
|
||||
# Step 2-3: Simulate voicemail detection occurs
|
||||
# Update engine state to reflect voicemail was detected
|
||||
mock_engine.get_call_disposition = Mock(
|
||||
return_value=EndTaskReason.VOICEMAIL_DETECTED.value
|
||||
)
|
||||
mock_engine.get_gathered_context = Mock(
|
||||
return_value={
|
||||
"voicemail_transcript": "You've reached voicemail, leave a message",
|
||||
"voicemail_confidence": 0.95,
|
||||
}
|
||||
)
|
||||
|
||||
# Step 5: Client disconnects
|
||||
await handlers["on_client_disconnected"](
|
||||
mock_transport, {"id": "participant_1"}, None
|
||||
)
|
||||
|
||||
# Verify engine cleanup
|
||||
mock_engine.cleanup.assert_called_once()
|
||||
|
||||
# Step 6: Verify proper disconnect reason in workflow run update
|
||||
mock_db_client.update_workflow_run.assert_called()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
|
||||
# Check the gathered context includes disconnect reason
|
||||
gathered_context = call_args[1]["gathered_context"]
|
||||
assert gathered_context["mapped_call_disposition"] == "voicemail_detected"
|
||||
assert gathered_context["voicemail_confidence"] == 0.95
|
||||
assert (
|
||||
gathered_context["voicemail_transcript"]
|
||||
== "You've reached voicemail, leave a message"
|
||||
)
|
||||
|
||||
# Verify task was NOT cancelled (engine-initiated disconnect)
|
||||
mock_task.cancel.assert_not_called()
|
||||
|
||||
# Verify audio buffer was stopped
|
||||
mock_audio_buffer.stop_recording.assert_called_once()
|
||||
|
||||
# Verify background jobs were enqueued
|
||||
assert (
|
||||
mock_enqueue_job.call_count >= 3
|
||||
) # At least 3 jobs should be enqueued
|
||||
|
|
@ -1,667 +0,0 @@
|
|||
"""
|
||||
Tests for workflow API routes.
|
||||
|
||||
This module tests the create, update, get, and validate workflow endpoints.
|
||||
The fixtures for database setup, test client, and utilities are in conftest.py.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_definition():
|
||||
"""Sample workflow definition for testing."""
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "6581",
|
||||
"type": "startCall",
|
||||
"position": {"x": 427, "y": 23},
|
||||
"data": {
|
||||
"prompt": "Hello, I am Abhishek from Dograh. ",
|
||||
"is_static": True,
|
||||
"name": "Start Call",
|
||||
"is_start": True,
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": True,
|
||||
"dragging": False,
|
||||
},
|
||||
{
|
||||
"id": "915",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 305, "y": 340},
|
||||
"data": {
|
||||
"prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent.",
|
||||
"name": "Agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
},
|
||||
{
|
||||
"id": "7598",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 90, "y": 650},
|
||||
"data": {
|
||||
"prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon",
|
||||
"name": "Agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
},
|
||||
{
|
||||
"id": "6919",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 520, "y": 650},
|
||||
"data": {
|
||||
"prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon",
|
||||
"name": "Agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
},
|
||||
{
|
||||
"id": "1802",
|
||||
"type": "endCall",
|
||||
"position": {"x": 305, "y": 960},
|
||||
"data": {
|
||||
"prompt": "Thank you!",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"is_static": True,
|
||||
"name": "End Call",
|
||||
"is_end": True,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "7598",
|
||||
"id": "xy-edge__915-7598",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "The customer wants to talk to a customer service agent",
|
||||
"label": "customer service agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "6919",
|
||||
"id": "xy-edge__915-6919",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "customer wants to talk to a sales representative",
|
||||
"label": "sales representative",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "6581",
|
||||
"target": "915",
|
||||
"id": "xy-edge__6581-915",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "Always take this route",
|
||||
"label": "Always take this route",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "7598",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__7598-1802",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "6919",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__6919-1802",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
],
|
||||
"viewport": {"x": 0, "y": 0, "zoom": 1},
|
||||
}
|
||||
|
||||
|
||||
class TestCreateWorkflow:
|
||||
"""Test cases for creating workflows."""
|
||||
|
||||
async def test_create_workflow_success(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test successful workflow creation."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_create_success"
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"name": "Test Workflow",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
}
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/create", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "id" in data
|
||||
assert data["name"] == "Test Workflow"
|
||||
assert data["workflow_definition"] == sample_workflow_definition
|
||||
assert "created_at" in data
|
||||
assert "current_definition_id" in data
|
||||
|
||||
async def test_create_workflow_invalid_definition(
|
||||
self, test_client_factory, db_session
|
||||
):
|
||||
"""Test workflow creation with invalid definition."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_invalid_def"
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"name": "Invalid Workflow",
|
||||
"workflow_definition": {"invalid": "structure"},
|
||||
}
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/create", json=request_data)
|
||||
|
||||
# The API should still create the workflow even with invalid definition
|
||||
# Validation happens in the validate endpoint
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workflow_missing_name(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test workflow creation without name."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_missing_name"
|
||||
)
|
||||
|
||||
request_data = {"workflow_definition": sample_workflow_definition}
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/create", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workflow_missing_definition(
|
||||
self, test_client_factory, db_session
|
||||
):
|
||||
"""Test workflow creation without workflow definition."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_missing_definition"
|
||||
)
|
||||
|
||||
request_data = {"name": "Test Workflow"}
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/create", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
class TestGetWorkflows:
|
||||
"""Test cases for fetching workflows."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_workflows_empty(self, test_client_factory, db_session):
|
||||
"""Test getting all workflows when none exist."""
|
||||
# Create a test user within the test function
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_empty_workflows"
|
||||
)
|
||||
|
||||
# Create a test client for this specific user
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.get("/api/v1/workflow/fetch")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_workflows_with_data(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test getting all workflows when some exist."""
|
||||
# Create a test user within the test function
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_with_workflows"
|
||||
)
|
||||
|
||||
# Create a test client for this specific user
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Test Workflow 1",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Create another workflow
|
||||
create_response2 = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Test Workflow 2",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response2.status_code == status.HTTP_200_OK
|
||||
|
||||
# Get all workflows
|
||||
response = await client.get("/api/v1/workflow/fetch")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
|
||||
# Check that both workflows are returned
|
||||
workflow_names = [w["name"] for w in data]
|
||||
assert "Test Workflow 1" in workflow_names
|
||||
assert "Test Workflow 2" in workflow_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_specific_workflow(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test getting a specific workflow by ID."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_specific_workflow"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Specific Workflow",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
created_workflow = create_response.json()
|
||||
workflow_id = created_workflow["id"]
|
||||
|
||||
# Get the specific workflow
|
||||
response = await client.get(
|
||||
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == workflow_id
|
||||
assert data["name"] == "Specific Workflow"
|
||||
assert data["workflow_definition"] == sample_workflow_definition
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_workflow(self, test_client_factory, db_session):
|
||||
"""Test getting a workflow that doesn't exist."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_nonexistent"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.get("/api/v1/workflow/fetch?workflow_id=99999")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestUpdateWorkflow:
|
||||
"""Test cases for updating workflows."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_workflow_name_only(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test updating only the workflow name."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_update_name"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Original Name",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# Update the workflow name
|
||||
update_data = {"name": "Updated Name"}
|
||||
response = await client.put(
|
||||
f"/api/v1/workflow/{workflow_id}", json=update_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == workflow_id
|
||||
assert data["name"] == "Updated Name"
|
||||
assert (
|
||||
data["workflow_definition"] == sample_workflow_definition
|
||||
) # Should remain unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_workflow_name_and_definition(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test updating both workflow name and definition."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_update_both"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Original Name",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# Create new workflow definition
|
||||
new_definition = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"position": {"x": 50, "y": 50},
|
||||
"data": {"label": "New Start"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
# Update the workflow
|
||||
update_data = {
|
||||
"name": "Updated Name",
|
||||
"workflow_definition": new_definition,
|
||||
}
|
||||
response = await client.put(
|
||||
f"/api/v1/workflow/{workflow_id}", json=update_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == workflow_id
|
||||
assert data["name"] == "Updated Name"
|
||||
assert data["workflow_definition"] == new_definition
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_nonexistent_workflow(self, test_client_factory, db_session):
|
||||
"""Test updating a workflow that doesn't exist."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_update_nonexistent"
|
||||
)
|
||||
|
||||
update_data = {"name": "Updated Name"}
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.put("/api/v1/workflow/99999", json=update_data)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_workflow_missing_name(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test updating a workflow without providing a name."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_update_missing_name"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Original Name",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# Try to update without providing name
|
||||
update_data = {"workflow_definition": sample_workflow_definition}
|
||||
response = await client.put(
|
||||
f"/api/v1/workflow/{workflow_id}", json=update_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
class TestWorkflowValidation:
|
||||
"""Test cases for workflow validation endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_workflow_success(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test successful workflow validation."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_validate_success"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Valid Workflow",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# Validate the workflow
|
||||
response = await client.post(f"/api/v1/workflow/{workflow_id}/validate")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["is_valid"] is True
|
||||
assert data["errors"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_nonexistent_workflow(self, test_client_factory, db_session):
|
||||
"""Test validating a workflow that doesn't exist."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_validate_nonexistent"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/99999/validate")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestWorkflowIntegration:
|
||||
"""Integration tests for workflow operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_workflow_lifecycle(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test the complete lifecycle of a workflow: create, get, update, validate."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_lifecycle"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# 1. Create workflow
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Lifecycle Test Workflow",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# 2. Get the created workflow
|
||||
get_response = await client.get(
|
||||
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
|
||||
)
|
||||
assert get_response.status_code == status.HTTP_200_OK
|
||||
workflow_data = get_response.json()
|
||||
assert workflow_data["name"] == "Lifecycle Test Workflow"
|
||||
|
||||
# 3. Add a new node in the workflow definition
|
||||
new_node = {
|
||||
"id": "6919_new",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 520, "y": 650},
|
||||
"data": {
|
||||
"prompt": "Something new",
|
||||
"name": "Agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
}
|
||||
new_edges = [
|
||||
{
|
||||
"source": "6919",
|
||||
"target": "6919_new",
|
||||
"id": "xy-edge__6919-6919_new",
|
||||
"data": {
|
||||
"condition": "Always take this route",
|
||||
"label": "Always take this route",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"source": "6919_new",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__6919_new-1802",
|
||||
"data": {
|
||||
"condition": "Always take this route",
|
||||
"label": "Always take this route",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
]
|
||||
new_definition = {
|
||||
"nodes": [
|
||||
*sample_workflow_definition["nodes"],
|
||||
new_node,
|
||||
],
|
||||
"edges": [
|
||||
*sample_workflow_definition["edges"],
|
||||
*new_edges,
|
||||
],
|
||||
}
|
||||
|
||||
update_response = await client.put(
|
||||
f"/api/v1/workflow/{workflow_id}",
|
||||
json={
|
||||
"name": "Updated Lifecycle Workflow",
|
||||
"workflow_definition": new_definition,
|
||||
},
|
||||
)
|
||||
assert update_response.status_code == status.HTTP_200_OK
|
||||
assert update_response.json()["name"] == "Updated Lifecycle Workflow"
|
||||
|
||||
# 4. Validate the updated workflow
|
||||
validate_response = await client.post(
|
||||
f"/api/v1/workflow/{workflow_id}/validate"
|
||||
)
|
||||
assert validate_response.status_code == status.HTTP_200_OK
|
||||
assert validate_response.json()["is_valid"] is True
|
||||
|
||||
# 5. Verify the update by getting the workflow again
|
||||
final_get_response = await client.get(
|
||||
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
|
||||
)
|
||||
assert final_get_response.status_code == status.HTTP_200_OK
|
||||
final_data = final_get_response.json()
|
||||
assert final_data["name"] == "Updated Lifecycle Workflow"
|
||||
assert final_data["workflow_definition"] == new_definition
|
||||
95
api/utils/credential_auth.py
Normal file
95
api/utils/credential_auth.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""Build HTTP authentication headers from ExternalCredentialModel.
|
||||
|
||||
This module provides functions for constructing HTTP authentication headers
|
||||
from ExternalCredentialModel instances. Used by both webhook integrations
|
||||
and custom tool execution.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.db.models import ExternalCredentialModel
|
||||
|
||||
|
||||
def build_auth_header(credential: "ExternalCredentialModel") -> Dict[str, str]:
|
||||
"""Build authentication header based on credential type.
|
||||
|
||||
Supports the following credential types:
|
||||
- bearer_token: Authorization: Bearer <token>
|
||||
- api_key: Custom header with API key
|
||||
- basic_auth: Authorization: Basic <base64(username:password)>
|
||||
- custom_header: Any custom header name/value pair
|
||||
|
||||
Args:
|
||||
credential: The ExternalCredentialModel instance
|
||||
|
||||
Returns:
|
||||
Dict with header name and value, or empty dict if credential type
|
||||
is not recognized or is 'none'
|
||||
"""
|
||||
cred_type = credential.credential_type
|
||||
cred_data = credential.credential_data or {}
|
||||
|
||||
if cred_type == "bearer_token":
|
||||
token = cred_data.get("token", "")
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
elif cred_type == "api_key":
|
||||
header_name = cred_data.get("header_name", "X-API-Key")
|
||||
api_key = cred_data.get("api_key", "")
|
||||
return {header_name: api_key}
|
||||
|
||||
elif cred_type == "basic_auth":
|
||||
username = cred_data.get("username", "")
|
||||
password = cred_data.get("password", "")
|
||||
encoded = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
return {"Authorization": f"Basic {encoded}"}
|
||||
|
||||
elif cred_type == "custom_header":
|
||||
header_name = cred_data.get("header_name", "X-Custom")
|
||||
header_value = cred_data.get("header_value", "")
|
||||
return {header_name: header_value}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def build_auth_header_from_data(
|
||||
credential_type: str,
|
||||
credential_data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Build authentication header from raw credential data.
|
||||
|
||||
This is a convenience function when you have credential data
|
||||
directly rather than a full ExternalCredentialModel.
|
||||
|
||||
Args:
|
||||
credential_type: Type of credential (bearer_token, api_key, etc.)
|
||||
credential_data: Dict containing credential-specific fields
|
||||
|
||||
Returns:
|
||||
Dict with header name and value
|
||||
"""
|
||||
cred_data = credential_data or {}
|
||||
|
||||
if credential_type == "bearer_token":
|
||||
token = cred_data.get("token", "")
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
elif credential_type == "api_key":
|
||||
header_name = cred_data.get("header_name", "X-API-Key")
|
||||
api_key = cred_data.get("api_key", "")
|
||||
return {header_name: api_key}
|
||||
|
||||
elif credential_type == "basic_auth":
|
||||
username = cred_data.get("username", "")
|
||||
password = cred_data.get("password", "")
|
||||
encoded = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
return {"Authorization": f"Basic {encoded}"}
|
||||
|
||||
elif credential_type == "custom_header":
|
||||
header_name = cred_data.get("header_name", "X-Custom")
|
||||
header_value = cred_data.get("header_value", "")
|
||||
return {header_name: header_value}
|
||||
|
||||
return {}
|
||||
498
ui/src/app/tools/[toolUuid]/page.tsx
Normal file
498
ui/src/app/tools/[toolUuid]/page.tsx
Normal file
|
|
@ -0,0 +1,498 @@
|
|||
"use client";
|
||||
|
||||
import { ArrowLeft, Code, Globe, Loader2, Save } from "lucide-react";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import {
|
||||
getToolApiV1ToolsToolUuidGet,
|
||||
updateToolApiV1ToolsToolUuidPut,
|
||||
} from "@/client/sdk.gen";
|
||||
import type { ToolResponse } from "@/client/types.gen";
|
||||
|
||||
// Extended HttpApiConfig with parameters (until client types are regenerated)
|
||||
interface HttpApiConfigWithParams {
|
||||
method?: string;
|
||||
url?: string;
|
||||
headers?: Record<string, string>;
|
||||
credential_uuid?: string;
|
||||
parameters?: ToolParameter[];
|
||||
timeout_ms?: number;
|
||||
}
|
||||
import {
|
||||
CredentialSelector,
|
||||
type HttpMethod,
|
||||
HttpMethodSelector,
|
||||
KeyValueEditor,
|
||||
type KeyValueItem,
|
||||
ParameterEditor,
|
||||
type ToolParameter,
|
||||
} from "@/components/http";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
export default function ToolDetailPage() {
|
||||
const { toolUuid } = useParams<{ toolUuid: string }>();
|
||||
const { user, getAccessToken, redirectToLogin, loading } = useAuth();
|
||||
const router = useRouter();
|
||||
|
||||
const [tool, setTool] = useState<ToolResponse | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [saveSuccess, setSaveSuccess] = useState(false);
|
||||
const [showCodeDialog, setShowCodeDialog] = useState(false);
|
||||
|
||||
// Form state
|
||||
const [name, setName] = useState("");
|
||||
const [description, setDescription] = useState("");
|
||||
const [httpMethod, setHttpMethod] = useState<HttpMethod>("POST");
|
||||
const [url, setUrl] = useState("");
|
||||
const [credentialUuid, setCredentialUuid] = useState("");
|
||||
const [headers, setHeaders] = useState<KeyValueItem[]>([]);
|
||||
const [parameters, setParameters] = useState<ToolParameter[]>([]);
|
||||
const [timeoutMs, setTimeoutMs] = useState(5000);
|
||||
|
||||
// Redirect if not authenticated
|
||||
useEffect(() => {
|
||||
if (!loading && !user) {
|
||||
redirectToLogin();
|
||||
}
|
||||
}, [loading, user, redirectToLogin]);
|
||||
|
||||
const fetchTool = useCallback(async () => {
|
||||
if (loading || !user || !toolUuid) return;
|
||||
|
||||
try {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
const accessToken = await getAccessToken();
|
||||
|
||||
const response = await getToolApiV1ToolsToolUuidGet({
|
||||
path: { tool_uuid: toolUuid },
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.data) {
|
||||
setTool(response.data);
|
||||
populateFormFromTool(response.data);
|
||||
}
|
||||
} catch (err) {
|
||||
setError("Failed to fetch tool");
|
||||
console.error("Error fetching tool:", err);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [loading, user, toolUuid, getAccessToken]);
|
||||
|
||||
const populateFormFromTool = (tool: ToolResponse) => {
|
||||
setName(tool.name);
|
||||
setDescription(tool.description || "");
|
||||
|
||||
const config = tool.definition?.config as HttpApiConfigWithParams | undefined;
|
||||
if (config) {
|
||||
setHttpMethod((config.method as HttpMethod) || "POST");
|
||||
setUrl(config.url || "");
|
||||
setCredentialUuid(config.credential_uuid || "");
|
||||
setTimeoutMs(config.timeout_ms || 5000);
|
||||
|
||||
// Convert headers object to array
|
||||
if (config.headers) {
|
||||
setHeaders(
|
||||
Object.entries(config.headers).map(([key, value]) => ({
|
||||
key,
|
||||
value: value as string,
|
||||
}))
|
||||
);
|
||||
} else {
|
||||
setHeaders([]);
|
||||
}
|
||||
|
||||
// Load parameters
|
||||
if (config.parameters && Array.isArray(config.parameters)) {
|
||||
setParameters(
|
||||
config.parameters.map((p: ToolParameter) => ({
|
||||
name: p.name || "",
|
||||
type: p.type || "string",
|
||||
description: p.description || "",
|
||||
required: p.required ?? true,
|
||||
}))
|
||||
);
|
||||
} else {
|
||||
setParameters([]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchTool();
|
||||
}, [fetchTool]);
|
||||
|
||||
const handleSave = async () => {
|
||||
// Validate URL
|
||||
if (!url.trim()) {
|
||||
setError("URL is required");
|
||||
return;
|
||||
}
|
||||
|
||||
// Validate parameters have names
|
||||
const invalidParams = parameters.filter((p) => !p.name.trim());
|
||||
if (invalidParams.length > 0) {
|
||||
setError("All parameters must have a name");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setIsSaving(true);
|
||||
setError(null);
|
||||
setSaveSuccess(false);
|
||||
const accessToken = await getAccessToken();
|
||||
|
||||
// Convert headers array to object
|
||||
const headersObject: Record<string, string> = {};
|
||||
headers.filter((h) => h.key && h.value).forEach((h) => {
|
||||
headersObject[h.key] = h.value;
|
||||
});
|
||||
|
||||
// Filter out empty parameters
|
||||
const validParameters = parameters.filter((p) => p.name.trim());
|
||||
|
||||
// Build the request body (cast needed until client types are regenerated)
|
||||
const requestBody = {
|
||||
name,
|
||||
description: description || undefined,
|
||||
definition: {
|
||||
schema_version: 1,
|
||||
type: "http_api",
|
||||
config: {
|
||||
method: httpMethod,
|
||||
url,
|
||||
credential_uuid: credentialUuid || undefined,
|
||||
headers:
|
||||
Object.keys(headersObject).length > 0
|
||||
? headersObject
|
||||
: undefined,
|
||||
parameters:
|
||||
validParameters.length > 0 ? validParameters : undefined,
|
||||
timeout_ms: timeoutMs,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const response = await updateToolApiV1ToolsToolUuidPut({
|
||||
path: { tool_uuid: toolUuid },
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
body: requestBody as any,
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.data) {
|
||||
setTool(response.data);
|
||||
setSaveSuccess(true);
|
||||
setTimeout(() => setSaveSuccess(false), 3000);
|
||||
}
|
||||
} catch (err) {
|
||||
setError("Failed to save tool");
|
||||
console.error("Error saving tool:", err);
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const getCodeSnippet = () => {
|
||||
if (!tool) return "";
|
||||
|
||||
const headersObj: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
headers.filter((h) => h.key && h.value).forEach((h) => {
|
||||
headersObj[h.key] = h.value;
|
||||
});
|
||||
|
||||
// Build example body from parameters
|
||||
const exampleBody: Record<string, unknown> = {};
|
||||
parameters.forEach((p) => {
|
||||
if (p.type === "number") {
|
||||
exampleBody[p.name] = 0;
|
||||
} else if (p.type === "boolean") {
|
||||
exampleBody[p.name] = true;
|
||||
} else {
|
||||
exampleBody[p.name] = `<${p.name}>`;
|
||||
}
|
||||
});
|
||||
|
||||
const hasBody = httpMethod !== "GET" && httpMethod !== "DELETE" && parameters.length > 0;
|
||||
|
||||
return `// ${tool.name}
|
||||
// ${tool.description || "HTTP API Tool"}
|
||||
|
||||
const response = await fetch("${url}", {
|
||||
method: "${httpMethod}",
|
||||
headers: ${JSON.stringify(headersObj, null, 4)},${hasBody ? `
|
||||
body: JSON.stringify(${JSON.stringify(exampleBody, null, 4)}),` : ""}
|
||||
});
|
||||
|
||||
const data = await response.json();`;
|
||||
};
|
||||
|
||||
if (loading || !user) {
|
||||
return (
|
||||
<div className="min-h-screen bg-background flex items-center justify-center">
|
||||
<div className="space-y-4">
|
||||
<Skeleton className="h-12 w-64" />
|
||||
<Skeleton className="h-64 w-96" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="min-h-screen bg-background">
|
||||
<div className="container mx-auto px-4 py-8">
|
||||
<div className="max-w-4xl mx-auto space-y-6">
|
||||
<Skeleton className="h-8 w-48" />
|
||||
<Skeleton className="h-64 w-full" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!tool) {
|
||||
return (
|
||||
<div className="min-h-screen bg-background">
|
||||
<div className="container mx-auto px-4 py-8">
|
||||
<div className="max-w-4xl mx-auto text-center">
|
||||
<h1 className="text-2xl font-bold mb-4">Tool not found</h1>
|
||||
<Button onClick={() => router.push("/tools")}>
|
||||
<ArrowLeft className="w-4 h-4 mr-2" />
|
||||
Back to Tools
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-background">
|
||||
<div className="container mx-auto px-4 py-8">
|
||||
<div className="max-w-4xl mx-auto">
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between mb-6">
|
||||
<div className="flex items-center gap-4">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => router.push("/tools")}
|
||||
>
|
||||
<ArrowLeft className="w-4 h-4 mr-2" />
|
||||
Back
|
||||
</Button>
|
||||
<div className="flex items-center gap-3">
|
||||
<div
|
||||
className="w-10 h-10 rounded-lg flex items-center justify-center"
|
||||
style={{
|
||||
backgroundColor: tool.icon_color || "#3B82F6",
|
||||
}}
|
||||
>
|
||||
<Globe className="w-5 h-5 text-white" />
|
||||
</div>
|
||||
<div>
|
||||
<h1 className="text-xl font-bold">{name}</h1>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
HTTP API Tool
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => setShowCodeDialog(true)}
|
||||
>
|
||||
<Code className="w-4 h-4 mr-2" />
|
||||
View Code
|
||||
</Button>
|
||||
<Button onClick={handleSave} disabled={isSaving}>
|
||||
{isSaving ? (
|
||||
<>
|
||||
<Loader2 className="w-4 h-4 mr-2 animate-spin" />
|
||||
Saving...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Save className="w-4 h-4 mr-2" />
|
||||
Save
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="mb-4 p-4 bg-destructive/10 border border-destructive/20 rounded-lg text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{saveSuccess && (
|
||||
<div className="mb-4 p-4 bg-green-500/10 border border-green-500/20 rounded-lg text-green-600">
|
||||
Tool saved successfully!
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Tool Configuration</CardTitle>
|
||||
<CardDescription>
|
||||
Configure the HTTP API endpoint and request settings
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<Tabs defaultValue="settings" className="w-full">
|
||||
<TabsList className="grid w-full grid-cols-3">
|
||||
<TabsTrigger value="settings">Settings</TabsTrigger>
|
||||
<TabsTrigger value="auth">Authentication</TabsTrigger>
|
||||
<TabsTrigger value="parameters">Parameters</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="settings" className="space-y-4 mt-4">
|
||||
<div className="grid gap-2">
|
||||
<Label>Tool Name</Label>
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
Use a descriptive name, like "Get Weather using API" for a tool that fetches weather
|
||||
</Label>
|
||||
<Input
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
placeholder="e.g., Book Appointment"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label>Description</Label>
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
Provide a description which makes it easy for LLM to understand what this tool does
|
||||
</Label>
|
||||
<Textarea
|
||||
value={description}
|
||||
onChange={(e) => setDescription(e.target.value)}
|
||||
placeholder="What does this tool do?"
|
||||
rows={3}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="grid gap-2">
|
||||
<Label>HTTP Method</Label>
|
||||
<HttpMethodSelector
|
||||
value={httpMethod}
|
||||
onChange={setHttpMethod}
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<Label>Timeout (ms)</Label>
|
||||
<Input
|
||||
type="number"
|
||||
value={timeoutMs}
|
||||
onChange={(e) =>
|
||||
setTimeoutMs(parseInt(e.target.value) || 5000)
|
||||
}
|
||||
min={1000}
|
||||
max={30000}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label>Endpoint URL</Label>
|
||||
<Input
|
||||
value={url}
|
||||
onChange={(e) => setUrl(e.target.value)}
|
||||
placeholder="https://api.example.com/appointments"
|
||||
/>
|
||||
</div>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="auth" className="space-y-4 mt-4">
|
||||
<CredentialSelector
|
||||
value={credentialUuid}
|
||||
onChange={setCredentialUuid}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="parameters" className="space-y-4 mt-4">
|
||||
<div className="grid gap-2">
|
||||
<Label>Tool Parameters</Label>
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
Define the parameters that the LLM will provide when calling this tool.
|
||||
These will be sent as JSON body for POST/PUT/PATCH or as URL query params for GET/DELETE.
|
||||
</Label>
|
||||
<ParameterEditor
|
||||
parameters={parameters}
|
||||
onChange={setParameters}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2 pt-4 border-t">
|
||||
<Label>Custom Headers</Label>
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
Add custom headers to include in the request (optional)
|
||||
</Label>
|
||||
<KeyValueEditor
|
||||
items={headers}
|
||||
onChange={setHeaders}
|
||||
keyPlaceholder="Header name"
|
||||
valuePlaceholder="Header value"
|
||||
addButtonText="Add Header"
|
||||
/>
|
||||
</div>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Code View Dialog */}
|
||||
<Dialog open={showCodeDialog} onOpenChange={setShowCodeDialog}>
|
||||
<DialogContent className="max-w-2xl">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Code Preview</DialogTitle>
|
||||
<DialogDescription>
|
||||
JavaScript code to make this API call
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="bg-muted rounded-lg p-4 font-mono text-sm overflow-auto max-h-96">
|
||||
<pre>{getCodeSnippet()}</pre>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
431
ui/src/app/tools/page.tsx
Normal file
431
ui/src/app/tools/page.tsx
Normal file
|
|
@ -0,0 +1,431 @@
|
|||
"use client";
|
||||
|
||||
import { Globe, Plus, Search, Trash2 } from "lucide-react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import {
|
||||
createToolApiV1ToolsPost,
|
||||
deleteToolApiV1ToolsToolUuidDelete,
|
||||
listToolsApiV1ToolsGet,
|
||||
} from "@/client/sdk.gen";
|
||||
import type { ToolResponse } from "@/client/types.gen";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
type ToolCategory = "http_api" | "native" | "integration";
|
||||
|
||||
const TOOL_CATEGORIES: { value: ToolCategory; label: string; description: string; disabled?: boolean }[] = [
|
||||
{
|
||||
value: "http_api",
|
||||
label: "External HTTP API",
|
||||
description: "Make HTTP requests to external APIs",
|
||||
},
|
||||
{
|
||||
value: "native",
|
||||
label: "Native (Coming Soon)",
|
||||
description: "Built-in tools like call transfer, DTMF input",
|
||||
disabled: true,
|
||||
},
|
||||
{
|
||||
value: "integration",
|
||||
label: "Integration (Coming Soon)",
|
||||
description: "Third-party integrations like Google Calendar",
|
||||
disabled: true,
|
||||
},
|
||||
];
|
||||
|
||||
export default function ToolsPage() {
|
||||
const { user, getAccessToken, redirectToLogin, loading } = useAuth();
|
||||
const router = useRouter();
|
||||
|
||||
const [tools, setTools] = useState<ToolResponse[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [searchQuery, setSearchQuery] = useState("");
|
||||
const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false);
|
||||
const [newToolName, setNewToolName] = useState("");
|
||||
const [newToolDescription, setNewToolDescription] = useState("");
|
||||
const [newToolCategory, setNewToolCategory] = useState<ToolCategory>("http_api");
|
||||
const [isCreating, setIsCreating] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
// Redirect if not authenticated
|
||||
useEffect(() => {
|
||||
if (!loading && !user) {
|
||||
redirectToLogin();
|
||||
}
|
||||
}, [loading, user, redirectToLogin]);
|
||||
|
||||
const fetchTools = useCallback(async () => {
|
||||
if (loading || !user) return;
|
||||
|
||||
try {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
const accessToken = await getAccessToken();
|
||||
|
||||
const response = await listToolsApiV1ToolsGet({
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.data) {
|
||||
setTools(response.data);
|
||||
}
|
||||
} catch (err) {
|
||||
setError("Failed to fetch tools");
|
||||
console.error("Error fetching tools:", err);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [loading, user, getAccessToken]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchTools();
|
||||
}, [fetchTools]);
|
||||
|
||||
const handleCreateTool = async () => {
|
||||
if (!newToolName.trim()) {
|
||||
setError("Please enter a name for the tool");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setIsCreating(true);
|
||||
setError(null);
|
||||
const accessToken = await getAccessToken();
|
||||
|
||||
const response = await createToolApiV1ToolsPost({
|
||||
body: {
|
||||
name: newToolName,
|
||||
description: newToolDescription || undefined,
|
||||
category: newToolCategory,
|
||||
icon: "globe",
|
||||
icon_color: "#3B82F6",
|
||||
definition: {
|
||||
schema_version: 1,
|
||||
type: newToolCategory,
|
||||
config: {
|
||||
method: "POST",
|
||||
url: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.data) {
|
||||
setIsCreateDialogOpen(false);
|
||||
setNewToolName("");
|
||||
setNewToolDescription("");
|
||||
setNewToolCategory("http_api");
|
||||
// Navigate to the new tool's detail page
|
||||
router.push(`/tools/${response.data.tool_uuid}`);
|
||||
}
|
||||
} catch (err) {
|
||||
setError("Failed to create tool");
|
||||
console.error("Error creating tool:", err);
|
||||
} finally {
|
||||
setIsCreating(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDeleteTool = async (toolUuid: string, e: React.MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
if (!confirm("Are you sure you want to archive this tool?")) return;
|
||||
|
||||
try {
|
||||
setError(null);
|
||||
const accessToken = await getAccessToken();
|
||||
|
||||
await deleteToolApiV1ToolsToolUuidDelete({
|
||||
path: {
|
||||
tool_uuid: toolUuid,
|
||||
},
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
fetchTools();
|
||||
} catch (err) {
|
||||
setError("Failed to delete tool");
|
||||
console.error("Error deleting tool:", err);
|
||||
}
|
||||
};
|
||||
|
||||
const filteredTools = tools.filter(
|
||||
(tool) =>
|
||||
tool.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
|
||||
tool.description?.toLowerCase().includes(searchQuery.toLowerCase())
|
||||
);
|
||||
|
||||
const getCategoryBadge = (category: string) => {
|
||||
switch (category) {
|
||||
case "http_api":
|
||||
return <Badge variant="default">HTTP API</Badge>;
|
||||
case "native":
|
||||
return <Badge variant="secondary">Native</Badge>;
|
||||
case "integration":
|
||||
return <Badge variant="outline">Integration</Badge>;
|
||||
default:
|
||||
return <Badge variant="outline">{category}</Badge>;
|
||||
}
|
||||
};
|
||||
|
||||
const getStatusBadge = (status: string) => {
|
||||
switch (status) {
|
||||
case "active":
|
||||
return <Badge className="bg-green-500">Active</Badge>;
|
||||
case "draft":
|
||||
return <Badge variant="secondary">Draft</Badge>;
|
||||
case "archived":
|
||||
return <Badge variant="destructive">Archived</Badge>;
|
||||
default:
|
||||
return <Badge variant="outline">{status}</Badge>;
|
||||
}
|
||||
};
|
||||
|
||||
if (loading || !user) {
|
||||
return (
|
||||
<div className="min-h-screen bg-background flex items-center justify-center">
|
||||
<div className="space-y-4">
|
||||
<Skeleton className="h-12 w-64" />
|
||||
<Skeleton className="h-64 w-96" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-background">
|
||||
<div className="container mx-auto px-4 py-8">
|
||||
<div className="max-w-6xl mx-auto">
|
||||
<div className="mb-8">
|
||||
<h1 className="text-3xl font-bold mb-2">Tools</h1>
|
||||
<p className="text-muted-foreground">
|
||||
Manage reusable HTTP API tools that can be used across your workflows
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="mb-4 p-4 bg-destructive/10 border border-destructive/20 rounded-lg text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Card className="mb-6">
|
||||
<CardHeader>
|
||||
<div className="flex justify-between items-center">
|
||||
<div>
|
||||
<CardTitle>Your Tools</CardTitle>
|
||||
<CardDescription>
|
||||
Create and manage HTTP API tools for your organization
|
||||
</CardDescription>
|
||||
</div>
|
||||
<Button onClick={() => setIsCreateDialogOpen(true)}>
|
||||
<Plus className="w-4 h-4 mr-2" />
|
||||
Create Tool
|
||||
</Button>
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{/* Search */}
|
||||
<div className="relative mb-4">
|
||||
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 h-4 w-4 text-muted-foreground" />
|
||||
<Input
|
||||
placeholder="Search tools..."
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
className="pl-10"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{isLoading ? (
|
||||
<div className="space-y-4">
|
||||
{[1, 2, 3].map((i) => (
|
||||
<div
|
||||
key={i}
|
||||
className="flex items-center justify-between p-4 border rounded-lg"
|
||||
>
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-4 w-32" />
|
||||
<Skeleton className="h-3 w-48" />
|
||||
</div>
|
||||
<Skeleton className="h-8 w-20" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
) : filteredTools.length === 0 ? (
|
||||
<div className="text-center py-12">
|
||||
<Globe className="w-12 h-12 text-muted-foreground mx-auto mb-4" />
|
||||
<p className="text-muted-foreground mb-4">
|
||||
{searchQuery
|
||||
? "No tools match your search"
|
||||
: "No tools found"}
|
||||
</p>
|
||||
{!searchQuery && (
|
||||
<Button onClick={() => setIsCreateDialogOpen(true)}>
|
||||
Create Your First Tool
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-4">
|
||||
{filteredTools.map((tool) => (
|
||||
<div
|
||||
key={tool.tool_uuid}
|
||||
className="flex items-center justify-between p-4 border rounded-lg hover:bg-muted/50 cursor-pointer transition-colors"
|
||||
onClick={() =>
|
||||
router.push(`/tools/${tool.tool_uuid}`)
|
||||
}
|
||||
>
|
||||
<div className="flex items-center gap-4">
|
||||
<div
|
||||
className="w-10 h-10 rounded-lg flex items-center justify-center"
|
||||
style={{
|
||||
backgroundColor:
|
||||
tool.icon_color || "#3B82F6",
|
||||
}}
|
||||
>
|
||||
<Globe className="w-5 h-5 text-white" />
|
||||
</div>
|
||||
<div>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-medium">
|
||||
{tool.name}
|
||||
</span>
|
||||
{getCategoryBadge(tool.category)}
|
||||
{getStatusBadge(tool.status)}
|
||||
</div>
|
||||
{tool.description && (
|
||||
<p className="text-sm text-muted-foreground mt-1">
|
||||
{tool.description}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={(e) =>
|
||||
handleDeleteTool(tool.tool_uuid, e)
|
||||
}
|
||||
className="text-destructive hover:text-destructive/90"
|
||||
>
|
||||
<Trash2 className="w-4 h-4" />
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Create Tool Dialog */}
|
||||
<Dialog open={isCreateDialogOpen} onOpenChange={setIsCreateDialogOpen}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Create New Tool</DialogTitle>
|
||||
<DialogDescription>
|
||||
Create a new tool that can be used in your workflows.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="grid gap-4 py-4">
|
||||
<div className="grid gap-2">
|
||||
<Label>Tool Type</Label>
|
||||
<Select
|
||||
value={newToolCategory}
|
||||
onValueChange={(v) => setNewToolCategory(v as ToolCategory)}
|
||||
>
|
||||
<SelectTrigger className="w-full">
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{TOOL_CATEGORIES.map((category) => (
|
||||
<SelectItem
|
||||
key={category.value}
|
||||
value={category.value}
|
||||
disabled={category.disabled}
|
||||
>
|
||||
{category.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{TOOL_CATEGORIES.find(c => c.value === newToolCategory)?.description}
|
||||
</p>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="name">Tool Name</Label>
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
Use a descriptive name, like "Get Weather using API" for a tool that fetches weather
|
||||
</Label>
|
||||
<Input
|
||||
id="name"
|
||||
value={newToolName}
|
||||
onChange={(e) => setNewToolName(e.target.value)}
|
||||
placeholder="e.g., Book Appointment, Check Inventory"
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="description">Description (Optional)</Label>
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
Provide a description which makes it easy for LLM to understand what this tool does
|
||||
</Label>
|
||||
<Input
|
||||
id="description"
|
||||
value={newToolDescription}
|
||||
onChange={(e) => setNewToolDescription(e.target.value)}
|
||||
placeholder="What does this tool do?"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => setIsCreateDialogOpen(false)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleCreateTool} disabled={isCreating}>
|
||||
{isCreating ? "Creating..." : "Create Tool"}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -221,7 +221,7 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
|
|||
</ReactFlow>
|
||||
|
||||
{/* Bottom-left controls - horizontal layout with custom buttons */}
|
||||
<div className="absolute bottom-12 left-8 z-[1000] flex gap-2">
|
||||
<div className="absolute bottom-12 left-8 z-10 flex gap-2">
|
||||
<TooltipProvider>
|
||||
{/* Zoom In */}
|
||||
<Tooltip>
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ export default function WorkflowRunPage() {
|
|||
<svg className="h-4 w-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth="2" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" />
|
||||
</svg>
|
||||
Back to Agent
|
||||
Customize Agent
|
||||
</Button>
|
||||
</Link>
|
||||
</CardHeader>
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -144,6 +144,18 @@ export type CreateTestSessionRequest = {
|
|||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Request schema for creating a tool.
|
||||
*/
|
||||
export type CreateToolRequest = {
|
||||
name: string;
|
||||
description?: string | null;
|
||||
category?: string;
|
||||
icon?: string | null;
|
||||
icon_color?: string | null;
|
||||
definition: ToolDefinition;
|
||||
};
|
||||
|
||||
export type CreateWorkflowRequest = {
|
||||
name: string;
|
||||
workflow_definition: {
|
||||
|
|
@ -174,6 +186,14 @@ export type CreateWorkflowTemplateRequest = {
|
|||
activity_description: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Response schema for the user who created a tool.
|
||||
*/
|
||||
export type CreatedByResponse = {
|
||||
id: number;
|
||||
provider_id: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Response schema for a webhook credential (never includes sensitive data).
|
||||
*/
|
||||
|
|
@ -309,6 +329,52 @@ export type HttpValidationError = {
|
|||
detail?: Array<ValidationError>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Configuration for HTTP API tools.
|
||||
*/
|
||||
export type HttpApiConfig = {
|
||||
/**
|
||||
* HTTP method (GET, POST, PUT, PATCH, DELETE)
|
||||
*/
|
||||
method: string;
|
||||
/**
|
||||
* Target URL (supports {{variable}} placeholders)
|
||||
*/
|
||||
url: string;
|
||||
/**
|
||||
* Static headers to include
|
||||
*/
|
||||
headers?: {
|
||||
[key: string]: string;
|
||||
} | null;
|
||||
/**
|
||||
* Reference to ExternalCredentialModel for auth
|
||||
*/
|
||||
credential_uuid?: string | null;
|
||||
/**
|
||||
* Request body with {{variable}} placeholders
|
||||
*/
|
||||
body_template?: {
|
||||
[key: string]: unknown;
|
||||
} | null;
|
||||
/**
|
||||
* Request timeout in milliseconds
|
||||
*/
|
||||
timeout_ms?: number | null;
|
||||
/**
|
||||
* Retry configuration
|
||||
*/
|
||||
retry_config?: {
|
||||
[key: string]: unknown;
|
||||
} | null;
|
||||
/**
|
||||
* JSONPath mappings for response extraction
|
||||
*/
|
||||
response_mapping?: {
|
||||
[key: string]: string;
|
||||
} | null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Request payload for superadmin impersonation.
|
||||
*
|
||||
|
|
@ -500,6 +566,44 @@ export type TestSessionResponse = {
|
|||
completed_at: string | null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Tool definition schema.
|
||||
*/
|
||||
export type ToolDefinition = {
|
||||
/**
|
||||
* Schema version for compatibility
|
||||
*/
|
||||
schema_version?: number;
|
||||
/**
|
||||
* Tool type (http_api)
|
||||
*/
|
||||
type: string;
|
||||
/**
|
||||
* Tool configuration
|
||||
*/
|
||||
config: HttpApiConfig;
|
||||
};
|
||||
|
||||
/**
|
||||
* Response schema for a tool.
|
||||
*/
|
||||
export type ToolResponse = {
|
||||
id: number;
|
||||
tool_uuid: string;
|
||||
name: string;
|
||||
description: string | null;
|
||||
category: string;
|
||||
icon: string | null;
|
||||
icon_color: string | null;
|
||||
status: string;
|
||||
definition: {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
created_at: string;
|
||||
updated_at: string | null;
|
||||
created_by?: CreatedByResponse | null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Request model for triggering a call via API
|
||||
*/
|
||||
|
|
@ -566,6 +670,18 @@ export type UpdateIntegrationRequest = {
|
|||
}>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Request schema for updating a tool.
|
||||
*/
|
||||
export type UpdateToolRequest = {
|
||||
name?: string | null;
|
||||
description?: string | null;
|
||||
icon?: string | null;
|
||||
icon_color?: string | null;
|
||||
definition?: ToolDefinition | null;
|
||||
status?: string | null;
|
||||
};
|
||||
|
||||
export type UpdateWorkflowRequest = {
|
||||
name: string;
|
||||
workflow_definition?: {
|
||||
|
|
@ -2347,6 +2463,177 @@ export type UpdateCredentialApiV1CredentialsCredentialUuidPutResponses = {
|
|||
|
||||
export type UpdateCredentialApiV1CredentialsCredentialUuidPutResponse = UpdateCredentialApiV1CredentialsCredentialUuidPutResponses[keyof UpdateCredentialApiV1CredentialsCredentialUuidPutResponses];
|
||||
|
||||
export type ListToolsApiV1ToolsGetData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
};
|
||||
path?: never;
|
||||
query?: {
|
||||
status?: string | null;
|
||||
category?: string | null;
|
||||
};
|
||||
url: '/api/v1/tools/';
|
||||
};
|
||||
|
||||
export type ListToolsApiV1ToolsGetErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type ListToolsApiV1ToolsGetError = ListToolsApiV1ToolsGetErrors[keyof ListToolsApiV1ToolsGetErrors];
|
||||
|
||||
export type ListToolsApiV1ToolsGetResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: Array<ToolResponse>;
|
||||
};
|
||||
|
||||
export type ListToolsApiV1ToolsGetResponse = ListToolsApiV1ToolsGetResponses[keyof ListToolsApiV1ToolsGetResponses];
|
||||
|
||||
export type CreateToolApiV1ToolsPostData = {
|
||||
body: CreateToolRequest;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
};
|
||||
path?: never;
|
||||
query?: never;
|
||||
url: '/api/v1/tools/';
|
||||
};
|
||||
|
||||
export type CreateToolApiV1ToolsPostErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type CreateToolApiV1ToolsPostError = CreateToolApiV1ToolsPostErrors[keyof CreateToolApiV1ToolsPostErrors];
|
||||
|
||||
export type CreateToolApiV1ToolsPostResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: ToolResponse;
|
||||
};
|
||||
|
||||
export type CreateToolApiV1ToolsPostResponse = CreateToolApiV1ToolsPostResponses[keyof CreateToolApiV1ToolsPostResponses];
|
||||
|
||||
export type DeleteToolApiV1ToolsToolUuidDeleteData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
};
|
||||
path: {
|
||||
tool_uuid: string;
|
||||
};
|
||||
query?: never;
|
||||
url: '/api/v1/tools/{tool_uuid}';
|
||||
};
|
||||
|
||||
export type DeleteToolApiV1ToolsToolUuidDeleteErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type DeleteToolApiV1ToolsToolUuidDeleteError = DeleteToolApiV1ToolsToolUuidDeleteErrors[keyof DeleteToolApiV1ToolsToolUuidDeleteErrors];
|
||||
|
||||
export type DeleteToolApiV1ToolsToolUuidDeleteResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
};
|
||||
|
||||
export type DeleteToolApiV1ToolsToolUuidDeleteResponse = DeleteToolApiV1ToolsToolUuidDeleteResponses[keyof DeleteToolApiV1ToolsToolUuidDeleteResponses];
|
||||
|
||||
export type GetToolApiV1ToolsToolUuidGetData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
};
|
||||
path: {
|
||||
tool_uuid: string;
|
||||
};
|
||||
query?: never;
|
||||
url: '/api/v1/tools/{tool_uuid}';
|
||||
};
|
||||
|
||||
export type GetToolApiV1ToolsToolUuidGetErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type GetToolApiV1ToolsToolUuidGetError = GetToolApiV1ToolsToolUuidGetErrors[keyof GetToolApiV1ToolsToolUuidGetErrors];
|
||||
|
||||
export type GetToolApiV1ToolsToolUuidGetResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: ToolResponse;
|
||||
};
|
||||
|
||||
export type GetToolApiV1ToolsToolUuidGetResponse = GetToolApiV1ToolsToolUuidGetResponses[keyof GetToolApiV1ToolsToolUuidGetResponses];
|
||||
|
||||
export type UpdateToolApiV1ToolsToolUuidPutData = {
|
||||
body: UpdateToolRequest;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
};
|
||||
path: {
|
||||
tool_uuid: string;
|
||||
};
|
||||
query?: never;
|
||||
url: '/api/v1/tools/{tool_uuid}';
|
||||
};
|
||||
|
||||
export type UpdateToolApiV1ToolsToolUuidPutErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type UpdateToolApiV1ToolsToolUuidPutError = UpdateToolApiV1ToolsToolUuidPutErrors[keyof UpdateToolApiV1ToolsToolUuidPutErrors];
|
||||
|
||||
export type UpdateToolApiV1ToolsToolUuidPutResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: ToolResponse;
|
||||
};
|
||||
|
||||
export type UpdateToolApiV1ToolsToolUuidPutResponse = UpdateToolApiV1ToolsToolUuidPutResponses[keyof UpdateToolApiV1ToolsToolUuidPutResponses];
|
||||
|
||||
export type GetIntegrationsApiV1IntegrationGetData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
|
|
|
|||
64
ui/src/components/flow/ToolBadges.tsx
Normal file
64
ui/src/components/flow/ToolBadges.tsx
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import { listToolsApiV1ToolsGet } from "@/client/sdk.gen";
|
||||
import type { ToolResponse } from "@/client/types.gen";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
interface ToolBadgesProps {
|
||||
toolUuids: string[];
|
||||
}
|
||||
|
||||
export function ToolBadges({ toolUuids }: ToolBadgesProps) {
|
||||
const { getAccessToken } = useAuth();
|
||||
const [tools, setTools] = useState<ToolResponse[]>([]);
|
||||
|
||||
const fetchTools = useCallback(async () => {
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
const response = await listToolsApiV1ToolsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
});
|
||||
if (response.data) {
|
||||
setTools(response.data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch tools:", error);
|
||||
}
|
||||
}, [getAccessToken]);
|
||||
|
||||
useEffect(() => {
|
||||
if (toolUuids.length > 0) {
|
||||
fetchTools();
|
||||
}
|
||||
}, [toolUuids.length, fetchTools]);
|
||||
|
||||
const selectedTools = tools.filter((tool) => toolUuids.includes(tool.tool_uuid));
|
||||
|
||||
if (selectedTools.length === 0 && toolUuids.length > 0) {
|
||||
// Still loading or tools not found
|
||||
return (
|
||||
<div className="flex flex-wrap gap-1">
|
||||
<Badge variant="outline" className="text-xs">
|
||||
Loading...
|
||||
</Badge>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{selectedTools.map((tool) => (
|
||||
<Badge
|
||||
key={tool.tool_uuid}
|
||||
variant="outline"
|
||||
className="text-xs"
|
||||
>
|
||||
{tool.name}
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
161
ui/src/components/flow/ToolSelector.tsx
Normal file
161
ui/src/components/flow/ToolSelector.tsx
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
"use client";
|
||||
|
||||
import { ExternalLink, Globe, Loader2 } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import { listToolsApiV1ToolsGet } from "@/client/sdk.gen";
|
||||
import type { ToolResponse } from "@/client/types.gen";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
interface ToolSelectorProps {
|
||||
value: string[];
|
||||
onChange: (uuids: string[]) => void;
|
||||
disabled?: boolean;
|
||||
label?: string;
|
||||
description?: string;
|
||||
showLabel?: boolean;
|
||||
}
|
||||
|
||||
export function ToolSelector({
|
||||
value,
|
||||
onChange,
|
||||
disabled = false,
|
||||
label = "Tools",
|
||||
description = "Select tools that the agent can use during the conversation.",
|
||||
showLabel = true,
|
||||
}: ToolSelectorProps) {
|
||||
const { getAccessToken } = useAuth();
|
||||
|
||||
const [tools, setTools] = useState<ToolResponse[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const fetchTools = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
const response = await listToolsApiV1ToolsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
query: { status: "active" },
|
||||
});
|
||||
if (response.error) {
|
||||
console.error("Failed to fetch tools:", response.error);
|
||||
setTools([]);
|
||||
return;
|
||||
}
|
||||
if (response.data) {
|
||||
setTools(response.data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch tools:", error);
|
||||
setTools([]);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [getAccessToken]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchTools();
|
||||
}, [fetchTools]);
|
||||
|
||||
const handleToggle = (toolUuid: string, checked: boolean) => {
|
||||
if (checked) {
|
||||
onChange([...value, toolUuid]);
|
||||
} else {
|
||||
onChange(value.filter((id) => id !== toolUuid));
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="grid gap-2">
|
||||
{showLabel && (
|
||||
<>
|
||||
<Label>{label}</Label>
|
||||
{description && (
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
{description}
|
||||
</Label>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{loading ? (
|
||||
<div className="flex items-center gap-2 p-3 border rounded-md">
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
<span className="text-sm text-muted-foreground">Loading tools...</span>
|
||||
</div>
|
||||
) : tools.length === 0 ? (
|
||||
<div className="p-4 border rounded-md text-center">
|
||||
<p className="text-sm text-muted-foreground mb-2">
|
||||
No tools available.
|
||||
</p>
|
||||
<Button variant="outline" size="sm" asChild>
|
||||
<Link href="/tools" target="_blank">
|
||||
<ExternalLink className="h-4 w-4 mr-2" />
|
||||
Create a Tool
|
||||
</Link>
|
||||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="border rounded-md divide-y">
|
||||
{tools.map((tool) => {
|
||||
const isSelected = value.includes(tool.tool_uuid);
|
||||
return (
|
||||
<label
|
||||
key={tool.tool_uuid}
|
||||
className={`flex items-center gap-3 p-3 cursor-pointer hover:bg-muted/50 ${
|
||||
disabled ? "opacity-50 cursor-not-allowed" : ""
|
||||
}`}
|
||||
>
|
||||
<Checkbox
|
||||
checked={isSelected}
|
||||
disabled={disabled}
|
||||
onCheckedChange={(checked) => {
|
||||
handleToggle(tool.tool_uuid, checked === true);
|
||||
}}
|
||||
/>
|
||||
<div
|
||||
className="w-6 h-6 rounded flex items-center justify-center shrink-0"
|
||||
style={{
|
||||
backgroundColor: tool.icon_color || "#3B82F6",
|
||||
}}
|
||||
>
|
||||
<Globe className="h-3 w-3 text-white" />
|
||||
</div>
|
||||
<div className="flex flex-col min-w-0 flex-1">
|
||||
<span className="text-sm font-medium truncate">
|
||||
{tool.name}
|
||||
</span>
|
||||
{tool.description && (
|
||||
<span className="text-xs text-muted-foreground truncate">
|
||||
{tool.description}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</label>
|
||||
);
|
||||
})}
|
||||
<div className="p-2 bg-muted/30">
|
||||
<Link
|
||||
href="/tools"
|
||||
target="_blank"
|
||||
className="flex items-center gap-2 text-sm text-muted-foreground hover:text-foreground"
|
||||
>
|
||||
<ExternalLink className="h-4 w-4" />
|
||||
Manage Tools
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{value.length > 0 && (
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{value.length} tool{value.length !== 1 ? "s" : ""} selected
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
|
||||
import { Edit, Headset, PlusIcon,Trash2Icon } from "lucide-react";
|
||||
import { Edit, Headset, PlusIcon, Trash2Icon, Wrench } from "lucide-react";
|
||||
import { memo, useEffect, useState } from "react";
|
||||
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import { ExtractionVariable,FlowNodeData } from "@/components/flow/types";
|
||||
import { ToolBadges } from "@/components/flow/ToolBadges";
|
||||
import { ToolSelector } from "@/components/flow/ToolSelector";
|
||||
import { ExtractionVariable, FlowNodeData } from "@/components/flow/types";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
|
|
@ -30,6 +32,8 @@ interface AgentNodeEditFormProps {
|
|||
setVariables: (vars: ExtractionVariable[]) => void;
|
||||
addGlobalPrompt: boolean;
|
||||
setAddGlobalPrompt: (value: boolean) => void;
|
||||
toolUuids: string[];
|
||||
setToolUuids: (value: string[]) => void;
|
||||
}
|
||||
|
||||
interface AgentNodeProps extends NodeProps {
|
||||
|
|
@ -50,6 +54,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
const [extractionPrompt, setExtractionPrompt] = useState(data.extraction_prompt ?? "");
|
||||
const [variables, setVariables] = useState<ExtractionVariable[]>(data.extraction_variables ?? []);
|
||||
const [addGlobalPrompt, setAddGlobalPrompt] = useState(data.add_global_prompt ?? true);
|
||||
const [toolUuids, setToolUuids] = useState<string[]>(data.tool_uuids ?? []);
|
||||
|
||||
const handleSave = async () => {
|
||||
handleSaveNodeData({
|
||||
|
|
@ -61,6 +66,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
extraction_prompt: extractionPrompt,
|
||||
extraction_variables: variables,
|
||||
add_global_prompt: addGlobalPrompt,
|
||||
tool_uuids: toolUuids.length > 0 ? toolUuids : undefined,
|
||||
});
|
||||
setOpen(false);
|
||||
// Save the workflow after updating node data with a small delay to ensure state is updated
|
||||
|
|
@ -79,6 +85,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
setExtractionPrompt(data.extraction_prompt ?? "");
|
||||
setVariables(data.extraction_variables ?? []);
|
||||
setAddGlobalPrompt(data.add_global_prompt ?? true);
|
||||
setToolUuids(data.tool_uuids ?? []);
|
||||
}
|
||||
setOpen(newOpen);
|
||||
};
|
||||
|
|
@ -93,6 +100,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
setExtractionPrompt(data.extraction_prompt ?? "");
|
||||
setVariables(data.extraction_variables ?? []);
|
||||
setAddGlobalPrompt(data.add_global_prompt ?? true);
|
||||
setToolUuids(data.tool_uuids ?? []);
|
||||
}
|
||||
}, [data, open]);
|
||||
|
||||
|
|
@ -114,6 +122,15 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
<p className="text-sm text-muted-foreground line-clamp-5 leading-relaxed">
|
||||
{data.prompt || 'No prompt configured'}
|
||||
</p>
|
||||
{data.tool_uuids && data.tool_uuids.length > 0 && (
|
||||
<div className="mt-3 pt-3 border-t border-border/50">
|
||||
<div className="flex items-center gap-1.5 text-xs text-muted-foreground mb-2">
|
||||
<Wrench className="h-3 w-3" />
|
||||
<span>Tools:</span>
|
||||
</div>
|
||||
<ToolBadges toolUuids={data.tool_uuids} />
|
||||
</div>
|
||||
)}
|
||||
</NodeContent>
|
||||
|
||||
<NodeToolbar isVisible={selected} position={Position.Right}>
|
||||
|
|
@ -151,6 +168,8 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
setVariables={setVariables}
|
||||
addGlobalPrompt={addGlobalPrompt}
|
||||
setAddGlobalPrompt={setAddGlobalPrompt}
|
||||
toolUuids={toolUuids}
|
||||
setToolUuids={setToolUuids}
|
||||
/>
|
||||
)}
|
||||
</NodeEditDialog>
|
||||
|
|
@ -173,6 +192,8 @@ const AgentNodeEditForm = ({
|
|||
setVariables,
|
||||
addGlobalPrompt,
|
||||
setAddGlobalPrompt,
|
||||
toolUuids,
|
||||
setToolUuids,
|
||||
}: AgentNodeEditFormProps) => {
|
||||
const handleVariableNameChange = (idx: number, value: string) => {
|
||||
const newVars = [...variables];
|
||||
|
|
@ -307,6 +328,15 @@ const AgentNodeEditForm = ({
|
|||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Tools Section */}
|
||||
<div className="pt-4 border-t mt-4">
|
||||
<ToolSelector
|
||||
value={toolUuids}
|
||||
onChange={setToolUuids}
|
||||
description="Select tools that the agent can invoke during this conversation step."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
|
||||
import { Edit, Play } from "lucide-react";
|
||||
import { Edit, Play, Wrench } from "lucide-react";
|
||||
import { memo, useEffect, useState } from "react";
|
||||
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import { ToolBadges } from "@/components/flow/ToolBadges";
|
||||
import { ToolSelector } from "@/components/flow/ToolSelector";
|
||||
import { FlowNodeData } from "@/components/flow/types";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
|
|
@ -31,6 +33,8 @@ interface StartCallEditFormProps {
|
|||
setDelayedStart: (value: boolean) => void;
|
||||
delayedStartDuration: number;
|
||||
setDelayedStartDuration: (value: number) => void;
|
||||
toolUuids: string[];
|
||||
setToolUuids: (value: string[]) => void;
|
||||
}
|
||||
|
||||
interface StartCallNodeProps extends NodeProps {
|
||||
|
|
@ -52,6 +56,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
const [detectVoicemail, setDetectVoicemail] = useState(data.detect_voicemail ?? false);
|
||||
const [delayedStart, setDelayedStart] = useState(data.delayed_start ?? false);
|
||||
const [delayedStartDuration, setDelayedStartDuration] = useState(data.delayed_start_duration ?? 2);
|
||||
const [toolUuids, setToolUuids] = useState<string[]>(data.tool_uuids ?? []);
|
||||
|
||||
const handleSave = async () => {
|
||||
handleSaveNodeData({
|
||||
|
|
@ -62,7 +67,8 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
add_global_prompt: addGlobalPrompt,
|
||||
detect_voicemail: detectVoicemail,
|
||||
delayed_start: delayedStart,
|
||||
delayed_start_duration: delayedStart ? delayedStartDuration : undefined
|
||||
delayed_start_duration: delayedStart ? delayedStartDuration : undefined,
|
||||
tool_uuids: toolUuids.length > 0 ? toolUuids : undefined,
|
||||
});
|
||||
setOpen(false);
|
||||
// Save the workflow after updating node data with a small delay to ensure state is updated
|
||||
|
|
@ -81,6 +87,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
setDetectVoicemail(data.detect_voicemail ?? false);
|
||||
setDelayedStart(data.delayed_start ?? false);
|
||||
setDelayedStartDuration(data.delayed_start_duration ?? 3);
|
||||
setToolUuids(data.tool_uuids ?? []);
|
||||
}
|
||||
setOpen(newOpen);
|
||||
};
|
||||
|
|
@ -95,6 +102,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
setDetectVoicemail(data.detect_voicemail ?? false);
|
||||
setDelayedStart(data.delayed_start ?? false);
|
||||
setDelayedStartDuration(data.delayed_start_duration ?? 3);
|
||||
setToolUuids(data.tool_uuids ?? []);
|
||||
}
|
||||
}, [data, open]);
|
||||
|
||||
|
|
@ -115,6 +123,15 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
<p className="text-sm text-muted-foreground line-clamp-5 leading-relaxed">
|
||||
{data.prompt || 'No prompt configured'}
|
||||
</p>
|
||||
{data.tool_uuids && data.tool_uuids.length > 0 && (
|
||||
<div className="mt-3 pt-3 border-t border-border/50">
|
||||
<div className="flex items-center gap-1.5 text-xs text-muted-foreground mb-2">
|
||||
<Wrench className="h-3 w-3" />
|
||||
<span>Tools:</span>
|
||||
</div>
|
||||
<ToolBadges toolUuids={data.tool_uuids} />
|
||||
</div>
|
||||
)}
|
||||
</NodeContent>
|
||||
|
||||
<NodeToolbar isVisible={selected} position={Position.Right}>
|
||||
|
|
@ -147,6 +164,8 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
setDelayedStart={setDelayedStart}
|
||||
delayedStartDuration={delayedStartDuration}
|
||||
setDelayedStartDuration={setDelayedStartDuration}
|
||||
toolUuids={toolUuids}
|
||||
setToolUuids={setToolUuids}
|
||||
/>
|
||||
)}
|
||||
</NodeEditDialog>
|
||||
|
|
@ -168,7 +187,9 @@ const StartCallEditForm = ({
|
|||
delayedStart,
|
||||
setDelayedStart,
|
||||
delayedStartDuration,
|
||||
setDelayedStartDuration
|
||||
setDelayedStartDuration,
|
||||
toolUuids,
|
||||
setToolUuids,
|
||||
}: StartCallEditFormProps) => {
|
||||
return (
|
||||
<div className="grid gap-2">
|
||||
|
|
@ -258,6 +279,15 @@ const StartCallEditForm = ({
|
|||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Tools Section */}
|
||||
<div className="pt-4 border-t mt-4">
|
||||
<ToolSelector
|
||||
value={toolUuids}
|
||||
onChange={setToolUuids}
|
||||
description="Select tools that the agent can invoke during this conversation step."
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,36 +1,22 @@
|
|||
import { NodeProps, NodeToolbar, Position } from "@xyflow/react";
|
||||
import { AlertCircle, Circle, Edit, Link2, Loader2, PlusIcon, Trash2Icon } from "lucide-react";
|
||||
import { memo, useCallback, useEffect, useState } from "react";
|
||||
import { Circle, Edit, Link2, Trash2Icon } from "lucide-react";
|
||||
import { memo, useEffect, useState } from "react";
|
||||
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import {
|
||||
createCredentialApiV1CredentialsPost,
|
||||
listCredentialsApiV1CredentialsGet,
|
||||
} from "@/client";
|
||||
import { CredentialResponse, WebhookCredentialType } from "@/client/types.gen";
|
||||
import { FlowNodeData } from "@/components/flow/types";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
CredentialSelector,
|
||||
type HttpMethod,
|
||||
HttpMethodSelector,
|
||||
KeyValueEditor,
|
||||
type KeyValueItem,
|
||||
} from "@/components/http";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { JsonEditor, validateJson } from "@/components/ui/json-editor";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
import { NodeContent } from "./common/NodeContent";
|
||||
import { NodeEditDialog } from "./common/NodeEditDialog";
|
||||
|
|
@ -40,17 +26,9 @@ interface WebhookNodeProps extends NodeProps {
|
|||
data: FlowNodeData;
|
||||
}
|
||||
|
||||
type HttpMethod = "GET" | "POST" | "PUT" | "PATCH" | "DELETE";
|
||||
|
||||
interface CustomHeader {
|
||||
key: string;
|
||||
value: string;
|
||||
}
|
||||
|
||||
export const WebhookNode = memo(({ data, selected, id }: WebhookNodeProps) => {
|
||||
const { open, setOpen, handleSaveNodeData, handleDeleteNode } = useNodeHandlers({ id });
|
||||
const { saveWorkflow } = useWorkflow();
|
||||
const { getAccessToken } = useAuth();
|
||||
|
||||
// Form state
|
||||
const [name, setName] = useState(data.name || "Webhook");
|
||||
|
|
@ -58,41 +36,13 @@ export const WebhookNode = memo(({ data, selected, id }: WebhookNodeProps) => {
|
|||
const [httpMethod, setHttpMethod] = useState<HttpMethod>(data.http_method || "POST");
|
||||
const [endpointUrl, setEndpointUrl] = useState(data.endpoint_url || "");
|
||||
const [credentialUuid, setCredentialUuid] = useState(data.credential_uuid || "");
|
||||
const [customHeaders, setCustomHeaders] = useState<CustomHeader[]>(
|
||||
const [customHeaders, setCustomHeaders] = useState<KeyValueItem[]>(
|
||||
data.custom_headers || []
|
||||
);
|
||||
const [payloadTemplate, setPayloadTemplate] = useState(
|
||||
data.payload_template ? JSON.stringify(data.payload_template, null, 2) : "{}"
|
||||
);
|
||||
|
||||
// Credentials state
|
||||
const [credentials, setCredentials] = useState<CredentialResponse[]>([]);
|
||||
const [credentialsLoading, setCredentialsLoading] = useState(false);
|
||||
|
||||
// Fetch credentials when dialog opens
|
||||
const fetchCredentials = useCallback(async () => {
|
||||
setCredentialsLoading(true);
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
const response = await listCredentialsApiV1CredentialsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
});
|
||||
if (response.error) {
|
||||
console.error("Failed to fetch credentials:", response.error);
|
||||
setCredentials([]);
|
||||
return;
|
||||
}
|
||||
if (response.data) {
|
||||
setCredentials(response.data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch credentials:", error);
|
||||
setCredentials([]);
|
||||
} finally {
|
||||
setCredentialsLoading(false);
|
||||
}
|
||||
}, [getAccessToken]);
|
||||
|
||||
// Validation state - only shown on save attempt
|
||||
const [jsonError, setJsonError] = useState<string | null>(null);
|
||||
const [endpointError, setEndpointError] = useState<string | null>(null);
|
||||
|
|
@ -143,8 +93,6 @@ export const WebhookNode = memo(({ data, selected, id }: WebhookNodeProps) => {
|
|||
// Clear any previous errors
|
||||
setJsonError(null);
|
||||
setEndpointError(null);
|
||||
// Fetch credentials when dialog opens
|
||||
fetchCredentials();
|
||||
}
|
||||
setOpen(newOpen);
|
||||
};
|
||||
|
|
@ -233,10 +181,6 @@ export const WebhookNode = memo(({ data, selected, id }: WebhookNodeProps) => {
|
|||
setEndpointUrl={setEndpointUrl}
|
||||
credentialUuid={credentialUuid}
|
||||
setCredentialUuid={setCredentialUuid}
|
||||
credentials={credentials}
|
||||
credentialsLoading={credentialsLoading}
|
||||
onRefreshCredentials={fetchCredentials}
|
||||
getAccessToken={getAccessToken}
|
||||
customHeaders={customHeaders}
|
||||
setCustomHeaders={setCustomHeaders}
|
||||
payloadTemplate={payloadTemplate}
|
||||
|
|
@ -259,16 +203,23 @@ interface WebhookNodeEditFormProps {
|
|||
setEndpointUrl: (value: string) => void;
|
||||
credentialUuid: string;
|
||||
setCredentialUuid: (value: string) => void;
|
||||
credentials: CredentialResponse[];
|
||||
credentialsLoading: boolean;
|
||||
onRefreshCredentials: () => Promise<void>;
|
||||
getAccessToken: () => Promise<string>;
|
||||
customHeaders: CustomHeader[];
|
||||
setCustomHeaders: (value: CustomHeader[]) => void;
|
||||
customHeaders: KeyValueItem[];
|
||||
setCustomHeaders: (value: KeyValueItem[]) => void;
|
||||
payloadTemplate: string;
|
||||
setPayloadTemplate: (value: string) => void;
|
||||
}
|
||||
|
||||
const availableVariables = [
|
||||
{ name: "workflow_run_id", description: "Unique ID of the workflow run" },
|
||||
{ name: "workflow_id", description: "ID of the workflow" },
|
||||
{ name: "workflow_name", description: "Name of the workflow" },
|
||||
{ name: "initial_context.*", description: "Initial context variables" },
|
||||
{ name: "gathered_context.*", description: "Extracted variables" },
|
||||
{ name: "cost_info.call_duration_seconds", description: "Call duration" },
|
||||
{ name: "recording_url", description: "Call recording URL" },
|
||||
{ name: "transcript_url", description: "Transcript URL" },
|
||||
];
|
||||
|
||||
const WebhookNodeEditForm = ({
|
||||
name,
|
||||
setName,
|
||||
|
|
@ -280,130 +231,11 @@ const WebhookNodeEditForm = ({
|
|||
setEndpointUrl,
|
||||
credentialUuid,
|
||||
setCredentialUuid,
|
||||
credentials,
|
||||
credentialsLoading,
|
||||
onRefreshCredentials,
|
||||
getAccessToken,
|
||||
customHeaders,
|
||||
setCustomHeaders,
|
||||
payloadTemplate,
|
||||
setPayloadTemplate,
|
||||
}: WebhookNodeEditFormProps) => {
|
||||
// Add Credential Dialog state
|
||||
const [isAddCredentialOpen, setIsAddCredentialOpen] = useState(false);
|
||||
const [newCredName, setNewCredName] = useState("");
|
||||
const [newCredDescription, setNewCredDescription] = useState("");
|
||||
const [newCredType, setNewCredType] = useState<WebhookCredentialType>("bearer_token");
|
||||
const [newCredData, setNewCredData] = useState<Record<string, string>>({});
|
||||
const [isCreatingCredential, setIsCreatingCredential] = useState(false);
|
||||
const [credentialError, setCredentialError] = useState<string | null>(null);
|
||||
|
||||
const handleCreateCredential = async () => {
|
||||
if (!newCredName.trim()) return;
|
||||
|
||||
setIsCreatingCredential(true);
|
||||
setCredentialError(null);
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
const response = await createCredentialApiV1CredentialsPost({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
body: {
|
||||
name: newCredName,
|
||||
description: newCredDescription || undefined,
|
||||
credential_type: newCredType,
|
||||
credential_data: newCredData,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.error) {
|
||||
const errorDetail = (response.error as { detail?: string })?.detail
|
||||
|| "Failed to create credential";
|
||||
setCredentialError(errorDetail);
|
||||
return;
|
||||
}
|
||||
|
||||
if (response.data) {
|
||||
// Refresh credentials list
|
||||
await onRefreshCredentials();
|
||||
// Select the newly created credential
|
||||
setCredentialUuid(response.data.uuid);
|
||||
// Close dialog and reset form
|
||||
setIsAddCredentialOpen(false);
|
||||
setNewCredName("");
|
||||
setNewCredDescription("");
|
||||
setNewCredType("bearer_token");
|
||||
setNewCredData({});
|
||||
setCredentialError(null);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to create credential:", error);
|
||||
setCredentialError(
|
||||
error instanceof Error ? error.message : "An unexpected error occurred"
|
||||
);
|
||||
} finally {
|
||||
setIsCreatingCredential(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleAddCredentialDialogChange = (open: boolean) => {
|
||||
setIsAddCredentialOpen(open);
|
||||
if (!open) {
|
||||
// Reset error when closing dialog
|
||||
setCredentialError(null);
|
||||
}
|
||||
};
|
||||
|
||||
const getCredentialDataFields = (type: WebhookCredentialType) => {
|
||||
switch (type) {
|
||||
case "api_key":
|
||||
return [
|
||||
{ key: "header_name", label: "Header Name", placeholder: "X-API-Key" },
|
||||
{ key: "api_key", label: "API Key", placeholder: "your-api-key", isSecret: true },
|
||||
];
|
||||
case "bearer_token":
|
||||
return [
|
||||
{ key: "token", label: "Token", placeholder: "your-bearer-token", isSecret: true },
|
||||
];
|
||||
case "basic_auth":
|
||||
return [
|
||||
{ key: "username", label: "Username", placeholder: "username" },
|
||||
{ key: "password", label: "Password", placeholder: "password", isSecret: true },
|
||||
];
|
||||
case "custom_header":
|
||||
return [
|
||||
{ key: "header_name", label: "Header Name", placeholder: "X-Custom-Header" },
|
||||
{ key: "header_value", label: "Header Value", placeholder: "header-value", isSecret: true },
|
||||
];
|
||||
default:
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
const addHeader = () => {
|
||||
setCustomHeaders([...customHeaders, { key: "", value: "" }]);
|
||||
};
|
||||
|
||||
const updateHeader = (index: number, field: "key" | "value", value: string) => {
|
||||
const newHeaders = [...customHeaders];
|
||||
newHeaders[index] = { ...newHeaders[index], [field]: value };
|
||||
setCustomHeaders(newHeaders);
|
||||
};
|
||||
|
||||
const removeHeader = (index: number) => {
|
||||
setCustomHeaders(customHeaders.filter((_, i) => i !== index));
|
||||
};
|
||||
|
||||
const availableVariables = [
|
||||
{ name: "workflow_run_id", description: "Unique ID of the workflow run" },
|
||||
{ name: "workflow_id", description: "ID of the workflow" },
|
||||
{ name: "workflow_name", description: "Name of the workflow" },
|
||||
{ name: "initial_context.*", description: "Initial context variables" },
|
||||
{ name: "gathered_context.*", description: "Extracted variables" },
|
||||
{ name: "cost_info.call_duration_seconds", description: "Call duration" },
|
||||
{ name: "recording_url", description: "Call recording URL" },
|
||||
{ name: "transcript_url", description: "Transcript URL" },
|
||||
];
|
||||
|
||||
return (
|
||||
<Tabs defaultValue="basic" className="w-full">
|
||||
<TabsList className="grid w-full grid-cols-4">
|
||||
|
|
@ -432,18 +264,10 @@ const WebhookNodeEditForm = ({
|
|||
|
||||
<div className="grid gap-2">
|
||||
<Label>HTTP Method</Label>
|
||||
<Select value={httpMethod} onValueChange={(v) => setHttpMethod(v as HttpMethod)}>
|
||||
<SelectTrigger>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="GET">GET</SelectItem>
|
||||
<SelectItem value="POST">POST</SelectItem>
|
||||
<SelectItem value="PUT">PUT</SelectItem>
|
||||
<SelectItem value="PATCH">PATCH</SelectItem>
|
||||
<SelectItem value="DELETE">DELETE</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<HttpMethodSelector
|
||||
value={httpMethod}
|
||||
onChange={setHttpMethod}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
|
|
@ -460,154 +284,10 @@ const WebhookNodeEditForm = ({
|
|||
</TabsContent>
|
||||
|
||||
<TabsContent value="auth" className="space-y-4 mt-4">
|
||||
<div className="grid gap-2">
|
||||
<Label>Credential</Label>
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
Select a credential for authentication, or leave empty for no auth.
|
||||
</Label>
|
||||
<div className="flex gap-2">
|
||||
<Select
|
||||
value={credentialUuid || "none"}
|
||||
onValueChange={(v) => setCredentialUuid(v === "none" ? "" : v)}
|
||||
disabled={credentialsLoading}
|
||||
>
|
||||
<SelectTrigger className="flex-1">
|
||||
{credentialsLoading ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
<span>Loading...</span>
|
||||
</div>
|
||||
) : (
|
||||
<SelectValue placeholder="No authentication" />
|
||||
)}
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="none">No authentication</SelectItem>
|
||||
{credentials.map((cred) => (
|
||||
<SelectItem key={cred.uuid} value={cred.uuid}>
|
||||
{cred.name} ({cred.credential_type})
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={() => setIsAddCredentialOpen(true)}
|
||||
title="Add new credential"
|
||||
>
|
||||
<PlusIcon className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{credentials.length === 0 && !credentialsLoading && (
|
||||
<div className="p-3 border rounded-md bg-muted/20">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
No credentials found. Click the + button to create one.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Add Credential Dialog */}
|
||||
<Dialog open={isAddCredentialOpen} onOpenChange={handleAddCredentialDialogChange}>
|
||||
<DialogContent className="sm:max-w-md">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add Credential</DialogTitle>
|
||||
<DialogDescription>
|
||||
Create a new credential for webhook authentication.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
{/* Error display */}
|
||||
{credentialError && (
|
||||
<div className="flex items-start gap-2 p-3 text-sm text-red-600 bg-red-50 border border-red-200 rounded-md">
|
||||
<AlertCircle className="h-4 w-4 mt-0.5 flex-shrink-0" />
|
||||
<span>{credentialError}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="space-y-4 py-4">
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="cred-name">Name *</Label>
|
||||
<Input
|
||||
id="cred-name"
|
||||
value={newCredName}
|
||||
onChange={(e) => setNewCredName(e.target.value)}
|
||||
placeholder="My API Key"
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="cred-description">Description</Label>
|
||||
<Input
|
||||
id="cred-description"
|
||||
value={newCredDescription}
|
||||
onChange={(e) => setNewCredDescription(e.target.value)}
|
||||
placeholder="Optional description"
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<Label>Credential Type</Label>
|
||||
<Select
|
||||
value={newCredType}
|
||||
onValueChange={(v) => {
|
||||
setNewCredType(v as WebhookCredentialType);
|
||||
setNewCredData({});
|
||||
}}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="bearer_token">Bearer Token</SelectItem>
|
||||
<SelectItem value="api_key">API Key</SelectItem>
|
||||
<SelectItem value="basic_auth">Basic Auth</SelectItem>
|
||||
<SelectItem value="custom_header">Custom Header</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
{getCredentialDataFields(newCredType).map((field) => (
|
||||
<div key={field.key} className="grid gap-2">
|
||||
<Label htmlFor={`cred-${field.key}`}>{field.label}</Label>
|
||||
<Input
|
||||
id={`cred-${field.key}`}
|
||||
type={field.isSecret ? "password" : "text"}
|
||||
value={newCredData[field.key] || ""}
|
||||
onChange={(e) =>
|
||||
setNewCredData((prev) => ({
|
||||
...prev,
|
||||
[field.key]: e.target.value,
|
||||
}))
|
||||
}
|
||||
placeholder={field.placeholder}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => setIsAddCredentialOpen(false)}
|
||||
disabled={isCreatingCredential}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleCreateCredential}
|
||||
disabled={!newCredName.trim() || isCreatingCredential}
|
||||
>
|
||||
{isCreatingCredential ? (
|
||||
<>
|
||||
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
|
||||
Creating...
|
||||
</>
|
||||
) : (
|
||||
"Create"
|
||||
)}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
<CredentialSelector
|
||||
value={credentialUuid}
|
||||
onChange={setCredentialUuid}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="headers" className="space-y-4 mt-4">
|
||||
|
|
@ -616,34 +296,13 @@ const WebhookNodeEditForm = ({
|
|||
<Label className="text-xs text-muted-foreground">
|
||||
Add custom headers to include in the webhook request.
|
||||
</Label>
|
||||
|
||||
{customHeaders.map((header, index) => (
|
||||
<div key={index} className="flex items-center gap-2">
|
||||
<Input
|
||||
placeholder="Header name"
|
||||
value={header.key}
|
||||
onChange={(e) => updateHeader(index, "key", e.target.value)}
|
||||
className="flex-1"
|
||||
/>
|
||||
<Input
|
||||
placeholder="Header value"
|
||||
value={header.value}
|
||||
onChange={(e) => updateHeader(index, "value", e.target.value)}
|
||||
className="flex-1"
|
||||
/>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={() => removeHeader(index)}
|
||||
>
|
||||
<Trash2Icon className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
|
||||
<Button variant="outline" size="sm" onClick={addHeader} className="w-fit">
|
||||
<PlusIcon className="h-4 w-4 mr-1" /> Add Header
|
||||
</Button>
|
||||
<KeyValueEditor
|
||||
items={customHeaders}
|
||||
onChange={setCustomHeaders}
|
||||
keyPlaceholder="Header name"
|
||||
valuePlaceholder="Header value"
|
||||
addButtonText="Add Header"
|
||||
/>
|
||||
</div>
|
||||
</TabsContent>
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,8 @@ export type FlowNodeData = {
|
|||
max_retries: number;
|
||||
retry_delay_seconds: number;
|
||||
};
|
||||
// Tools - array of tool UUIDs that can be invoked by this node
|
||||
tool_uuids?: string[];
|
||||
}
|
||||
|
||||
export type FlowNode = {
|
||||
|
|
|
|||
242
ui/src/components/http/create-credential-dialog.tsx
Normal file
242
ui/src/components/http/create-credential-dialog.tsx
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
"use client";
|
||||
|
||||
import { AlertCircle, Loader2 } from "lucide-react";
|
||||
import { useState } from "react";
|
||||
|
||||
import { createCredentialApiV1CredentialsPost } from "@/client";
|
||||
import { CredentialResponse, WebhookCredentialType } from "@/client/types.gen";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
interface CreateCredentialDialogProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
onCreated?: (credential: CredentialResponse) => void;
|
||||
}
|
||||
|
||||
interface CredentialField {
|
||||
key: string;
|
||||
label: string;
|
||||
placeholder: string;
|
||||
isSecret?: boolean;
|
||||
}
|
||||
|
||||
const getCredentialDataFields = (type: WebhookCredentialType): CredentialField[] => {
|
||||
switch (type) {
|
||||
case "api_key":
|
||||
return [
|
||||
{ key: "header_name", label: "Header Name", placeholder: "X-API-Key" },
|
||||
{ key: "api_key", label: "API Key", placeholder: "your-api-key", isSecret: true },
|
||||
];
|
||||
case "bearer_token":
|
||||
return [
|
||||
{ key: "token", label: "Token", placeholder: "your-bearer-token", isSecret: true },
|
||||
];
|
||||
case "basic_auth":
|
||||
return [
|
||||
{ key: "username", label: "Username", placeholder: "username" },
|
||||
{ key: "password", label: "Password", placeholder: "password", isSecret: true },
|
||||
];
|
||||
case "custom_header":
|
||||
return [
|
||||
{ key: "header_name", label: "Header Name", placeholder: "X-Custom-Header" },
|
||||
{ key: "header_value", label: "Header Value", placeholder: "header-value", isSecret: true },
|
||||
];
|
||||
default:
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
export function CreateCredentialDialog({
|
||||
open,
|
||||
onOpenChange,
|
||||
onCreated,
|
||||
}: CreateCredentialDialogProps) {
|
||||
const { getAccessToken } = useAuth();
|
||||
|
||||
const [name, setName] = useState("");
|
||||
const [description, setDescription] = useState("");
|
||||
const [credentialType, setCredentialType] = useState<WebhookCredentialType>("bearer_token");
|
||||
const [credentialData, setCredentialData] = useState<Record<string, string>>({});
|
||||
const [isCreating, setIsCreating] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const handleCreate = async () => {
|
||||
if (!name.trim()) return;
|
||||
|
||||
setIsCreating(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
const response = await createCredentialApiV1CredentialsPost({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
body: {
|
||||
name,
|
||||
description: description || undefined,
|
||||
credential_type: credentialType,
|
||||
credential_data: credentialData,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.error) {
|
||||
const errorDetail = (response.error as { detail?: string })?.detail
|
||||
|| "Failed to create credential";
|
||||
setError(errorDetail);
|
||||
return;
|
||||
}
|
||||
|
||||
if (response.data) {
|
||||
onCreated?.(response.data);
|
||||
handleClose();
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Failed to create credential:", err);
|
||||
setError(
|
||||
err instanceof Error ? err.message : "An unexpected error occurred"
|
||||
);
|
||||
} finally {
|
||||
setIsCreating(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleClose = () => {
|
||||
onOpenChange(false);
|
||||
// Reset form
|
||||
setName("");
|
||||
setDescription("");
|
||||
setCredentialType("bearer_token");
|
||||
setCredentialData({});
|
||||
setError(null);
|
||||
};
|
||||
|
||||
const handleOpenChange = (newOpen: boolean) => {
|
||||
if (!newOpen) {
|
||||
setError(null);
|
||||
}
|
||||
onOpenChange(newOpen);
|
||||
};
|
||||
|
||||
const fields = getCredentialDataFields(credentialType);
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={handleOpenChange}>
|
||||
<DialogContent className="sm:max-w-md">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add Credential</DialogTitle>
|
||||
<DialogDescription>
|
||||
Create a new credential for authentication.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
{error && (
|
||||
<div className="flex items-start gap-2 p-3 text-sm text-red-600 bg-red-50 border border-red-200 rounded-md">
|
||||
<AlertCircle className="h-4 w-4 mt-0.5 flex-shrink-0" />
|
||||
<span>{error}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="space-y-4 py-4">
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="cred-name">Name *</Label>
|
||||
<Input
|
||||
id="cred-name"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
placeholder="My API Key"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="cred-description">Description</Label>
|
||||
<Input
|
||||
id="cred-description"
|
||||
value={description}
|
||||
onChange={(e) => setDescription(e.target.value)}
|
||||
placeholder="Optional description"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label>Credential Type</Label>
|
||||
<Select
|
||||
value={credentialType}
|
||||
onValueChange={(v) => {
|
||||
setCredentialType(v as WebhookCredentialType);
|
||||
setCredentialData({});
|
||||
}}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="bearer_token">Bearer Token</SelectItem>
|
||||
<SelectItem value="api_key">API Key</SelectItem>
|
||||
<SelectItem value="basic_auth">Basic Auth</SelectItem>
|
||||
<SelectItem value="custom_header">Custom Header</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
{fields.map((field) => (
|
||||
<div key={field.key} className="grid gap-2">
|
||||
<Label htmlFor={`cred-${field.key}`}>{field.label}</Label>
|
||||
<Input
|
||||
id={`cred-${field.key}`}
|
||||
type={field.isSecret ? "password" : "text"}
|
||||
value={credentialData[field.key] || ""}
|
||||
onChange={(e) =>
|
||||
setCredentialData((prev) => ({
|
||||
...prev,
|
||||
[field.key]: e.target.value,
|
||||
}))
|
||||
}
|
||||
placeholder={field.placeholder}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={handleClose}
|
||||
disabled={isCreating}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleCreate}
|
||||
disabled={!name.trim() || isCreating}
|
||||
>
|
||||
{isCreating ? (
|
||||
<>
|
||||
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
|
||||
Creating...
|
||||
</>
|
||||
) : (
|
||||
"Create"
|
||||
)}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
140
ui/src/components/http/credential-selector.tsx
Normal file
140
ui/src/components/http/credential-selector.tsx
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
"use client";
|
||||
|
||||
import { Loader2, PlusIcon } from "lucide-react";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import { listCredentialsApiV1CredentialsGet } from "@/client";
|
||||
import { CredentialResponse } from "@/client/types.gen";
|
||||
import { CreateCredentialDialog } from "@/components/http/create-credential-dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
interface CredentialSelectorProps {
|
||||
value: string;
|
||||
onChange: (uuid: string) => void;
|
||||
disabled?: boolean;
|
||||
placeholder?: string;
|
||||
label?: string;
|
||||
description?: string;
|
||||
showLabel?: boolean;
|
||||
}
|
||||
|
||||
export function CredentialSelector({
|
||||
value,
|
||||
onChange,
|
||||
disabled = false,
|
||||
placeholder = "No authentication",
|
||||
label = "Credential",
|
||||
description = "Select a credential for authentication, or leave empty for no auth.",
|
||||
showLabel = true,
|
||||
}: CredentialSelectorProps) {
|
||||
const { getAccessToken } = useAuth();
|
||||
|
||||
const [credentials, setCredentials] = useState<CredentialResponse[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [isAddDialogOpen, setIsAddDialogOpen] = useState(false);
|
||||
|
||||
const fetchCredentials = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const accessToken = await getAccessToken();
|
||||
const response = await listCredentialsApiV1CredentialsGet({
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
});
|
||||
if (response.error) {
|
||||
console.error("Failed to fetch credentials:", response.error);
|
||||
setCredentials([]);
|
||||
return;
|
||||
}
|
||||
if (response.data) {
|
||||
setCredentials(response.data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch credentials:", error);
|
||||
setCredentials([]);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [getAccessToken]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchCredentials();
|
||||
}, [fetchCredentials]);
|
||||
|
||||
const handleCredentialCreated = async (credential: CredentialResponse) => {
|
||||
await fetchCredentials();
|
||||
onChange(credential.uuid);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="grid gap-2">
|
||||
{showLabel && (
|
||||
<>
|
||||
<Label>{label}</Label>
|
||||
{description && (
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
{description}
|
||||
</Label>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<div className="flex gap-2">
|
||||
<Select
|
||||
value={value || "none"}
|
||||
onValueChange={(v) => onChange(v === "none" ? "" : v)}
|
||||
disabled={disabled || loading}
|
||||
>
|
||||
<SelectTrigger className="flex-1">
|
||||
{loading ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
<span>Loading...</span>
|
||||
</div>
|
||||
) : (
|
||||
<SelectValue placeholder={placeholder} />
|
||||
)}
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="none">{placeholder}</SelectItem>
|
||||
{credentials.map((cred) => (
|
||||
<SelectItem key={cred.uuid} value={cred.uuid}>
|
||||
{cred.name} ({cred.credential_type})
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={() => setIsAddDialogOpen(true)}
|
||||
title="Add new credential"
|
||||
disabled={disabled}
|
||||
>
|
||||
<PlusIcon className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{credentials.length === 0 && !loading && (
|
||||
<div className="p-3 border rounded-md bg-muted/20">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
No credentials found. Click the + button to create one.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<CreateCredentialDialog
|
||||
open={isAddDialogOpen}
|
||||
onOpenChange={setIsAddDialogOpen}
|
||||
onCreated={handleCredentialCreated}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
44
ui/src/components/http/http-method-selector.tsx
Normal file
44
ui/src/components/http/http-method-selector.tsx
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"use client";
|
||||
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
|
||||
export type HttpMethod = "GET" | "POST" | "PUT" | "PATCH" | "DELETE";
|
||||
|
||||
interface HttpMethodSelectorProps {
|
||||
value: HttpMethod;
|
||||
onChange: (method: HttpMethod) => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
const HTTP_METHODS: HttpMethod[] = ["GET", "POST", "PUT", "PATCH", "DELETE"];
|
||||
|
||||
export function HttpMethodSelector({
|
||||
value,
|
||||
onChange,
|
||||
disabled = false,
|
||||
}: HttpMethodSelectorProps) {
|
||||
return (
|
||||
<Select
|
||||
value={value}
|
||||
onValueChange={(v) => onChange(v as HttpMethod)}
|
||||
disabled={disabled}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{HTTP_METHODS.map((method) => (
|
||||
<SelectItem key={method} value={method}>
|
||||
{method}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
);
|
||||
}
|
||||
5
ui/src/components/http/index.ts
Normal file
5
ui/src/components/http/index.ts
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
export { CreateCredentialDialog } from "./create-credential-dialog";
|
||||
export { CredentialSelector } from "./credential-selector";
|
||||
export { type HttpMethod, HttpMethodSelector } from "./http-method-selector";
|
||||
export { KeyValueEditor, type KeyValueItem } from "./key-value-editor";
|
||||
export { ParameterEditor, type ParameterType,type ToolParameter } from "./parameter-editor";
|
||||
85
ui/src/components/http/key-value-editor.tsx
Normal file
85
ui/src/components/http/key-value-editor.tsx
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
"use client";
|
||||
|
||||
import { PlusIcon, Trash2Icon } from "lucide-react";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
|
||||
export interface KeyValueItem {
|
||||
key: string;
|
||||
value: string;
|
||||
}
|
||||
|
||||
interface KeyValueEditorProps {
|
||||
items: KeyValueItem[];
|
||||
onChange: (items: KeyValueItem[]) => void;
|
||||
keyPlaceholder?: string;
|
||||
valuePlaceholder?: string;
|
||||
addButtonText?: string;
|
||||
emptyMessage?: string;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export function KeyValueEditor({
|
||||
items,
|
||||
onChange,
|
||||
keyPlaceholder = "Key",
|
||||
valuePlaceholder = "Value",
|
||||
addButtonText = "Add",
|
||||
disabled = false,
|
||||
}: KeyValueEditorProps) {
|
||||
const addItem = () => {
|
||||
onChange([...items, { key: "", value: "" }]);
|
||||
};
|
||||
|
||||
const updateItem = (index: number, field: "key" | "value", value: string) => {
|
||||
const newItems = [...items];
|
||||
newItems[index] = { ...newItems[index], [field]: value };
|
||||
onChange(newItems);
|
||||
};
|
||||
|
||||
const removeItem = (index: number) => {
|
||||
onChange(items.filter((_, i) => i !== index));
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
{items.map((item, index) => (
|
||||
<div key={index} className="flex items-center gap-2">
|
||||
<Input
|
||||
placeholder={keyPlaceholder}
|
||||
value={item.key}
|
||||
onChange={(e) => updateItem(index, "key", e.target.value)}
|
||||
className="flex-1"
|
||||
disabled={disabled}
|
||||
/>
|
||||
<Input
|
||||
placeholder={valuePlaceholder}
|
||||
value={item.value}
|
||||
onChange={(e) => updateItem(index, "value", e.target.value)}
|
||||
className="flex-1"
|
||||
disabled={disabled}
|
||||
/>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={() => removeItem(index)}
|
||||
disabled={disabled}
|
||||
>
|
||||
<Trash2Icon className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={addItem}
|
||||
className="w-fit"
|
||||
disabled={disabled}
|
||||
>
|
||||
<PlusIcon className="h-4 w-4 mr-1" /> {addButtonText}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
167
ui/src/components/http/parameter-editor.tsx
Normal file
167
ui/src/components/http/parameter-editor.tsx
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
"use client";
|
||||
|
||||
import { PlusIcon, Trash2Icon } from "lucide-react";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
|
||||
export type ParameterType = "string" | "number" | "boolean";
|
||||
|
||||
export interface ToolParameter {
|
||||
name: string;
|
||||
type: ParameterType;
|
||||
description: string;
|
||||
required: boolean;
|
||||
}
|
||||
|
||||
interface ParameterEditorProps {
|
||||
parameters: ToolParameter[];
|
||||
onChange: (parameters: ToolParameter[]) => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export function ParameterEditor({
|
||||
parameters,
|
||||
onChange,
|
||||
disabled = false,
|
||||
}: ParameterEditorProps) {
|
||||
const addParameter = () => {
|
||||
onChange([
|
||||
...parameters,
|
||||
{ name: "", type: "string", description: "", required: true },
|
||||
]);
|
||||
};
|
||||
|
||||
const updateParameter = (
|
||||
index: number,
|
||||
field: keyof ToolParameter,
|
||||
value: string | boolean
|
||||
) => {
|
||||
const newParams = [...parameters];
|
||||
newParams[index] = { ...newParams[index], [field]: value };
|
||||
onChange(newParams);
|
||||
};
|
||||
|
||||
const removeParameter = (index: number) => {
|
||||
onChange(parameters.filter((_, i) => i !== index));
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
{parameters.length === 0 && (
|
||||
<div className="text-sm text-muted-foreground py-4 text-center border border-dashed rounded-md">
|
||||
No parameters defined. Add a parameter to specify what data this tool needs.
|
||||
</div>
|
||||
)}
|
||||
|
||||
{parameters.map((param, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className="border rounded-lg p-4 space-y-3 bg-muted/20"
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm font-medium text-muted-foreground">
|
||||
Parameter {index + 1}
|
||||
</span>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => removeParameter(index)}
|
||||
disabled={disabled}
|
||||
className="h-8 w-8"
|
||||
>
|
||||
<Trash2Icon className="h-4 w-4 text-muted-foreground hover:text-destructive" />
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-2 gap-3">
|
||||
<div className="space-y-1.5">
|
||||
<Label className="text-xs">Name</Label>
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
Name of the parameter, like "order_id" or "customer_name"
|
||||
</Label>
|
||||
<Input
|
||||
placeholder="e.g., customer_name"
|
||||
value={param.name}
|
||||
onChange={(e) =>
|
||||
updateParameter(index, "name", e.target.value)
|
||||
}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-1.5">
|
||||
<Label className="text-xs">Type</Label>
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
Type of the parameter, like "string" or "number" or "boolean"
|
||||
</Label>
|
||||
<Select
|
||||
value={param.type}
|
||||
onValueChange={(value: ParameterType) =>
|
||||
updateParameter(index, "type", value)
|
||||
}
|
||||
disabled={disabled}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select type" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="string">String</SelectItem>
|
||||
<SelectItem value="number">Number</SelectItem>
|
||||
<SelectItem value="boolean">Boolean</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-1.5">
|
||||
<Label className="text-xs">Description</Label>
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
Description of the parameter, which makes it easy for LLM to understand, like "The ID of the Customer to fetch Order Details"
|
||||
</Label>
|
||||
<Input
|
||||
placeholder="Describe what this parameter is for..."
|
||||
value={param.description}
|
||||
onChange={(e) =>
|
||||
updateParameter(index, "description", e.target.value)
|
||||
}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
<Switch
|
||||
id={`required-${index}`}
|
||||
checked={param.required}
|
||||
onCheckedChange={(checked) =>
|
||||
updateParameter(index, "required", checked)
|
||||
}
|
||||
disabled={disabled}
|
||||
/>
|
||||
<Label htmlFor={`required-${index}`} className="text-sm">
|
||||
Required
|
||||
</Label>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={addParameter}
|
||||
className="w-fit"
|
||||
disabled={disabled}
|
||||
>
|
||||
<PlusIcon className="h-4 w-4 mr-1" /> Add Parameter
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -16,6 +16,7 @@ import {
|
|||
Star,
|
||||
TrendingUp,
|
||||
Workflow,
|
||||
Wrench,
|
||||
Zap,
|
||||
} from "lucide-react";
|
||||
import Link from "next/link";
|
||||
|
|
@ -108,6 +109,11 @@ export function AppSidebar() {
|
|||
url: "/telephony-configurations",
|
||||
icon: Phone,
|
||||
},
|
||||
{
|
||||
title: "Tools",
|
||||
url: "/tools",
|
||||
icon: Wrench,
|
||||
},
|
||||
// {
|
||||
// title: "Integrations",
|
||||
// url: "/integrations",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue