From 3e55af9256ecbde4b92a547fc9bca8a8c5e57f15 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Fri, 2 Jan 2026 13:11:02 +0530 Subject: [PATCH] feat: user defined custom tools as part of workflow execution (#94) * feat: add custom tools functionality * Show tools in nodes * integrate tool calling with pipeline engine --- .../versions/ebc80cea7965_add_tools_model.py | 92 ++ api/conftest.py | 395 +++++-- api/db/db_client.py | 3 + api/db/models.py | 84 ++ api/db/tool_client.py | 276 +++++ api/enums.py | 18 + api/pytest.ini | 7 +- api/requirements.dev.txt | 3 +- api/routes/main.py | 2 + api/routes/tool.py | 336 ++++++ api/services/workflow/dto.py | 1 + api/services/workflow/pipecat_engine.py | 29 +- .../workflow/pipecat_engine_custom_tools.py | 189 +++ api/services/workflow/tools/custom_tool.py | 180 +++ api/services/workflow/workflow.py | 1 + api/tasks/run_integrations.py | 42 +- .../test_assistant_context_aggregator.py | 138 --- api/tests/test_audio_transcript_buffers.py | 120 -- api/tests/test_concurrent_call_limiting.py | 330 ------ api/tests/test_configuration_masking_merge.py | 78 -- api/tests/test_custom_tools.py | 1041 +++++++++++++++++ .../test_custom_tools_context_integration.py | 512 ++++++++ api/tests/test_default_user_configuration.py | 33 - api/tests/test_disposition_mapper.py | 122 -- .../test_event_handler_disposition_mapping.py | 370 ------ api/tests/test_event_handlers_refactor.py | 184 --- api/tests/test_filters.py | 162 --- api/tests/test_global_prompt.py | 249 ---- api/tests/test_global_prompt_unit.py | 175 --- api/tests/test_leave_counter.py | 248 ---- api/tests/test_llm_response_reorder.py | 99 -- api/tests/test_looptalk_routes.py | 506 -------- api/tests/test_mock_llm_service.py | 142 --- api/tests/test_pipecat_disposition_mapping.py | 236 ---- api/tests/test_pipecat_engine.py | 206 ---- api/tests/test_provider_switching.py | 295 ----- api/tests/test_run_integrations_db_client.py | 266 ----- api/tests/test_run_integrations_template.py | 330 ------ api/tests/test_s3_signed_url.py | 117 -- api/tests/test_s3_upload_tasks.py | 129 -- api/tests/test_template_renderer.py | 89 -- api/tests/test_usage_concurrency.py | 152 --- api/tests/test_variable_extraction.py | 140 --- api/tests/test_voicemail_detection_rtc.py | 547 --------- api/tests/test_workflow_routes.py | 667 ----------- api/utils/credential_auth.py | 95 ++ ui/src/app/tools/[toolUuid]/page.tsx | 498 ++++++++ ui/src/app/tools/page.tsx | 431 +++++++ .../workflow/[workflowId]/RenderWorkflow.tsx | 2 +- .../[workflowId]/run/[runId]/page.tsx | 2 +- ui/src/client/sdk.gen.ts | 97 +- ui/src/client/types.gen.ts | 287 +++++ ui/src/components/flow/ToolBadges.tsx | 64 + ui/src/components/flow/ToolSelector.tsx | 161 +++ ui/src/components/flow/nodes/AgentNode.tsx | 34 +- ui/src/components/flow/nodes/StartCall.tsx | 36 +- ui/src/components/flow/nodes/WebhookNode.tsx | 417 +------ ui/src/components/flow/types.ts | 2 + .../http/create-credential-dialog.tsx | 242 ++++ .../components/http/credential-selector.tsx | 140 +++ .../components/http/http-method-selector.tsx | 44 + ui/src/components/http/index.ts | 5 + ui/src/components/http/key-value-editor.tsx | 85 ++ ui/src/components/http/parameter-editor.tsx | 167 +++ ui/src/components/layout/AppSidebar.tsx | 6 + 65 files changed, 5483 insertions(+), 6673 deletions(-) create mode 100644 api/alembic/versions/ebc80cea7965_add_tools_model.py create mode 100644 api/db/tool_client.py create mode 100644 api/routes/tool.py create mode 100644 api/services/workflow/pipecat_engine_custom_tools.py create mode 100644 api/services/workflow/tools/custom_tool.py delete mode 100644 api/tests/test_assistant_context_aggregator.py delete mode 100644 api/tests/test_audio_transcript_buffers.py delete mode 100644 api/tests/test_concurrent_call_limiting.py delete mode 100644 api/tests/test_configuration_masking_merge.py create mode 100644 api/tests/test_custom_tools.py create mode 100644 api/tests/test_custom_tools_context_integration.py delete mode 100644 api/tests/test_default_user_configuration.py delete mode 100644 api/tests/test_disposition_mapper.py delete mode 100644 api/tests/test_event_handler_disposition_mapping.py delete mode 100644 api/tests/test_event_handlers_refactor.py delete mode 100644 api/tests/test_filters.py delete mode 100644 api/tests/test_global_prompt.py delete mode 100644 api/tests/test_global_prompt_unit.py delete mode 100644 api/tests/test_leave_counter.py delete mode 100644 api/tests/test_llm_response_reorder.py delete mode 100644 api/tests/test_looptalk_routes.py delete mode 100644 api/tests/test_mock_llm_service.py delete mode 100644 api/tests/test_pipecat_disposition_mapping.py delete mode 100644 api/tests/test_pipecat_engine.py delete mode 100644 api/tests/test_provider_switching.py delete mode 100644 api/tests/test_run_integrations_db_client.py delete mode 100644 api/tests/test_run_integrations_template.py delete mode 100644 api/tests/test_s3_signed_url.py delete mode 100644 api/tests/test_s3_upload_tasks.py delete mode 100644 api/tests/test_template_renderer.py delete mode 100644 api/tests/test_usage_concurrency.py delete mode 100644 api/tests/test_variable_extraction.py delete mode 100644 api/tests/test_voicemail_detection_rtc.py delete mode 100644 api/tests/test_workflow_routes.py create mode 100644 api/utils/credential_auth.py create mode 100644 ui/src/app/tools/[toolUuid]/page.tsx create mode 100644 ui/src/app/tools/page.tsx create mode 100644 ui/src/components/flow/ToolBadges.tsx create mode 100644 ui/src/components/flow/ToolSelector.tsx create mode 100644 ui/src/components/http/create-credential-dialog.tsx create mode 100644 ui/src/components/http/credential-selector.tsx create mode 100644 ui/src/components/http/http-method-selector.tsx create mode 100644 ui/src/components/http/index.ts create mode 100644 ui/src/components/http/key-value-editor.tsx create mode 100644 ui/src/components/http/parameter-editor.tsx diff --git a/api/alembic/versions/ebc80cea7965_add_tools_model.py b/api/alembic/versions/ebc80cea7965_add_tools_model.py new file mode 100644 index 0000000..22f34e8 --- /dev/null +++ b/api/alembic/versions/ebc80cea7965_add_tools_model.py @@ -0,0 +1,92 @@ +"""add tools model + +Revision ID: ebc80cea7965 +Revises: 36b5dbf670e4 +Create Date: 2026-01-01 10:13:50.807135 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "ebc80cea7965" +down_revision: Union[str, None] = "36b5dbf670e4" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + sa.Enum("active", "archived", "draft", name="tool_status").create(op.get_bind()) + sa.Enum("http_api", "native", "integration", name="tool_category").create( + op.get_bind() + ) + op.create_table( + "tools", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("tool_uuid", sa.String(length=36), nullable=False), + sa.Column("organization_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column( + "category", + postgresql.ENUM( + "http_api", + "native", + "integration", + name="tool_category", + create_type=False, + ), + nullable=False, + ), + sa.Column("icon", sa.String(length=50), nullable=True), + sa.Column("icon_color", sa.String(length=7), nullable=True), + sa.Column( + "status", + postgresql.ENUM( + "active", "archived", "draft", name="tool_status", create_type=False + ), + server_default=sa.text("'active'::tool_status"), + nullable=False, + ), + sa.Column("definition", sa.JSON(), nullable=False), + sa.Column("created_by", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["created_by"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("organization_id", "name", name="unique_org_tool_name"), + ) + op.create_index("ix_tools_category", "tools", ["category"], unique=False) + op.create_index( + "ix_tools_organization_id", "tools", ["organization_id"], unique=False + ) + op.create_index("ix_tools_status", "tools", ["status"], unique=False) + op.create_index(op.f("ix_tools_tool_uuid"), "tools", ["tool_uuid"], unique=True) + op.create_index("ix_tools_uuid", "tools", ["tool_uuid"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_tools_uuid", table_name="tools") + op.drop_index(op.f("ix_tools_tool_uuid"), table_name="tools") + op.drop_index("ix_tools_status", table_name="tools") + op.drop_index("ix_tools_organization_id", table_name="tools") + op.drop_index("ix_tools_category", table_name="tools") + op.drop_table("tools") + sa.Enum("http_api", "native", "integration", name="tool_category").drop( + op.get_bind() + ) + sa.Enum("active", "archived", "draft", name="tool_status").drop(op.get_bind()) + # ### end Alembic commands ### diff --git a/api/conftest.py b/api/conftest.py index 88d8caf..6fcfa2c 100644 --- a/api/conftest.py +++ b/api/conftest.py @@ -1,143 +1,315 @@ """ -Shared pytest fixtures for the API tests. -This file contains database setup, test client configuration, and utility fixtures -that can be reused across all test files. +Pytest configuration and fixtures for async database testing. + +This module sets up the test infrastructure using: +- A separate test database (appends _test to the database name) +- Alembic migrations run once per test session +- Transaction-based isolation for each test (savepoint pattern) + +References: +- https://www.core27.co/post/transactional-unit-tests-with-pytest-and-async-sqlalchemy +- https://docs.sqlalchemy.org/en/20/orm/session_transaction.html """ import os -import subprocess -import uuid +from pathlib import Path +from typing import AsyncGenerator +from urllib.parse import urlparse, urlunparse -import pytest_asyncio -from httpx import ASGITransport, AsyncClient -from loguru import logger -from sqlalchemy import text -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +# Load environment variables before importing anything else +from dotenv import load_dotenv -from api.app import app -from api.db import db_client +# Load .env.test from api directory for test configuration +env_path = Path(__file__).parent / ".env.test" +load_dotenv(env_path) -# Test database setup globals -TEST_DATABASE_NAME = None -TEST_DATABASE_URL = None +import pytest +from sqlalchemy import event, text +from sqlalchemy.ext.asyncio import ( + AsyncConnection, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import SessionTransaction +from sqlalchemy.pool import NullPool -@pytest_asyncio.fixture -async def test_database(): +def get_test_database_url() -> str: """ - Set up a temporary PostgreSQL database for testing. - This fixture creates a unique test database, runs migrations, and cleans up afterward. + Get the test database URL by appending _test to the database name. + + Example: + postgresql+asyncpg://user:pass@host/mydb + -> postgresql+asyncpg://user:pass@host/mydb_test """ - global TEST_DATABASE_NAME, TEST_DATABASE_URL + original_url = os.environ.get("DATABASE_URL") + if not original_url: + raise ValueError("DATABASE_URL environment variable is not set") - # Generate a unique test database name - TEST_DATABASE_NAME = f"test_dograh_{uuid.uuid4().hex[:8]}" + parsed = urlparse(original_url) + # Append _test to the database name (path without leading slash) + original_db_name = parsed.path.lstrip("/") + test_db_name = f"{original_db_name}_test" - # Get the base DATABASE_URL and parse it - base_url = os.environ.get("DATABASE_URL") - # Extract connection parts and replace database name - url_parts = base_url.split("/") - base_connection = "/".join(url_parts[:-1]) - TEST_DATABASE_URL = f"{base_connection}/{TEST_DATABASE_NAME}" + # Reconstruct the URL with the new database name + test_url = urlunparse( + ( + parsed.scheme, + parsed.netloc, + f"/{test_db_name}", + parsed.params, + parsed.query, + parsed.fragment, + ) + ) + return test_url - # Create a connection to the default postgres database to create our test database - default_engine = create_async_engine(base_url) + +def get_base_database_url() -> str: + """ + Get base database URL (postgres) for creating/dropping test database. + """ + original_url = os.environ.get("DATABASE_URL") + parsed = urlparse(original_url) + # Connect to 'postgres' database for admin operations + base_url = urlunparse( + ( + parsed.scheme, + parsed.netloc, + "/postgres", + parsed.params, + parsed.query, + parsed.fragment, + ) + ) + return base_url + + +def get_test_db_name() -> str: + """Extract the test database name.""" + original_url = os.environ.get("DATABASE_URL") + parsed = urlparse(original_url) + original_db_name = parsed.path.lstrip("/") + return f"{original_db_name}_test" + + +@pytest.fixture(scope="session") +async def setup_test_database(): + """ + Session-scoped fixture that creates the test database and runs migrations. + + This runs once at the start of the test session. + """ + test_db_name = get_test_db_name() + base_url = get_base_database_url() + test_url = get_test_database_url() + + # Create engine to connect to postgres database (for admin operations) + admin_engine = create_async_engine( + base_url, + poolclass=NullPool, + isolation_level="AUTOCOMMIT", # Required for CREATE DATABASE + ) + + # Create test database if it doesn't exist + async with admin_engine.connect() as conn: + # Check if database exists + result = await conn.execute( + text("SELECT 1 FROM pg_database WHERE datname = :dbname"), + {"dbname": test_db_name}, + ) + exists = result.scalar() is not None + + if not exists: + print(f"\n Creating test database: {test_db_name}") + # Use template0 to avoid collation version mismatch issues + await conn.execute( + text(f'CREATE DATABASE "{test_db_name}" TEMPLATE template0') + ) + else: + print(f"\n Using existing test database: {test_db_name}") + + await admin_engine.dispose() + + # Run alembic migrations on the test database + print(f" Running migrations on {test_db_name}...") + await run_migrations(test_url) + print(" Migrations complete!") + + yield test_url + + # Cleanup: Optionally drop the test database after tests + # Commented out to allow inspection of test data after failures + # async with admin_engine.connect() as conn: + # await conn.execute(text(f'DROP DATABASE IF EXISTS "{test_db_name}"')) + + +async def run_migrations(database_url: str): + """ + Run alembic migrations programmatically on the given database. + """ + from alembic import command + from alembic.config import Config + + # Get alembic.ini path + alembic_ini_path = Path(__file__).parent / "alembic.ini" + + # Create alembic config + alembic_cfg = Config(str(alembic_ini_path)) + + # Override the database URL - need to patch both os.environ AND api.constants + # because api.constants.DATABASE_URL is cached at import time + original_env_url = os.environ.get("DATABASE_URL") + os.environ["DATABASE_URL"] = database_url + alembic_cfg.set_main_option("sqlalchemy.url", database_url) + + # Also patch the cached value in api.constants + import api.constants + + original_constants_url = api.constants.DATABASE_URL + + api.constants.DATABASE_URL = database_url + + # Run migrations in a thread to avoid blocking the event loop + import asyncio + + def _run_upgrade(): + command.upgrade(alembic_cfg, "head") try: - # Create the test database - async with default_engine.connect() as conn: - # Use autocommit mode to create database - await conn.execute(text("COMMIT")) - await conn.execute(text(f"CREATE DATABASE {TEST_DATABASE_NAME}")) - - await default_engine.dispose() - - # Run migrations on the test database - env = os.environ.copy() - env["DATABASE_URL"] = TEST_DATABASE_URL - # Add the parent directory to PYTHONPATH so alembic can find the api module - parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - env["PYTHONPATH"] = parent_dir + ":" + env.get("PYTHONPATH", "") - - # Run alembic upgrade to create all tables - result = subprocess.run( - [ - "conda", - "run", - "-n", - "dograh", - "python", - "-m", - "alembic", - "-c", - "alembic.ini", - "upgrade", - "head", - ], - env=env, - capture_output=True, - text=True, - ) - - if result.returncode != 0: - logger.error(f"Alembic stderr: {result.stderr}") - logger.error(f"Alembic stdout: {result.stdout}") - raise RuntimeError(f"Alembic migration failed: {result.stderr}") - - logger.info(f"Created test database: {TEST_DATABASE_NAME}") - yield TEST_DATABASE_URL - + await asyncio.get_event_loop().run_in_executor(None, _run_upgrade) finally: - # Cleanup: Drop the test database - cleanup_engine = create_async_engine(base_url) - try: - async with cleanup_engine.connect() as conn: - # Terminate any connections to the test database - await conn.execute(text("COMMIT")) - await conn.execute( - text(f""" - SELECT pg_terminate_backend(pid) - FROM pg_stat_activity - WHERE datname = '{TEST_DATABASE_NAME}' AND pid <> pg_backend_pid() - """) - ) - await conn.execute( - text(f"DROP DATABASE IF EXISTS {TEST_DATABASE_NAME}") - ) - logger.info(f"Cleaned up test database: {TEST_DATABASE_NAME}") - except Exception as e: - logger.error( - f"Warning: Could not clean up test database {TEST_DATABASE_NAME}: {e}" - ) - finally: - await cleanup_engine.dispose() + # Restore original DATABASE_URL + if original_env_url: + os.environ["DATABASE_URL"] = original_env_url + api.constants.DATABASE_URL = original_constants_url -@pytest_asyncio.fixture -async def db_session(test_database): +@pytest.fixture(scope="session") +async def test_engine(setup_test_database): """ - Create a test database client that uses the temporary database. - This fixture replaces the global db_client with a test version. + Create a test database engine (session-scoped). + + Uses NullPool to avoid connection issues with async tests. """ + test_url = setup_test_database + engine = create_async_engine( + test_url, + poolclass=NullPool, + echo=False, # Set to True for SQL debugging + ) + yield engine + await engine.dispose() + + +@pytest.fixture(scope="function") +async def db_connection(test_engine) -> AsyncGenerator[AsyncConnection, None]: + """ + Create a database connection for each test. + + This connection wraps all operations in a transaction that + will be rolled back at the end of the test. + """ + async with test_engine.connect() as connection: + yield connection + + +@pytest.fixture(scope="function") +async def async_session( + db_connection: AsyncConnection, +) -> AsyncGenerator[AsyncSession, None]: + """ + Create a database session with transaction isolation for each test. + + This fixture: + 1. Begins a transaction on the connection + 2. Creates a savepoint (nested transaction) + 3. Yields the session for test use + 4. Rolls back all changes after the test + + Tests can call session.commit() and it will only commit to the savepoint, + not to the actual database. The outer transaction rollback ensures + complete isolation between tests. + """ + # Begin outer transaction + trans = await db_connection.begin() + + # Create session bound to this connection + async_session_maker = async_sessionmaker( + bind=db_connection, + expire_on_commit=False, + autoflush=False, + ) + + async with async_session_maker() as session: + # Begin a nested transaction (savepoint) + nested = await session.begin_nested() + + # Set up event listener to restart savepoint after commits + @event.listens_for(session.sync_session, "after_transaction_end") + def reopen_nested_transaction(session_sync, transaction: SessionTransaction): + nonlocal nested + if not nested.is_active: + nested = session.sync_session.begin_nested() + + yield session + + # Rollback everything + await trans.rollback() + + +class _TestSessionContext: + """ + Context manager wrapper for test session. + + Mimics the behavior of async_sessionmaker() context manager + but uses the existing test session without closing it. + """ + + def __init__(self, session: AsyncSession): + self._session = session + + async def __aenter__(self) -> AsyncSession: + return self._session + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + await self._session.flush() + return False + + +@pytest.fixture(scope="function") +async def db_session(async_session: AsyncSession): + """ + Create a DBClient instance that uses the test session. + + This patches the DBClient's async_session to use our test session, + ensuring all database operations go through the transactional test session. + + Note: This fixture yields a DBClient (not a raw session) for backward + compatibility with existing tests that call db_session.get_or_create_user_by_provider_id(), etc. + """ + from api.db import db_client + + def test_session_maker(): + return _TestSessionContext(async_session) + + # Store originals original_engine = db_client.engine - original_session = db_client.async_session + original_async_session = db_client.async_session - # Replace the database client's engine and session with test ones - test_engine = create_async_engine(test_database) - test_session_maker = async_sessionmaker(bind=test_engine) - - db_client.engine = test_engine + # Patch the db_client to use our test session db_client.async_session = test_session_maker yield db_client - # Restore original database client - await test_engine.dispose() + # Restore originals db_client.engine = original_engine - db_client.async_session = original_session + db_client.async_session = original_async_session -@pytest_asyncio.fixture +@pytest.fixture async def test_client_factory(db_session): """ Factory fixture that creates test clients for specific users. @@ -155,6 +327,9 @@ async def test_client_factory(db_session): """ from contextlib import asynccontextmanager + from httpx import ASGITransport, AsyncClient + + from api.app import app from api.services.auth.depends import get_user @asynccontextmanager diff --git a/api/db/db_client.py b/api/db/db_client.py index 793abc9..6a6c713 100644 --- a/api/db/db_client.py +++ b/api/db/db_client.py @@ -8,6 +8,7 @@ from api.db.organization_client import OrganizationClient from api.db.organization_configuration_client import OrganizationConfigurationClient from api.db.organization_usage_client import OrganizationUsageClient from api.db.reports_client import ReportsClient +from api.db.tool_client import ToolClient from api.db.user_client import UserClient from api.db.webhook_credential_client import WebhookCredentialClient from api.db.workflow_client import WorkflowClient @@ -31,6 +32,7 @@ class DBClient( EmbedTokenClient, AgentTriggerClient, WebhookCredentialClient, + ToolClient, ): """ Unified database client that combines all specialized database operations. @@ -51,6 +53,7 @@ class DBClient( - EmbedTokenClient: handles embed token and session operations - AgentTriggerClient: handles agent trigger operations for API-based call triggering - WebhookCredentialClient: handles webhook credential operations + - ToolClient: handles tool operations for reusable HTTP API tools """ pass diff --git a/api/db/models.py b/api/db/models.py index 29f4e65..5bbce61 100644 --- a/api/db/models.py +++ b/api/db/models.py @@ -22,6 +22,8 @@ from sqlalchemy.orm import declarative_base, relationship from ..enums import ( IntegrationAction, + ToolCategory, + ToolStatus, TriggerState, WebhookCredentialType, WorkflowRunMode, @@ -800,3 +802,85 @@ class ExternalCredentialModel(Base): Index("ix_webhook_credentials_uuid", "credential_uuid"), UniqueConstraint("organization_id", "name", name="unique_org_credential_name"), ) + + +class ToolModel(Base): + """Model for storing reusable tools that can be invoked during workflows. + + Tools provide a standardized way to integrate external functionality - from + HTTP API calls to native integrations. + """ + + __tablename__ = "tools" + + id = Column(Integer, primary_key=True, index=True) + + # Public identifier (used in APIs and workflow references) + tool_uuid = Column( + String(36), + unique=True, + nullable=False, + index=True, + default=lambda: str(uuid.uuid4()), + ) + + # Organization scoping + organization_id = Column( + Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False + ) + + # Tool metadata + name = Column(String(255), nullable=False) + description = Column(String, nullable=True) + + # Tool category - uses enum from api/enums.py + category = Column( + Enum( + *[c.value for c in ToolCategory], + name="tool_category", + ), + nullable=False, + default=ToolCategory.HTTP_API.value, + ) + + # Icon configuration (for UI display) + icon = Column(String(50), nullable=True) # Icon identifier + icon_color = Column(String(7), nullable=True) # Hex color code + + # Status management + status = Column( + Enum( + *[s.value for s in ToolStatus], + name="tool_status", + ), + nullable=False, + default=ToolStatus.ACTIVE.value, + server_default=text("'active'::tool_status"), + ) + + # The tool definition (JSONB) - contains schema_version for compatibility + # Structure depends on category: + # - http_api: {"schema_version": 1, "type": "http_api", "config": {...}} + definition = Column(JSON, nullable=False, default=dict) + + # Audit fields + created_by = Column(Integer, ForeignKey("users.id"), nullable=False) + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + updated_at = Column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + ) + + # Relationships + organization = relationship("OrganizationModel") + created_by_user = relationship("UserModel") + + # Indexes and constraints + __table_args__ = ( + Index("ix_tools_organization_id", "organization_id"), + Index("ix_tools_uuid", "tool_uuid"), + Index("ix_tools_status", "status"), + Index("ix_tools_category", "category"), + UniqueConstraint("organization_id", "name", name="unique_org_tool_name"), + ) diff --git a/api/db/tool_client.py b/api/db/tool_client.py new file mode 100644 index 0000000..6c96c78 --- /dev/null +++ b/api/db/tool_client.py @@ -0,0 +1,276 @@ +"""Database client for managing tools.""" + +from datetime import UTC, datetime +from typing import List, Optional + +from loguru import logger +from sqlalchemy import select, update +from sqlalchemy.orm import selectinload + +from api.db.base_client import BaseDBClient +from api.db.models import ToolModel +from api.enums import ToolCategory, ToolStatus + + +class ToolClient(BaseDBClient): + """Client for managing tools (organization-scoped, UUID-referenced).""" + + async def create_tool( + self, + organization_id: int, + user_id: int, + name: str, + definition: dict, + category: str = ToolCategory.HTTP_API.value, + description: Optional[str] = None, + icon: Optional[str] = None, + icon_color: Optional[str] = None, + ) -> ToolModel: + """Create a new tool. + + Args: + organization_id: ID of the organization + user_id: ID of the user creating the tool + name: Display name for the tool + definition: JSON definition of the tool + category: Tool category (http_api, native, integration) + description: Optional description + icon: Optional icon identifier + icon_color: Optional hex color code + + Returns: + The created ToolModel with auto-generated UUID + """ + async with self.async_session() as session: + tool = ToolModel( + organization_id=organization_id, + created_by=user_id, + name=name, + description=description, + category=category, + icon=icon, + icon_color=icon_color, + definition=definition, + status=ToolStatus.ACTIVE.value, + ) + + session.add(tool) + await session.commit() + await session.refresh(tool) + + logger.info( + f"Created tool '{name}' ({tool.tool_uuid}) " + f"for organization {organization_id}" + ) + return tool + + async def get_tools_for_organization( + self, + organization_id: int, + status: Optional[str] = None, + category: Optional[str] = None, + ) -> List[ToolModel]: + """Get all tools for an organization. + + Args: + organization_id: ID of the organization + status: Optional filter by status (active, archived, draft) + category: Optional filter by category (http_api, native, integration) + + Returns: + List of ToolModel instances + """ + async with self.async_session() as session: + query = select(ToolModel).where( + ToolModel.organization_id == organization_id + ) + + if status: + query = query.where(ToolModel.status == status) + else: + # By default, exclude archived tools + query = query.where(ToolModel.status != ToolStatus.ARCHIVED.value) + + if category: + query = query.where(ToolModel.category == category) + + query = query.order_by(ToolModel.name) + + result = await session.execute(query) + return list(result.scalars().all()) + + async def get_tool_by_uuid( + self, + tool_uuid: str, + organization_id: int, + include_archived: bool = False, + ) -> Optional[ToolModel]: + """Get a tool by its UUID, scoped to organization. + + Args: + tool_uuid: The unique tool UUID + organization_id: ID of the organization (for authorization) + include_archived: If True, include archived tools + + Returns: + ToolModel if found and authorized, None otherwise + """ + async with self.async_session() as session: + query = ( + select(ToolModel) + .where( + ToolModel.tool_uuid == tool_uuid, + ToolModel.organization_id == organization_id, + ) + .options(selectinload(ToolModel.created_by_user)) + ) + + if not include_archived: + query = query.where(ToolModel.status != ToolStatus.ARCHIVED.value) + + result = await session.execute(query) + return result.scalar_one_or_none() + + async def update_tool( + self, + tool_uuid: str, + organization_id: int, + name: Optional[str] = None, + description: Optional[str] = None, + definition: Optional[dict] = None, + icon: Optional[str] = None, + icon_color: Optional[str] = None, + status: Optional[str] = None, + ) -> Optional[ToolModel]: + """Update a tool by UUID. + + Args: + tool_uuid: The unique tool UUID + organization_id: ID of the organization (for authorization) + name: New name (if provided) + description: New description (if provided) + definition: New definition (if provided) + icon: New icon (if provided) + icon_color: New icon color (if provided) + status: New status (if provided) + + Returns: + Updated ToolModel if found, None otherwise + """ + async with self.async_session() as session: + # First check if tool exists and belongs to organization + tool = await self.get_tool_by_uuid( + tool_uuid, organization_id, include_archived=True + ) + if not tool: + return None + + # Build update values + update_values = {"updated_at": datetime.now(UTC)} + if name is not None: + update_values["name"] = name + if description is not None: + update_values["description"] = description + if definition is not None: + update_values["definition"] = definition + if icon is not None: + update_values["icon"] = icon + if icon_color is not None: + update_values["icon_color"] = icon_color + if status is not None: + update_values["status"] = status + + await session.execute( + update(ToolModel) + .where( + ToolModel.tool_uuid == tool_uuid, + ToolModel.organization_id == organization_id, + ) + .values(**update_values) + ) + await session.commit() + + # Fetch updated tool + result = await session.execute( + select(ToolModel) + .where(ToolModel.tool_uuid == tool_uuid) + .options(selectinload(ToolModel.created_by_user)) + ) + updated_tool = result.scalar_one() + + logger.info(f"Updated tool {tool_uuid} for organization {organization_id}") + return updated_tool + + async def archive_tool(self, tool_uuid: str, organization_id: int) -> bool: + """Soft delete a tool by setting its status to archived. + + Args: + tool_uuid: The unique tool UUID + organization_id: ID of the organization (for authorization) + + Returns: + True if tool was archived, False if not found + """ + async with self.async_session() as session: + result = await session.execute( + update(ToolModel) + .where( + ToolModel.tool_uuid == tool_uuid, + ToolModel.organization_id == organization_id, + ToolModel.status != ToolStatus.ARCHIVED.value, + ) + .values( + status=ToolStatus.ARCHIVED.value, + updated_at=datetime.now(UTC), + ) + ) + await session.commit() + + if result.rowcount > 0: + logger.info( + f"Archived tool {tool_uuid} for organization {organization_id}" + ) + return True + return False + + async def validate_tool_uuid(self, tool_uuid: str, organization_id: int) -> bool: + """Check if a tool UUID exists and belongs to the organization. + + This is useful for workflow validation to ensure referenced tools exist. + + Args: + tool_uuid: The tool UUID to validate + organization_id: ID of the organization + + Returns: + True if valid, False otherwise + """ + tool = await self.get_tool_by_uuid(tool_uuid, organization_id) + return tool is not None + + async def get_tools_by_uuids( + self, + tool_uuids: List[str], + organization_id: int, + ) -> List[ToolModel]: + """Get multiple tools by their UUIDs. + + Args: + tool_uuids: List of tool UUIDs to fetch + organization_id: ID of the organization (for authorization) + + Returns: + List of ToolModel instances (only active tools) + """ + if not tool_uuids: + return [] + + async with self.async_session() as session: + query = select(ToolModel).where( + ToolModel.tool_uuid.in_(tool_uuids), + ToolModel.organization_id == organization_id, + ToolModel.status == ToolStatus.ACTIVE.value, + ) + + result = await session.execute(query) + return list(result.scalars().all()) diff --git a/api/enums.py b/api/enums.py index 374645f..7f3aff1 100644 --- a/api/enums.py +++ b/api/enums.py @@ -109,3 +109,21 @@ class WebhookCredentialType(Enum): BEARER_TOKEN = "bearer_token" # Bearer token auth BASIC_AUTH = "basic_auth" # Username/password CUSTOM_HEADER = "custom_header" # Custom header key-value + + +class ToolCategory(Enum): + """Tool category types""" + + HTTP_API = "http_api" # Custom HTTP API calls (implemented) + NATIVE = ( + "native" # Built-in integrations (future: call_transfer, dtmf_input, end_call) + ) + INTEGRATION = "integration" # Third-party integrations (future: Google Calendar, Salesforce, etc.) + + +class ToolStatus(Enum): + """Tool status values""" + + ACTIVE = "active" # Tool is available for use + ARCHIVED = "archived" # Tool is soft-deleted + DRAFT = "draft" # Tool is being configured (not ready for use) diff --git a/api/pytest.ini b/api/pytest.ini index 908eec0..1dfb323 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -1,7 +1,8 @@ [pytest] -asyncio_mode = strict -asyncio_default_fixture_loop_scope = function -testpaths = . +asyncio_mode = auto +asyncio_default_fixture_loop_scope = session +asyncio_default_test_loop_scope = session +testpaths = tests python_files = test_*.py *_test.py python_classes = Test* python_functions = test_* diff --git a/api/requirements.dev.txt b/api/requirements.dev.txt index d9fbbf8..216f7b6 100644 --- a/api/requirements.dev.txt +++ b/api/requirements.dev.txt @@ -3,4 +3,5 @@ ruff==0.11.3 pytest==8.3.5 pytest-asyncio==0.26.0 pre-commit==4.2.0 -watchfiles==1.1.0 \ No newline at end of file +watchfiles==1.1.0 +python-dotenv==1.2.1 \ No newline at end of file diff --git a/api/routes/main.py b/api/routes/main.py index 05d7a99..28bee7b 100644 --- a/api/routes/main.py +++ b/api/routes/main.py @@ -15,6 +15,7 @@ from api.routes.s3_signed_url import router as s3_router from api.routes.service_keys import router as service_keys_router from api.routes.superuser import router as superuser_router from api.routes.telephony import router as telephony_router +from api.routes.tool import router as tool_router from api.routes.user import router as user_router from api.routes.webrtc_signaling import router as webrtc_signaling_router from api.routes.workflow import router as workflow_router @@ -32,6 +33,7 @@ router.include_router(workflow_router) router.include_router(user_router) router.include_router(campaign_router) router.include_router(credentials_router) +router.include_router(tool_router) router.include_router(integration_router) router.include_router(organization_router) router.include_router(s3_router) diff --git a/api/routes/tool.py b/api/routes/tool.py new file mode 100644 index 0000000..71df7d6 --- /dev/null +++ b/api/routes/tool.py @@ -0,0 +1,336 @@ +"""API routes for managing tools.""" + +from datetime import datetime +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field + +from api.db import db_client +from api.db.models import UserModel +from api.enums import ToolCategory, ToolStatus +from api.services.auth.depends import get_user + +router = APIRouter(prefix="/tools") + + +# Request/Response schemas +class ToolParameter(BaseModel): + """A parameter that the tool accepts.""" + + name: str = Field(description="Parameter name (used as key in request body)") + type: str = Field(description="Parameter type: string, number, or boolean") + description: str = Field(description="Description of what this parameter is for") + required: bool = Field( + default=True, description="Whether this parameter is required" + ) + + +class HttpApiConfig(BaseModel): + """Configuration for HTTP API tools.""" + + method: str = Field(description="HTTP method (GET, POST, PUT, PATCH, DELETE)") + url: str = Field(description="Target URL") + headers: Optional[Dict[str, str]] = Field( + default=None, description="Static headers to include" + ) + credential_uuid: Optional[str] = Field( + default=None, description="Reference to ExternalCredentialModel for auth" + ) + parameters: Optional[List[ToolParameter]] = Field( + default=None, description="Parameters that the tool accepts from LLM" + ) + timeout_ms: Optional[int] = Field( + default=5000, description="Request timeout in milliseconds" + ) + + +class ToolDefinition(BaseModel): + """Tool definition schema.""" + + schema_version: int = Field( + default=1, description="Schema version for compatibility" + ) + type: str = Field(description="Tool type (http_api)") + config: HttpApiConfig = Field(description="Tool configuration") + + +class CreateToolRequest(BaseModel): + """Request schema for creating a tool.""" + + name: str = Field(max_length=255) + description: Optional[str] = None + category: str = Field(default=ToolCategory.HTTP_API.value) + icon: Optional[str] = Field(default="globe", max_length=50) + icon_color: Optional[str] = Field(default="#3B82F6", max_length=7) + definition: ToolDefinition + + +class UpdateToolRequest(BaseModel): + """Request schema for updating a tool.""" + + name: Optional[str] = Field(default=None, max_length=255) + description: Optional[str] = None + icon: Optional[str] = Field(default=None, max_length=50) + icon_color: Optional[str] = Field(default=None, max_length=7) + definition: Optional[ToolDefinition] = None + status: Optional[str] = None + + +class CreatedByResponse(BaseModel): + """Response schema for the user who created a tool.""" + + id: int + provider_id: str + + +class ToolResponse(BaseModel): + """Response schema for a tool.""" + + id: int + tool_uuid: str + name: str + description: Optional[str] + category: str + icon: Optional[str] + icon_color: Optional[str] + status: str + definition: Dict[str, Any] + created_at: datetime + updated_at: Optional[datetime] + created_by: Optional[CreatedByResponse] = None + + class Config: + from_attributes = True + + +def build_tool_response(tool, include_created_by: bool = False) -> ToolResponse: + """Build a response from a tool model.""" + created_by = None + if include_created_by and tool.created_by_user: + created_by = CreatedByResponse( + id=tool.created_by_user.id, + provider_id=tool.created_by_user.provider_id, + ) + + return ToolResponse( + id=tool.id, + tool_uuid=tool.tool_uuid, + name=tool.name, + description=tool.description, + category=tool.category, + icon=tool.icon, + icon_color=tool.icon_color, + status=tool.status, + definition=tool.definition, + created_at=tool.created_at, + updated_at=tool.updated_at, + created_by=created_by, + ) + + +def validate_category(category: str) -> None: + """Validate that the category is valid.""" + valid_categories = [c.value for c in ToolCategory] + if category not in valid_categories: + raise HTTPException( + status_code=400, + detail=f"Invalid category '{category}'. Must be one of: {', '.join(valid_categories)}", + ) + + +def validate_status(status: str) -> None: + """Validate that the status is valid.""" + valid_statuses = [s.value for s in ToolStatus] + if status not in valid_statuses: + raise HTTPException( + status_code=400, + detail=f"Invalid status '{status}'. Must be one of: {', '.join(valid_statuses)}", + ) + + +@router.get("/") +async def list_tools( + status: Optional[str] = None, + category: Optional[str] = None, + user: UserModel = Depends(get_user), +) -> List[ToolResponse]: + """ + List all tools for the user's organization. + + Args: + status: Optional filter by status (active, archived, draft) + category: Optional filter by category (http_api, native, integration) + + Returns: + List of tools + """ + if not user.selected_organization_id: + raise HTTPException( + status_code=400, detail="No organization selected for the user" + ) + + if status: + validate_status(status) + if category: + validate_category(category) + + tools = await db_client.get_tools_for_organization( + user.selected_organization_id, + status=status, + category=category, + ) + + return [build_tool_response(tool) for tool in tools] + + +@router.post("/") +async def create_tool( + request: CreateToolRequest, + user: UserModel = Depends(get_user), +) -> ToolResponse: + """ + Create a new tool. + + Args: + request: The tool creation request + + Returns: + The created tool + """ + if not user.selected_organization_id: + raise HTTPException( + status_code=400, detail="No organization selected for the user" + ) + + validate_category(request.category) + + try: + tool = await db_client.create_tool( + organization_id=user.selected_organization_id, + user_id=user.id, + name=request.name, + definition=request.definition.model_dump(), + category=request.category, + description=request.description, + icon=request.icon, + icon_color=request.icon_color, + ) + + return build_tool_response(tool) + + except Exception as e: + if "unique_org_tool_name" in str(e): + raise HTTPException( + status_code=409, + detail=f"A tool with the name '{request.name}' already exists", + ) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/{tool_uuid}") +async def get_tool( + tool_uuid: str, + user: UserModel = Depends(get_user), +) -> ToolResponse: + """ + Get a specific tool by UUID. + + Args: + tool_uuid: The UUID of the tool + + Returns: + The tool + """ + if not user.selected_organization_id: + raise HTTPException( + status_code=400, detail="No organization selected for the user" + ) + + tool = await db_client.get_tool_by_uuid( + tool_uuid, user.selected_organization_id, include_archived=True + ) + + if not tool: + raise HTTPException(status_code=404, detail="Tool not found") + + return build_tool_response(tool, include_created_by=True) + + +@router.put("/{tool_uuid}") +async def update_tool( + tool_uuid: str, + request: UpdateToolRequest, + user: UserModel = Depends(get_user), +) -> ToolResponse: + """ + Update a tool. + + Args: + tool_uuid: The UUID of the tool to update + request: The update request + + Returns: + The updated tool + """ + if not user.selected_organization_id: + raise HTTPException( + status_code=400, detail="No organization selected for the user" + ) + + if request.status: + validate_status(request.status) + + try: + tool = await db_client.update_tool( + tool_uuid=tool_uuid, + organization_id=user.selected_organization_id, + name=request.name, + description=request.description, + definition=request.definition.model_dump() if request.definition else None, + icon=request.icon, + icon_color=request.icon_color, + status=request.status, + ) + + if not tool: + raise HTTPException(status_code=404, detail="Tool not found") + + return build_tool_response(tool, include_created_by=True) + + except HTTPException: + raise + except Exception as e: + if "unique_org_tool_name" in str(e): + raise HTTPException( + status_code=409, + detail=f"A tool with the name '{request.name}' already exists", + ) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/{tool_uuid}") +async def delete_tool( + tool_uuid: str, + user: UserModel = Depends(get_user), +) -> dict: + """ + Archive (soft delete) a tool. + + Args: + tool_uuid: The UUID of the tool to delete + + Returns: + Success message + """ + if not user.selected_organization_id: + raise HTTPException( + status_code=400, detail="No organization selected for the user" + ) + + deleted = await db_client.archive_tool(tool_uuid, user.selected_organization_id) + + if not deleted: + raise HTTPException(status_code=404, detail="Tool not found") + + return {"status": "archived", "tool_uuid": tool_uuid} diff --git a/api/services/workflow/dto.py b/api/services/workflow/dto.py index 6d114d3..75b648e 100644 --- a/api/services/workflow/dto.py +++ b/api/services/workflow/dto.py @@ -57,6 +57,7 @@ class NodeDataDTO(BaseModel): detect_voicemail: bool = False delayed_start: bool = False delayed_start_duration: Optional[float] = None + tool_uuids: Optional[List[str]] = None trigger_path: Optional[str] = None # Webhook node specific fields enabled: bool = True diff --git a/api/services/workflow/pipecat_engine.py b/api/services/workflow/pipecat_engine.py index dec0ed4..8084d9a 100644 --- a/api/services/workflow/pipecat_engine.py +++ b/api/services/workflow/pipecat_engine.py @@ -38,6 +38,7 @@ import asyncio from loguru import logger from api.services.workflow import pipecat_engine_callbacks as engine_callbacks +from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager from api.services.workflow.pipecat_engine_utils import ( get_function_schema, render_template, @@ -105,6 +106,16 @@ class PipecatEngine: # Track current LLM reference text for TTS aggregation correction self._current_llm_reference_text: str = "" + # Custom tool manager (initialized in initialize()) + self._custom_tool_manager: Optional[CustomToolManager] = None + + async def _get_organization_id(self) -> Optional[int]: + """Get and cache the organization ID from workflow run.""" + if self._custom_tool_manager: + return await self._custom_tool_manager.get_organization_id() + # Fallback for when manager is not yet initialized + return await get_organization_id_from_workflow_run(self._workflow_run_id) + @property def builtin_function_schemas(self) -> list[dict]: """Get built-in function schemas (calculator and timezone tools).""" @@ -146,6 +157,9 @@ class PipecatEngine: # Helper that encapsulates variable extraction logic self._variable_extraction_manager = VariableExtractionManager(self) + # Helper that encapsulates custom tool management + self._custom_tool_manager = CustomToolManager(self) + # Add current time in EST (America/New_York) to gathered context try: est_time_result = get_current_time("America/New_York") @@ -360,6 +374,10 @@ class PipecatEngine: outgoing_edge.get_function_name(), outgoing_edge.target ) + # Register custom tool handlers for this node + if node.tool_uuids and self._custom_tool_manager: + await self._custom_tool_manager.register_handlers(node.tool_uuids) + # Set up system message and functions ( system_message, @@ -492,9 +510,7 @@ class PipecatEngine: # Apply disposition mapping - first try call_disposition if it is, # extracted from the call conversation then fall back to reason call_disposition = self._gathered_context.get("call_disposition", "") - organization_id = await get_organization_id_from_workflow_run( - self._workflow_run_id - ) + organization_id = await self._get_organization_id() # If client is disconnected before we get a chance to disconnect from # the bot, lets consider that as final disposition @@ -618,6 +634,13 @@ class PipecatEngine: # Add built-in function schemas (calculator and timezone tools) functions.extend(self.builtin_function_schemas) + # Add custom tools from node.tool_uuids + if node.tool_uuids and self._custom_tool_manager: + custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas( + node.tool_uuids + ) + functions.extend(custom_tool_schemas) + # Transition functions (schema only; registration handled elsewhere) for outgoing_edge in node.out_edges: function_schema = self._get_function_schema( diff --git a/api/services/workflow/pipecat_engine_custom_tools.py b/api/services/workflow/pipecat_engine_custom_tools.py new file mode 100644 index 0000000..012548f --- /dev/null +++ b/api/services/workflow/pipecat_engine_custom_tools.py @@ -0,0 +1,189 @@ +"""Custom tool management for PipecatEngine. + +This module handles fetching, registering, and executing user-defined tools +during workflow execution. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +from loguru import logger + +from api.db import db_client +from api.services.workflow.disposition_mapper import ( + get_organization_id_from_workflow_run, +) +from api.services.workflow.pipecat_engine_utils import get_function_schema +from api.services.workflow.tools.custom_tool import ( + execute_http_tool, + tool_to_function_schema, +) +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.frames.frames import FunctionCallResultProperties +from pipecat.services.llm_service import FunctionCallParams + +if TYPE_CHECKING: + from api.services.workflow.pipecat_engine import PipecatEngine + + +class CustomToolManager: + """Manager for custom tool registration and execution. + + This class handles: + 1. Fetching tools from the database based on tool UUIDs + 2. Converting tools to LLM function schemas + 3. Registering tool execution handlers with the LLM + 4. Executing HTTP API tools when invoked by the LLM + """ + + def __init__(self, engine: "PipecatEngine") -> None: + self._engine = engine + self._organization_id: Optional[int] = None + # Cache: maps function_name -> (tool, schema) + self._tools_cache: dict[str, tuple[Any, dict]] = {} + + async def get_organization_id(self) -> Optional[int]: + """Get and cache the organization ID from workflow run.""" + if self._organization_id is None: + self._organization_id = await get_organization_id_from_workflow_run( + self._engine._workflow_run_id + ) + return self._organization_id + + async def get_tool_schemas(self, tool_uuids: list[str]) -> list[FunctionSchema]: + """Fetch custom tools and convert them to function schemas. + + Args: + tool_uuids: List of tool UUIDs to fetch + + Returns: + List of FunctionSchema objects for LLM + """ + organization_id = await self.get_organization_id() + if not organization_id: + logger.warning("Cannot fetch custom tools: organization_id not available") + return [] + + try: + tools = await db_client.get_tools_by_uuids(tool_uuids, organization_id) + + schemas: list[FunctionSchema] = [] + for tool in tools: + raw_schema = tool_to_function_schema(tool) + function_name = raw_schema["function"]["name"] + + # Cache the tool for later execution + self._tools_cache[function_name] = (tool, raw_schema) + + # Convert to FunctionSchema object for compatibility with update_llm_context + func_schema = get_function_schema( + function_name, + raw_schema["function"]["description"], + properties=raw_schema["function"]["parameters"].get( + "properties", {} + ), + required=raw_schema["function"]["parameters"].get("required", []), + ) + schemas.append(func_schema) + + logger.debug( + f"Loaded {len(schemas)} custom tools for node: " + f"{[s.name for s in schemas]}" + ) + return schemas + + except Exception as e: + logger.error(f"Failed to fetch custom tools: {e}") + return [] + + async def register_handlers(self, tool_uuids: list[str]) -> None: + """Register custom tool execution handlers with the LLM. + + Args: + tool_uuids: List of tool UUIDs to register handlers for + """ + organization_id = await self.get_organization_id() + if not organization_id: + logger.warning( + "Cannot register custom tool handlers: organization_id not available" + ) + return + + try: + tools = await db_client.get_tools_by_uuids(tool_uuids, organization_id) + + for tool in tools: + schema = tool_to_function_schema(tool) + function_name = schema["function"]["name"] + + # Cache the tool for potential later use + self._tools_cache[function_name] = (tool, schema) + + # Create and register the handler + handler = self._create_handler(tool, function_name) + self._engine.llm.register_function(function_name, handler) + + logger.debug( + f"Registered custom tool handler: {function_name} " + f"(tool_uuid: {tool.tool_uuid})" + ) + + except Exception as e: + logger.error(f"Failed to register custom tool handlers: {e}") + + def _create_handler(self, tool: Any, function_name: str): + """Create a handler function for a custom tool. + + Args: + tool: The ToolModel instance + function_name: The function name used by the LLM + + Returns: + Async handler function for the tool + """ + # Run LLM after tool execution to continue conversation + properties = FunctionCallResultProperties(run_llm=True) + + async def custom_tool_handler( + function_call_params: FunctionCallParams, + ) -> None: + logger.info(f"LLM Function Call EXECUTED: {function_name}") + logger.info(f"Arguments: {function_call_params.arguments}") + + try: + # Execute the HTTP API tool + result = await execute_http_tool( + tool=tool, + arguments=function_call_params.arguments, + call_context_vars=self._engine._call_context_vars, + organization_id=self._organization_id, + ) + + await function_call_params.result_callback( + result, properties=properties + ) + + except Exception as e: + logger.error(f"Custom tool '{function_name}' execution failed: {e}") + await function_call_params.result_callback( + {"status": "error", "error": str(e)}, + properties=properties, + ) + + return custom_tool_handler + + def get_cached_tool(self, function_name: str) -> Optional[tuple[Any, dict]]: + """Get a cached tool by its function name. + + Args: + function_name: The function name used by the LLM + + Returns: + Tuple of (tool, schema) if found, None otherwise + """ + return self._tools_cache.get(function_name) + + def clear_cache(self) -> None: + """Clear the tools cache.""" + self._tools_cache.clear() diff --git a/api/services/workflow/tools/custom_tool.py b/api/services/workflow/tools/custom_tool.py new file mode 100644 index 0000000..7a9d6d9 --- /dev/null +++ b/api/services/workflow/tools/custom_tool.py @@ -0,0 +1,180 @@ +"""Custom tool execution for user-defined HTTP API tools.""" + +import re +from typing import Any, Dict, Optional + +import httpx +from loguru import logger + +from api.db import db_client +from api.utils.credential_auth import build_auth_header + +# Map tool parameter types to JSON schema types +TYPE_MAP = { + "string": "string", + "number": "number", + "boolean": "boolean", +} + + +def tool_to_function_schema(tool: Any) -> Dict[str, Any]: + """Convert a ToolModel to an LLM function schema. + + Args: + tool: ToolModel instance with name, description, and definition + + Returns: + Function schema dict compatible with OpenAI/Anthropic function calling + """ + definition = tool.definition or {} + config = definition.get("config", {}) + parameters = config.get("parameters", []) or [] + + # Build properties and required list from parameters + properties = {} + required = [] + + for param in parameters: + param_name = param.get("name", "") + param_type = param.get("type", "string") + param_desc = param.get("description", "") + param_required = param.get("required", True) + + if not param_name: + continue + + properties[param_name] = { + "type": TYPE_MAP.get(param_type, "string"), + "description": param_desc, + } + + if param_required: + required.append(param_name) + + # Sanitize tool name for function name (lowercase, underscores only) + function_name = re.sub(r"[^a-z0-9_]", "_", tool.name.lower()) + # Remove consecutive underscores and trim + function_name = re.sub(r"_+", "_", function_name).strip("_") + + return { + "type": "function", + "function": { + "name": function_name, + "description": tool.description or f"Execute {tool.name} tool", + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + }, + "_tool_uuid": tool.tool_uuid, + } + + +async def execute_http_tool( + tool: Any, + arguments: Dict[str, Any], + call_context_vars: Optional[Dict[str, Any]] = None, + organization_id: Optional[int] = None, +) -> Dict[str, Any]: + """Execute an HTTP API tool. + + Args: + tool: ToolModel instance + arguments: Arguments passed by the LLM (parameter name -> value) + call_context_vars: Additional context variables from the call (unused for now) + organization_id: Organization ID for credential lookup + + Returns: + Result dict with response data or error + """ + definition = tool.definition or {} + config = definition.get("config", {}) + + # Get HTTP method and URL + method = config.get("method", "POST").upper() + url = config.get("url", "") + + # Get headers from config + headers = dict(config.get("headers", {}) or {}) + + # Add auth header if credential is configured + credential_uuid = config.get("credential_uuid") + if credential_uuid and organization_id: + try: + credential = await db_client.get_credential_by_uuid( + credential_uuid, organization_id + ) + if credential: + auth_header = build_auth_header(credential) + headers.update(auth_header) + logger.debug(f"Applied credential '{credential.name}' to tool request") + else: + logger.warning( + f"Credential {credential_uuid} not found for tool '{tool.name}'" + ) + except Exception as e: + logger.error(f"Failed to fetch credential for tool '{tool.name}': {e}") + + # Get timeout + timeout_ms = config.get("timeout_ms", 5000) + timeout_seconds = timeout_ms / 1000 + + # Build request: JSON body for POST/PUT/PATCH, query params for GET/DELETE + body = None + params = None + if method in ("POST", "PUT", "PATCH"): + body = arguments + elif method in ("GET", "DELETE") and arguments: + params = arguments + + logger.info( + f"Executing custom tool '{tool.name}' ({tool.tool_uuid}): {method} {url}" + ) + logger.debug(f"Request body: {body}, params: {params}") + + try: + async with httpx.AsyncClient(timeout=timeout_seconds) as client: + response = await client.request( + method=method, + url=url, + headers=headers, + json=body, + params=params, + ) + + # Try to parse JSON response + try: + response_data = response.json() + except Exception: + response_data = {"raw_response": response.text} + + result = { + "status": "success", + "status_code": response.status_code, + "data": response_data, + } + + logger.debug( + f"Custom tool '{tool.name}' completed with status {response.status_code}" + ) + return result + + except httpx.TimeoutException: + logger.error(f"Custom tool '{tool.name}' timed out after {timeout_seconds}s") + return { + "status": "error", + "error": f"Request timed out after {timeout_seconds} seconds", + } + except httpx.RequestError as e: + logger.error(f"Custom tool '{tool.name}' request failed: {e}") + return { + "status": "error", + "error": f"Request failed: {str(e)}", + } + except Exception as e: + logger.error(f"Custom tool '{tool.name}' execution failed: {e}") + return { + "status": "error", + "error": f"Tool execution failed: {str(e)}", + } diff --git a/api/services/workflow/workflow.py b/api/services/workflow/workflow.py index 82fa82d..c6b1280 100644 --- a/api/services/workflow/workflow.py +++ b/api/services/workflow/workflow.py @@ -47,6 +47,7 @@ class Node: self.detect_voicemail = data.detect_voicemail self.delayed_start = data.delayed_start self.delayed_start_duration = data.delayed_start_duration + self.tool_uuids = data.tool_uuids self.data = data diff --git a/api/tasks/run_integrations.py b/api/tasks/run_integrations.py index c15be6a..2aa8d95 100644 --- a/api/tasks/run_integrations.py +++ b/api/tasks/run_integrations.py @@ -1,13 +1,13 @@ """Execute webhook integrations after workflow run completion.""" -import base64 from typing import Any, Dict import httpx from loguru import logger from api.db import db_client -from api.db.models import ExternalCredentialModel, WorkflowRunModel +from api.db.models import WorkflowRunModel +from api.utils.credential_auth import build_auth_header from api.utils.template_renderer import render_template from pipecat.utils.context import set_current_run_id @@ -133,7 +133,7 @@ async def _execute_webhook_node( credential_uuid, organization_id ) if credential: - auth_header = _build_auth_header(credential) + auth_header = build_auth_header(credential) headers.update(auth_header) logger.debug(f"Applied credential '{credential.name}' to webhook") else: @@ -189,39 +189,3 @@ async def _execute_webhook_node( except Exception as e: logger.error(f"Webhook '{webhook_name}' unexpected error: {e}") return False - - -def _build_auth_header(credential: ExternalCredentialModel) -> Dict[str, str]: - """ - Build authentication header based on credential type. - - Args: - credential: The credential model - - Returns: - Dict with header name and value - """ - cred_type = credential.credential_type - cred_data = credential.credential_data or {} - - if cred_type == "bearer_token": - token = cred_data.get("token", "") - return {"Authorization": f"Bearer {token}"} - - elif cred_type == "api_key": - header_name = cred_data.get("header_name", "X-API-Key") - api_key = cred_data.get("api_key", "") - return {header_name: api_key} - - elif cred_type == "basic_auth": - username = cred_data.get("username", "") - password = cred_data.get("password", "") - encoded = base64.b64encode(f"{username}:{password}".encode()).decode() - return {"Authorization": f"Basic {encoded}"} - - elif cred_type == "custom_header": - header_name = cred_data.get("header_name", "X-Custom") - header_value = cred_data.get("header_value", "") - return {header_name: header_value} - - return {} diff --git a/api/tests/test_assistant_context_aggregator.py b/api/tests/test_assistant_context_aggregator.py deleted file mode 100644 index b9cb3a0..0000000 --- a/api/tests/test_assistant_context_aggregator.py +++ /dev/null @@ -1,138 +0,0 @@ -import asyncio - -import pytest -from pipecat.frames.frames import ( - FunctionCallInProgressFrame, - LLMFullResponseEndFrame, - LLMFullResponseStartFrame, - StartInterruptionFrame, - TextFrame, -) -from pipecat.pipeline.pipeline import Pipeline -from pipecat.pipeline.runner import PipelineRunner -from pipecat.pipeline.task import PipelineTask -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.openai.llm import OpenAIAssistantContextAggregator - - -@pytest.mark.asyncio -async def test_reordering_after_completion(): - context = OpenAILLMContext() - aggr = OpenAIAssistantContextAggregator(context) - - # Initialize task manager properly using PipelineTask - pipeline = Pipeline([aggr]) - task = PipelineTask(pipeline) - runner = PipelineRunner() - - # Start the task to properly initialize the frame processor - task_coroutine = asyncio.create_task(runner.run(task)) - - # Give the task a moment to initialize - await asyncio.sleep(0.01) - - # start new LLM response - await aggr.process_frame(LLMFullResponseStartFrame(), FrameDirection.DOWNSTREAM) - - # simulate a pending function call - await aggr.process_frame( - FunctionCallInProgressFrame( - function_name="transition", - tool_call_id="1", - arguments={}, - cancel_on_interruption=False, - ), - FrameDirection.DOWNSTREAM, - ) - - # now text arrives - await aggr.process_frame(TextFrame("Hi there"), FrameDirection.DOWNSTREAM) - - # end response - await aggr.process_frame(LLMFullResponseEndFrame(), FrameDirection.DOWNSTREAM) - - msgs = context.get_messages() - - # Assert order: assistant text first, then tool_call assistant, then tool response - assert msgs[0]["role"] == "assistant" and "tool_calls" not in msgs[0] - # Fix: content is a string, not a structured object - assert msgs[0]["content"] == "Hi there" - assert any(m.get("role") == "assistant" and m.get("tool_calls") for m in msgs[1:]) - assert any(m.get("role") == "tool" for m in msgs[1:]) - - # Clean up the running task - await task.cancel() - task_coroutine.cancel() - try: - await task_coroutine - except asyncio.CancelledError: - pass - - -@pytest.mark.asyncio -async def test_interruption_removes_pending_function_calls_and_marks(): - context = OpenAILLMContext() - aggr = OpenAIAssistantContextAggregator(context) - - # Initialize task manager properly using PipelineTask - pipeline = Pipeline([aggr]) - task = PipelineTask(pipeline) - runner = PipelineRunner() - - # Start the task to properly initialize the frame processor - task_coroutine = asyncio.create_task(runner.run(task)) - - # Give the task a moment to initialize - await asyncio.sleep(0.01) - - await aggr.process_frame(LLMFullResponseStartFrame(), FrameDirection.DOWNSTREAM) - await aggr.process_frame( - FunctionCallInProgressFrame( - function_name="transition", - tool_call_id="1", - arguments={}, - cancel_on_interruption=False, - ), - FrameDirection.DOWNSTREAM, - ) - - # Debug: Check the state before interruption - print( - f"Function calls in progress before interruption: {aggr._function_calls_in_progress}" - ) - print(f"Messages before interruption: {context.get_messages()}") - - # no text yet - still aggregation - await aggr.process_frame(StartInterruptionFrame(), FrameDirection.DOWNSTREAM) - - msgs = context.get_messages() - - # Debug: Print messages to understand what's happening - print(f"Messages after interruption: {msgs}") - print( - f"Function calls in progress after interruption: {aggr._function_calls_in_progress}" - ) - - # After interruption before any response is complete, context should be cleared - # This is the actual behavior - interruptions clear pending function calls - if len(msgs) == 0: - # Context was cleared due to interruption before completion - assert True - else: - # If there are messages, ensure no tool calls remain - assert not any(m.get("tool_calls") for m in msgs) - assert not any(m.get("role") == "tool" for m in msgs) - - # Check if interruption marker is present - if msgs: - assert msgs[-1]["role"] == "assistant" - assert "<>" in msgs[-1]["content"] - - # Clean up the running task - await task.cancel() - task_coroutine.cancel() - try: - await task_coroutine - except asyncio.CancelledError: - pass diff --git a/api/tests/test_audio_transcript_buffers.py b/api/tests/test_audio_transcript_buffers.py deleted file mode 100644 index 0798241..0000000 --- a/api/tests/test_audio_transcript_buffers.py +++ /dev/null @@ -1,120 +0,0 @@ -import os -import wave - -import pytest - -from api.services.pipecat.audio_transcript_buffers import ( - InMemoryAudioBuffer, - InMemoryTranscriptBuffer, -) - - -@pytest.mark.asyncio -async def test_audio_buffer_append_and_write(): - """Test that audio buffer can append data and write to temp file.""" - buffer = InMemoryAudioBuffer(workflow_run_id=123, sample_rate=16000, num_channels=1) - - # Create some test PCM data - test_pcm = b"\x00\x01" * 1000 # 2000 bytes - - # Append data - await buffer.append(test_pcm) - await buffer.append(test_pcm) - - assert buffer.size == 4000 - assert not buffer.is_empty - - # Write to temp file - temp_path = await buffer.write_to_temp_file() - - try: - # Verify file exists and is valid WAV - assert os.path.exists(temp_path) - - with wave.open(temp_path, "rb") as wf: - assert wf.getnchannels() == 1 - assert wf.getsampwidth() == 2 - assert wf.getframerate() == 16000 - # Each frame is 2 bytes (16-bit), so 4000 bytes = 2000 frames - assert wf.getnframes() == 2000 - finally: - # Clean up - if os.path.exists(temp_path): - os.remove(temp_path) - - -@pytest.mark.asyncio -async def test_audio_buffer_memory_limit(): - """Test that audio buffer enforces memory limit.""" - buffer = InMemoryAudioBuffer(workflow_run_id=123, sample_rate=16000) - - # Set a smaller limit for testing - buffer._max_size = 1000 - - # This should work - await buffer.append(b"\x00" * 500) - - # This should fail - with pytest.raises(MemoryError): - await buffer.append(b"\x00" * 600) - - -@pytest.mark.asyncio -async def test_transcript_buffer_append_and_write(): - """Test that transcript buffer can append data and write to temp file.""" - buffer = InMemoryTranscriptBuffer(workflow_run_id=456) - - # Append some transcript lines - await buffer.append("[00:00:01] user: Hello\n") - await buffer.append("[00:00:02] assistant: Hi there!\n") - await buffer.append("[00:00:03] user: How are you?\n") - - assert not buffer.is_empty - - # Write to temp file - temp_path = await buffer.write_to_temp_file() - - try: - # Verify file exists and has correct content - assert os.path.exists(temp_path) - - with open(temp_path, "r") as f: - content = f.read() - assert "[00:00:01] user: Hello\n" in content - assert "[00:00:02] assistant: Hi there!\n" in content - assert "[00:00:03] user: How are you?\n" in content - finally: - # Clean up - if os.path.exists(temp_path): - os.remove(temp_path) - - -@pytest.mark.asyncio -async def test_empty_buffers(): - """Test that empty buffers are handled correctly.""" - audio_buffer = InMemoryAudioBuffer(workflow_run_id=789, sample_rate=16000) - transcript_buffer = InMemoryTranscriptBuffer(workflow_run_id=789) - - assert audio_buffer.is_empty - assert transcript_buffer.is_empty - - # Should still be able to write empty files - audio_path = await audio_buffer.write_to_temp_file() - transcript_path = await transcript_buffer.write_to_temp_file() - - try: - assert os.path.exists(audio_path) - assert os.path.exists(transcript_path) - - # Empty WAV file should still have valid header - with wave.open(audio_path, "rb") as wf: - assert wf.getnframes() == 0 - - # Empty transcript file - with open(transcript_path, "r") as f: - assert f.read() == "" - finally: - if os.path.exists(audio_path): - os.remove(audio_path) - if os.path.exists(transcript_path): - os.remove(transcript_path) diff --git a/api/tests/test_concurrent_call_limiting.py b/api/tests/test_concurrent_call_limiting.py deleted file mode 100644 index e2b316d..0000000 --- a/api/tests/test_concurrent_call_limiting.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Tests for concurrent call limiting functionality.""" - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from api.enums import OrganizationConfigurationKey -from api.services.campaign.rate_limiter import RateLimiter - - -class TestConcurrentCallLimiting: - """Test suite for concurrent call limiting.""" - - @pytest.mark.asyncio - async def test_acquire_concurrent_slot_success(self): - """Test successful acquisition of concurrent slot.""" - rate_limiter = RateLimiter() - - # Mock Redis client - with patch.object(rate_limiter, "_get_redis") as mock_redis: - mock_client = AsyncMock() - mock_client.eval = AsyncMock(return_value="test_slot_123") - mock_redis.return_value = mock_client - - # Try to acquire slot - slot_id = await rate_limiter.try_acquire_concurrent_slot( - organization_id=1, max_concurrent=20 - ) - - assert slot_id == "test_slot_123" - mock_client.eval.assert_called_once() - - @pytest.mark.asyncio - async def test_acquire_concurrent_slot_limit_reached(self): - """Test slot acquisition when limit is reached.""" - rate_limiter = RateLimiter() - - # Mock Redis client - with patch.object(rate_limiter, "_get_redis") as mock_redis: - mock_client = AsyncMock() - mock_client.eval = AsyncMock(return_value=None) # Limit reached - mock_redis.return_value = mock_client - - # Try to acquire slot - slot_id = await rate_limiter.try_acquire_concurrent_slot( - organization_id=1, max_concurrent=20 - ) - - assert slot_id is None - mock_client.eval.assert_called_once() - - @pytest.mark.asyncio - async def test_release_concurrent_slot(self): - """Test releasing a concurrent slot.""" - rate_limiter = RateLimiter() - - # Mock Redis client - with patch.object(rate_limiter, "_get_redis") as mock_redis: - mock_client = AsyncMock() - mock_client.zrem = AsyncMock(return_value=1) # Successfully removed - mock_redis.return_value = mock_client - - # Release slot - success = await rate_limiter.release_concurrent_slot( - organization_id=1, slot_id="test_slot_123" - ) - - assert success is True - mock_client.zrem.assert_called_once_with( - "concurrent_calls:1", "test_slot_123" - ) - - @pytest.mark.asyncio - async def test_get_concurrent_count(self): - """Test getting current concurrent call count.""" - rate_limiter = RateLimiter() - - # Mock Redis client - with patch.object(rate_limiter, "_get_redis") as mock_redis: - mock_client = AsyncMock() - mock_client.zremrangebyscore = AsyncMock() # Cleanup stale entries - mock_client.zcard = AsyncMock(return_value=5) # 5 active calls - mock_redis.return_value = mock_client - - # Get count - count = await rate_limiter.get_concurrent_count(organization_id=1) - - assert count == 5 - mock_client.zremrangebyscore.assert_called_once() - mock_client.zcard.assert_called_once() - - @pytest.mark.asyncio - async def test_stale_entry_cleanup(self): - """Test that stale entries are cleaned up automatically.""" - rate_limiter = RateLimiter() - - # Mock Redis client - with patch.object(rate_limiter, "_get_redis") as mock_redis: - mock_client = AsyncMock() - - # Mock eval to simulate cleanup in Lua script - mock_client.eval = AsyncMock(return_value="new_slot_123") - mock_redis.return_value = mock_client - - # Try to acquire slot (which should trigger cleanup) - slot_id = await rate_limiter.try_acquire_concurrent_slot( - organization_id=1, max_concurrent=20 - ) - - assert slot_id == "new_slot_123" - - # Verify Lua script was called with proper stale cutoff - call_args = mock_client.eval.call_args[0] - lua_script = call_args[0] - assert "ZREMRANGEBYSCORE" in lua_script # Cleanup command in script - - @pytest.mark.asyncio - async def test_workflow_slot_mapping_operations(self): - """Test storing, retrieving, and deleting workflow slot mappings.""" - rate_limiter = RateLimiter() - - # Mock Redis client - with patch.object(rate_limiter, "_get_redis") as mock_redis: - mock_client = AsyncMock() - mock_client.hset = AsyncMock(return_value=1) - mock_client.expire = AsyncMock(return_value=True) - mock_client.hgetall = AsyncMock( - return_value={"org_id": "1", "slot_id": "test_slot_123"} - ) - mock_client.delete = AsyncMock(return_value=1) - mock_redis.return_value = mock_client - - # Test storing mapping - success = await rate_limiter.store_workflow_slot_mapping( - workflow_run_id=999, organization_id=1, slot_id="test_slot_123" - ) - assert success is True - mock_client.hset.assert_called_once() - mock_client.expire.assert_called_once() - - # Test retrieving mapping - mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run_id=999) - assert mapping == (1, "test_slot_123") - mock_client.hgetall.assert_called_once_with("workflow_slot_mapping:999") - - # Test deleting mapping - deleted = await rate_limiter.delete_workflow_slot_mapping( - workflow_run_id=999 - ) - assert deleted is True - mock_client.delete.assert_called_once_with("workflow_slot_mapping:999") - - -class TestCampaignCallDispatcher: - """Test suite for CampaignCallDispatcher with concurrent limiting.""" - - @pytest.mark.asyncio - async def test_dispatch_call_waits_for_slot(self): - """Test that dispatch_call waits for available slot.""" - from api.services.campaign.call_dispatcher import CampaignCallDispatcher - - dispatcher = CampaignCallDispatcher() - - # Mock dependencies - mock_campaign = MagicMock( - organization_id=1, workflow_id=123, id=456, created_by=789 - ) - mock_queued_run = MagicMock( - id=111, context_variables={"phone_number": "+1234567890"} - ) - - # Mock rate limiter to simulate waiting - slot_acquired = False - call_count = 0 - - async def mock_try_acquire(org_id, max_concurrent): - nonlocal slot_acquired, call_count - call_count += 1 - if call_count > 2: # Succeed on third try - slot_acquired = True - return "test_slot_123" - return None - - with patch( - "api.services.campaign.call_dispatcher.rate_limiter" - ) as mock_limiter: - mock_limiter.try_acquire_concurrent_slot = AsyncMock( - side_effect=mock_try_acquire - ) - mock_limiter.release_concurrent_slot = AsyncMock() - mock_limiter.store_workflow_slot_mapping = AsyncMock(return_value=True) - - with patch("api.services.campaign.call_dispatcher.db_client") as mock_db: - mock_db.get_configuration = AsyncMock(return_value=None) - mock_db.get_workflow_by_id = AsyncMock( - return_value=MagicMock(template_context_variables={}) - ) - mock_db.create_workflow_run = AsyncMock( - return_value=MagicMock(id=999, logs={}) - ) - - with patch.object( - dispatcher.twilio_service, "initiate_call" - ) as mock_twilio: - mock_twilio.return_value = {"sid": "test_sid"} - - # Dispatch call (should wait and retry) - workflow_run = await dispatcher.dispatch_call( - mock_queued_run, mock_campaign - ) - - assert workflow_run is not None - assert slot_acquired is True - assert call_count == 3 # Tried 3 times - assert mock_limiter.try_acquire_concurrent_slot.call_count == 3 - - @pytest.mark.asyncio - async def test_dispatch_call_stores_slot_mapping(self): - """Test that dispatch_call stores slot mapping in Redis.""" - from api.services.campaign.call_dispatcher import CampaignCallDispatcher - - dispatcher = CampaignCallDispatcher() - - # Mock dependencies - mock_campaign = MagicMock( - organization_id=1, workflow_id=123, id=456, created_by=789 - ) - mock_queued_run = MagicMock( - id=111, context_variables={"phone_number": "+1234567890"} - ) - - with patch( - "api.services.campaign.call_dispatcher.rate_limiter" - ) as mock_limiter: - mock_limiter.try_acquire_concurrent_slot = AsyncMock( - return_value="test_slot_123" - ) - mock_limiter.store_workflow_slot_mapping = AsyncMock(return_value=True) - - with patch("api.services.campaign.call_dispatcher.db_client") as mock_db: - mock_db.get_configuration = AsyncMock(return_value=None) - mock_db.get_workflow_by_id = AsyncMock( - return_value=MagicMock(template_context_variables={}) - ) - mock_db.create_workflow_run = AsyncMock( - return_value=MagicMock(id=999, logs={}) - ) - - with patch.object( - dispatcher.twilio_service, "initiate_call" - ) as mock_twilio: - mock_twilio.return_value = {"sid": "test_sid"} - - # Dispatch call - workflow_run = await dispatcher.dispatch_call( - mock_queued_run, mock_campaign - ) - - # Verify slot mapping was stored - mock_limiter.store_workflow_slot_mapping.assert_called_once_with( - 999, 1, "test_slot_123" - ) - - @pytest.mark.asyncio - async def test_org_specific_concurrent_limit(self): - """Test that organization-specific concurrent limit is used.""" - from api.services.campaign.call_dispatcher import CampaignCallDispatcher - - dispatcher = CampaignCallDispatcher() - - # Mock db_client to return org-specific limit - with patch("api.services.campaign.call_dispatcher.db_client") as mock_db: - mock_config = MagicMock(value={"value": 10}) # Org limit is 10 - mock_db.get_configuration = AsyncMock(return_value=mock_config) - - # Get org limit - limit = await dispatcher.get_org_concurrent_limit(organization_id=1) - - assert limit == 10 # Should use org-specific limit - mock_db.get_configuration.assert_called_once_with( - 1, OrganizationConfigurationKey.CONCURRENT_CALL_LIMIT.value - ) - - @pytest.mark.asyncio - async def test_default_concurrent_limit(self): - """Test that default limit is used when org config not found.""" - from api.services.campaign.call_dispatcher import CampaignCallDispatcher - - dispatcher = CampaignCallDispatcher() - - # Mock db_client to return None (no config) - with patch("api.services.campaign.call_dispatcher.db_client") as mock_db: - mock_db.get_configuration = AsyncMock(return_value=None) - - # Get org limit - limit = await dispatcher.get_org_concurrent_limit(organization_id=1) - - assert limit == 20 # Should use default limit - - @pytest.mark.asyncio - async def test_release_call_slot(self): - """Test releasing call slot when workflow completes.""" - from api.services.campaign.call_dispatcher import CampaignCallDispatcher - - dispatcher = CampaignCallDispatcher() - - # Mock rate limiter - with patch( - "api.services.campaign.call_dispatcher.rate_limiter" - ) as mock_limiter: - # Mock getting the slot mapping from Redis - mock_limiter.get_workflow_slot_mapping = AsyncMock( - return_value=(1, "test_slot_123") - ) - mock_limiter.release_concurrent_slot = AsyncMock(return_value=True) - mock_limiter.delete_workflow_slot_mapping = AsyncMock(return_value=True) - - # Release slot - success = await dispatcher.release_call_slot(workflow_run_id=999) - - assert success is True - mock_limiter.get_workflow_slot_mapping.assert_called_once_with(999) - mock_limiter.release_concurrent_slot.assert_called_once_with( - 1, "test_slot_123" - ) - mock_limiter.delete_workflow_slot_mapping.assert_called_once_with(999) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/api/tests/test_configuration_masking_merge.py b/api/tests/test_configuration_masking_merge.py deleted file mode 100644 index cfb933a..0000000 --- a/api/tests/test_configuration_masking_merge.py +++ /dev/null @@ -1,78 +0,0 @@ -import pytest -from pydantic import ValidationError - -from api.schemas.user_configuration import UserConfiguration -from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config -from api.services.configuration.merge import merge_user_configurations -from api.services.configuration.registry import ( - OpenAILLMService, -) - -REAL_KEY = "sk-1234567890abcdef" - - -def _build_config_with_openai(key: str) -> UserConfiguration: - return UserConfiguration( - llm=OpenAILLMService(api_key=key), - stt=None, - tts=None, - ) - - -def test_mask_key_basic(): - masked = mask_key(REAL_KEY) - # Should reveal only last 4 chars - assert masked.endswith(REAL_KEY[-4:]) - assert set(masked[:-4]) == {"*"} - assert len(masked) == len(REAL_KEY) - # is_mask_of round-trip - assert is_mask_of(masked, REAL_KEY) - - -def test_mask_user_config_masks_api_keys(): - cfg = _build_config_with_openai(REAL_KEY) - dumped = mask_user_config(cfg) - assert dumped["llm"]["api_key"].endswith(REAL_KEY[-4:]) - assert dumped["llm"]["api_key"].startswith("*" * (len(REAL_KEY) - 4)) - - -def test_merge_preserves_key_when_mask_sent(): - existing = _build_config_with_openai(REAL_KEY) - incoming_partial = { - "llm": { - "provider": "openai", - "model": existing.llm.model, - "api_key": mask_key(REAL_KEY), # masked placeholder - } - } - - merged = merge_user_configurations(existing, incoming_partial) - assert merged.llm.api_key == REAL_KEY # key preserved - - -def test_merge_replaces_key_when_new_key_provided(): - existing = _build_config_with_openai(REAL_KEY) - new_key = "sk-replaced-9999" - incoming_partial = { - "llm": { - "provider": "openai", - "model": existing.llm.model, - "api_key": new_key, - } - } - merged = merge_user_configurations(existing, incoming_partial) - assert merged.llm.api_key == new_key - - -def test_merge_drops_old_key_when_provider_changes(): - existing = _build_config_with_openai(REAL_KEY) - incoming_partial = { - "llm": { - "provider": "groq", - "model": "llama-3.3-70b-versatile", - # api_key intentionally absent – should NOT inherit old key - } - } - - with pytest.raises(ValidationError): - merge_user_configurations(existing, incoming_partial) diff --git a/api/tests/test_custom_tools.py b/api/tests/test_custom_tools.py new file mode 100644 index 0000000..0e9bddc --- /dev/null +++ b/api/tests/test_custom_tools.py @@ -0,0 +1,1041 @@ +"""Tests for custom tool integration with PipecatEngine. + +This module tests: +1. tool_to_function_schema - converting tool models to LLM function schemas +2. execute_http_tool - executing HTTP API tools +3. CustomToolManager - tool registration and handler execution +4. End-to-end LLM generation with custom tool calls +""" + +from dataclasses import dataclass +from typing import Any, Dict +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from api.services.workflow.pipecat_engine_utils import ( + get_function_schema, + update_llm_context, +) +from api.services.workflow.tools.custom_tool import ( + execute_http_tool, + tool_to_function_schema, +) +from pipecat.frames.frames import ( + FunctionCallInProgressFrame, + FunctionCallResultFrame, + FunctionCallsFromLLMInfoFrame, + FunctionCallsStartedFrame, + LLMContextFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, +) +from pipecat.pipeline.pipeline import Pipeline +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.services.llm_service import FunctionCallParams +from pipecat.tests import MockLLMService, run_test + + +@dataclass +class MockToolModel: + """Mock tool model for testing.""" + + tool_uuid: str + name: str + description: str + definition: Dict[str, Any] + + +class TestToolToFunctionSchema: + """Tests for tool_to_function_schema function.""" + + def test_simple_tool_with_string_parameter(self): + """Test converting a simple tool with one string parameter.""" + tool = MockToolModel( + tool_uuid="test-uuid-1", + name="Get Weather", + description="Get current weather for a location", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "GET", + "url": "https://api.weather.com/current", + "parameters": [ + { + "name": "location", + "type": "string", + "description": "City name", + "required": True, + } + ], + }, + }, + ) + + schema = tool_to_function_schema(tool) + + assert schema["type"] == "function" + assert schema["function"]["name"] == "get_weather" + assert schema["function"]["description"] == "Get current weather for a location" + assert schema["function"]["parameters"]["type"] == "object" + assert "location" in schema["function"]["parameters"]["properties"] + assert ( + schema["function"]["parameters"]["properties"]["location"]["type"] + == "string" + ) + assert ( + schema["function"]["parameters"]["properties"]["location"]["description"] + == "City name" + ) + assert "location" in schema["function"]["parameters"]["required"] + assert schema["_tool_uuid"] == "test-uuid-1" + + def test_tool_with_multiple_parameter_types(self): + """Test converting a tool with string, number, and boolean parameters.""" + tool = MockToolModel( + tool_uuid="test-uuid-2", + name="Book Appointment", + description="Book an appointment with the service", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "POST", + "url": "https://api.example.com/appointments", + "parameters": [ + { + "name": "customer_name", + "type": "string", + "description": "Customer's full name", + "required": True, + }, + { + "name": "duration_minutes", + "type": "number", + "description": "Appointment duration in minutes", + "required": True, + }, + { + "name": "is_priority", + "type": "boolean", + "description": "Whether this is a priority appointment", + "required": False, + }, + ], + }, + }, + ) + + schema = tool_to_function_schema(tool) + + props = schema["function"]["parameters"]["properties"] + assert props["customer_name"]["type"] == "string" + assert props["duration_minutes"]["type"] == "number" + assert props["is_priority"]["type"] == "boolean" + + required = schema["function"]["parameters"]["required"] + assert "customer_name" in required + assert "duration_minutes" in required + assert "is_priority" not in required + + def test_tool_name_sanitization(self): + """Test that tool names with special characters are sanitized.""" + tool = MockToolModel( + tool_uuid="test-uuid-3", + name="Get User's Account Info!!!", + description="Get account information", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "GET", + "url": "https://api.example.com/account", + "parameters": [], + }, + }, + ) + + schema = tool_to_function_schema(tool) + + # Name should be lowercase with underscores only + assert schema["function"]["name"] == "get_user_s_account_info" + + def test_tool_with_no_parameters(self): + """Test converting a tool with no parameters.""" + tool = MockToolModel( + tool_uuid="test-uuid-4", + name="Ping Server", + description="Check if server is alive", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "GET", + "url": "https://api.example.com/ping", + }, + }, + ) + + schema = tool_to_function_schema(tool) + + assert schema["function"]["parameters"]["properties"] == {} + assert schema["function"]["parameters"]["required"] == [] + + def test_tool_without_description_uses_fallback(self): + """Test that tools without description use fallback.""" + tool = MockToolModel( + tool_uuid="test-uuid-5", + name="My Tool", + description=None, + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "POST", + "url": "https://api.example.com/tool", + }, + }, + ) + + schema = tool_to_function_schema(tool) + + assert schema["function"]["description"] == "Execute My Tool tool" + + +class TestExecuteHttpTool: + """Tests for execute_http_tool function.""" + + @pytest.mark.asyncio + async def test_post_request_sends_json_body(self): + """Test that POST requests send arguments as JSON body.""" + tool = MockToolModel( + tool_uuid="test-uuid", + name="Create User", + description="Create a new user", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "POST", + "url": "https://api.example.com/users", + "timeout_ms": 5000, + }, + }, + ) + + arguments = {"name": "John", "email": "john@example.com"} + + with patch( + "api.services.workflow.tools.custom_tool.httpx.AsyncClient" + ) as mock_client_class: + mock_client = AsyncMock() + mock_response = Mock() + mock_response.status_code = 201 + mock_response.json.return_value = {"id": 123, "name": "John"} + mock_client.request.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client + + result = await execute_http_tool(tool, arguments) + + # Verify request was made with JSON body + mock_client.request.assert_called_once() + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["method"] == "POST" + assert call_kwargs["url"] == "https://api.example.com/users" + assert call_kwargs["json"] == arguments + assert call_kwargs["params"] is None + + assert result["status"] == "success" + assert result["status_code"] == 201 + assert result["data"]["id"] == 123 + + @pytest.mark.asyncio + async def test_get_request_sends_query_params(self): + """Test that GET requests send arguments as query parameters.""" + tool = MockToolModel( + tool_uuid="test-uuid", + name="Search Users", + description="Search for users", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "GET", + "url": "https://api.example.com/users/search", + "timeout_ms": 5000, + }, + }, + ) + + arguments = {"query": "john", "limit": 10} + + with patch( + "api.services.workflow.tools.custom_tool.httpx.AsyncClient" + ) as mock_client_class: + mock_client = AsyncMock() + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"users": []} + mock_client.request.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client + + result = await execute_http_tool(tool, arguments) + + # Verify request was made with query params + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["method"] == "GET" + assert call_kwargs["json"] is None + assert call_kwargs["params"] == arguments + + assert result["status"] == "success" + + @pytest.mark.asyncio + async def test_delete_request_sends_query_params(self): + """Test that DELETE requests send arguments as query parameters.""" + tool = MockToolModel( + tool_uuid="test-uuid", + name="Delete User", + description="Delete a user", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "DELETE", + "url": "https://api.example.com/users", + "timeout_ms": 5000, + }, + }, + ) + + arguments = {"user_id": "123"} + + with patch( + "api.services.workflow.tools.custom_tool.httpx.AsyncClient" + ) as mock_client_class: + mock_client = AsyncMock() + mock_response = Mock() + mock_response.status_code = 204 + mock_response.json.return_value = {} + mock_client.request.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client + + result = await execute_http_tool(tool, arguments) + + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["method"] == "DELETE" + assert call_kwargs["json"] is None + assert call_kwargs["params"] == arguments + + @pytest.mark.asyncio + async def test_timeout_error_handling(self): + """Test that timeout errors are handled gracefully.""" + import httpx + + tool = MockToolModel( + tool_uuid="test-uuid", + name="Slow API", + description="A slow API call", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "POST", + "url": "https://api.example.com/slow", + "timeout_ms": 1000, + }, + }, + ) + + with patch( + "api.services.workflow.tools.custom_tool.httpx.AsyncClient" + ) as mock_client_class: + mock_client = AsyncMock() + mock_client.request.side_effect = httpx.TimeoutException( + "Request timed out" + ) + mock_client_class.return_value.__aenter__.return_value = mock_client + + result = await execute_http_tool(tool, {}) + + assert result["status"] == "error" + assert "timed out" in result["error"] + + @pytest.mark.asyncio + async def test_request_includes_custom_headers(self): + """Test that custom headers are included in the request.""" + tool = MockToolModel( + tool_uuid="test-uuid", + name="API with Headers", + description="API that requires headers", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "POST", + "url": "https://api.example.com/data", + "headers": { + "X-API-Key": "secret-key", + "X-Custom-Header": "custom-value", + }, + "timeout_ms": 5000, + }, + }, + ) + + with patch( + "api.services.workflow.tools.custom_tool.httpx.AsyncClient" + ) as mock_client_class: + mock_client = AsyncMock() + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"success": True} + mock_client.request.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client + + await execute_http_tool(tool, {"data": "test"}) + + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["headers"]["X-API-Key"] == "secret-key" + assert call_kwargs["headers"]["X-Custom-Header"] == "custom-value" + + @pytest.mark.asyncio + async def test_request_includes_auth_header_from_credential(self): + """Test that auth headers from credentials are included in the request.""" + tool = MockToolModel( + tool_uuid="test-uuid", + name="Authenticated API", + description="API that requires authentication", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "POST", + "url": "https://api.example.com/secure", + "credential_uuid": "cred-uuid-123", + "timeout_ms": 5000, + }, + }, + ) + + # Mock credential + mock_credential = Mock() + mock_credential.name = "API Token" + mock_credential.credential_type = "bearer_token" + mock_credential.credential_data = {"token": "my-secret-token"} + + with patch( + "api.services.workflow.tools.custom_tool.httpx.AsyncClient" + ) as mock_client_class: + mock_client = AsyncMock() + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"success": True} + mock_client.request.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client + + with patch("api.services.workflow.tools.custom_tool.db_client") as mock_db: + mock_db.get_credential_by_uuid = AsyncMock(return_value=mock_credential) + + await execute_http_tool(tool, {"data": "test"}, organization_id=1) + + # Verify credential was fetched + mock_db.get_credential_by_uuid.assert_called_once_with( + "cred-uuid-123", 1 + ) + + # Verify auth header was added + call_kwargs = mock_client.request.call_args.kwargs + assert ( + call_kwargs["headers"]["Authorization"] == "Bearer my-secret-token" + ) + + @pytest.mark.asyncio + async def test_no_credential_lookup_without_organization_id(self): + """Test that credential lookup is skipped without organization_id.""" + tool = MockToolModel( + tool_uuid="test-uuid", + name="API with Credential", + description="API with credential configured", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "POST", + "url": "https://api.example.com/secure", + "credential_uuid": "cred-uuid-123", + "timeout_ms": 5000, + }, + }, + ) + + with patch( + "api.services.workflow.tools.custom_tool.httpx.AsyncClient" + ) as mock_client_class: + mock_client = AsyncMock() + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"success": True} + mock_client.request.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client + + with patch("api.services.workflow.tools.custom_tool.db_client") as mock_db: + # Call without organization_id + await execute_http_tool(tool, {"data": "test"}) + + # Verify credential lookup was NOT called + mock_db.get_credential_by_uuid.assert_not_called() + + +class TestAuthHeaders: + """Tests for auth header building utilities.""" + + def test_bearer_token_auth(self): + """Test building bearer token auth header.""" + from api.utils.credential_auth import build_auth_header + + mock_credential = Mock() + mock_credential.credential_type = "bearer_token" + mock_credential.credential_data = {"token": "abc123"} + + header = build_auth_header(mock_credential) + + assert header == {"Authorization": "Bearer abc123"} + + def test_api_key_auth(self): + """Test building API key auth header.""" + from api.utils.credential_auth import build_auth_header + + mock_credential = Mock() + mock_credential.credential_type = "api_key" + mock_credential.credential_data = { + "header_name": "X-API-Key", + "api_key": "secret-key-123", + } + + header = build_auth_header(mock_credential) + + assert header == {"X-API-Key": "secret-key-123"} + + def test_basic_auth(self): + """Test building basic auth header.""" + import base64 + + from api.utils.credential_auth import build_auth_header + + mock_credential = Mock() + mock_credential.credential_type = "basic_auth" + mock_credential.credential_data = { + "username": "user", + "password": "pass123", + } + + header = build_auth_header(mock_credential) + + expected_encoded = base64.b64encode(b"user:pass123").decode() + assert header == {"Authorization": f"Basic {expected_encoded}"} + + def test_custom_header_auth(self): + """Test building custom header auth.""" + from api.utils.credential_auth import build_auth_header + + mock_credential = Mock() + mock_credential.credential_type = "custom_header" + mock_credential.credential_data = { + "header_name": "X-Custom-Auth", + "header_value": "custom-value-123", + } + + header = build_auth_header(mock_credential) + + assert header == {"X-Custom-Auth": "custom-value-123"} + + def test_unknown_auth_type_returns_empty(self): + """Test that unknown auth types return empty dict.""" + from api.utils.credential_auth import build_auth_header + + mock_credential = Mock() + mock_credential.credential_type = "unknown_type" + mock_credential.credential_data = {} + + header = build_auth_header(mock_credential) + + assert header == {} + + def test_none_credential_type_returns_empty(self): + """Test that 'none' credential type returns empty dict.""" + from api.utils.credential_auth import build_auth_header + + mock_credential = Mock() + mock_credential.credential_type = "none" + mock_credential.credential_data = {} + + header = build_auth_header(mock_credential) + + assert header == {} + + def test_build_auth_header_from_data(self): + """Test building auth header from raw data.""" + from api.utils.credential_auth import build_auth_header_from_data + + header = build_auth_header_from_data( + credential_type="bearer_token", + credential_data={"token": "my-token"}, + ) + + assert header == {"Authorization": "Bearer my-token"} + + def test_api_key_default_header_name(self): + """Test that API key uses default header name if not specified.""" + from api.utils.credential_auth import build_auth_header + + mock_credential = Mock() + mock_credential.credential_type = "api_key" + mock_credential.credential_data = {"api_key": "key123"} + + header = build_auth_header(mock_credential) + + assert header == {"X-API-Key": "key123"} + + +class TestCustomToolManagerIntegration: + """Integration tests for CustomToolManager with MockLLMService.""" + + @pytest.mark.asyncio + async def test_llm_calls_custom_tool_handler(self): + """Test that when LLM makes a function call, the custom tool handler is executed.""" + # Create function call chunks that simulate LLM calling a custom tool + chunks = MockLLMService.create_function_call_chunks( + function_name="book_appointment", + arguments={"customer_name": "John Doe", "date": "2024-01-15"}, + tool_call_id="call_custom_123", + ) + + llm = MockLLMService(mock_chunks=chunks, chunk_delay=0.001) + + # Track if our handler was called + handler_called = False + received_arguments = None + + async def mock_book_appointment(params: FunctionCallParams): + nonlocal handler_called, received_arguments + handler_called = True + received_arguments = params.arguments + await params.result_callback({"status": "booked", "confirmation": "ABC123"}) + + # Register the function handler + llm.register_function("book_appointment", mock_book_appointment) + + # Create context and run + messages = [ + {"role": "user", "content": "Book an appointment for John Doe on Jan 15"} + ] + context = LLMContext(messages) + + pipeline = Pipeline([llm]) + frames_to_send = [LLMContextFrame(context)] + + await run_test( + pipeline, + frames_to_send=frames_to_send, + expected_down_frames=[ + LLMFullResponseStartFrame, + FunctionCallsFromLLMInfoFrame, + FunctionCallsStartedFrame, + LLMFullResponseEndFrame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, + ], + ) + + # Verify handler was called with correct arguments + assert handler_called, "Custom tool handler should have been called" + assert received_arguments == {"customer_name": "John Doe", "date": "2024-01-15"} + + @pytest.mark.asyncio + async def test_multiple_custom_tools_can_be_registered(self): + """Test that multiple custom tools can be registered and called.""" + # Create chunks for calling multiple tools + functions = [ + { + "name": "get_weather", + "arguments": {"location": "NYC"}, + "tool_call_id": "call_weather", + }, + { + "name": "book_restaurant", + "arguments": {"restaurant": "Tavern", "party_size": 4}, + "tool_call_id": "call_restaurant", + }, + ] + chunks = MockLLMService.create_multiple_function_call_chunks(functions) + + llm = MockLLMService(mock_chunks=chunks, chunk_delay=0.001) + + # Track calls + calls_made = [] + + async def mock_get_weather(params: FunctionCallParams): + calls_made.append(("get_weather", params.arguments)) + await params.result_callback({"temp": 72, "condition": "sunny"}) + + async def mock_book_restaurant(params: FunctionCallParams): + calls_made.append(("book_restaurant", params.arguments)) + await params.result_callback({"confirmed": True}) + + llm.register_function("get_weather", mock_get_weather) + llm.register_function("book_restaurant", mock_book_restaurant) + + messages = [{"role": "user", "content": "Check weather and book restaurant"}] + context = LLMContext(messages) + + pipeline = Pipeline([llm]) + await run_test( + pipeline, + frames_to_send=[LLMContextFrame(context)], + expected_down_frames=None, + ) + + # Verify both handlers were called + assert len(calls_made) == 2 + tool_names = [call[0] for call in calls_made] + assert "get_weather" in tool_names + assert "book_restaurant" in tool_names + + +class TestCustomToolManagerUnit: + """Unit tests for CustomToolManager class.""" + + @pytest.mark.asyncio + async def test_get_tool_schemas_returns_correct_format(self): + """Test that get_tool_schemas returns FunctionSchema objects.""" + from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager + from pipecat.adapters.schemas.function_schema import FunctionSchema + + # Create a mock engine + mock_engine = Mock() + mock_engine._workflow_run_id = 1 + mock_engine._call_context_vars = {} + + manager = CustomToolManager(mock_engine) + + # Mock the database client + mock_tool = MockToolModel( + tool_uuid="uuid-1", + name="Test Tool", + description="A test tool", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "POST", + "url": "https://api.example.com/test", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "Test param", + "required": True, + } + ], + }, + }, + ) + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" + ) as mock_get_org: + mock_get_org.return_value = 1 + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=[mock_tool]) + + schemas = await manager.get_tool_schemas(["uuid-1"]) + + assert len(schemas) == 1 + schema = schemas[0] + + # Schema should be a FunctionSchema object + assert isinstance(schema, FunctionSchema) + + # FunctionSchema should have correct attributes + assert schema.name == "test_tool" + assert "param1" in schema.properties + assert schema.properties["param1"]["type"] == "string" + assert "param1" in schema.required + + @pytest.mark.asyncio + async def test_register_handlers_creates_working_handler(self): + """Test that register_handlers creates handlers that can execute tools.""" + from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager + + # Create a mock engine with a mock LLM + mock_llm = Mock() + registered_handlers = {} + + def capture_register(name, handler, **kwargs): + registered_handlers[name] = handler + + mock_llm.register_function = capture_register + + mock_engine = Mock() + mock_engine._workflow_run_id = 1 + mock_engine._call_context_vars = {} + mock_engine.llm = mock_llm + + manager = CustomToolManager(mock_engine) + + mock_tool = MockToolModel( + tool_uuid="uuid-1", + name="API Call", + description="Make an API call", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "POST", + "url": "https://api.example.com/call", + "parameters": [], + }, + }, + ) + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" + ) as mock_get_org: + mock_get_org.return_value = 1 + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=[mock_tool]) + + await manager.register_handlers(["uuid-1"]) + + # Verify handler was registered + assert "api_call" in registered_handlers + + # Now test that the handler works + handler = registered_handlers["api_call"] + + result_received = None + + async def mock_result_callback(result, properties=None): + nonlocal result_received + result_received = result + + mock_params = Mock() + mock_params.arguments = {"key": "value"} + mock_params.result_callback = mock_result_callback + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.execute_http_tool" + ) as mock_execute: + mock_execute.return_value = { + "status": "success", + "data": {"response": "ok"}, + } + + await handler(mock_params) + + # Verify execute was called + mock_execute.assert_called_once() + + # Verify result was returned + assert result_received["status"] == "success" + + @pytest.mark.asyncio + async def test_tools_cache_prevents_duplicate_fetches(self): + """Test that tools are cached after first fetch.""" + from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager + + mock_engine = Mock() + mock_engine._workflow_run_id = 1 + mock_engine._call_context_vars = {} + mock_engine.llm = Mock() + mock_engine.llm.register_function = Mock() + + manager = CustomToolManager(mock_engine) + + mock_tool = MockToolModel( + tool_uuid="uuid-1", + name="Cached Tool", + description="A tool that should be cached", + definition={ + "schema_version": 1, + "type": "http_api", + "config": {"method": "GET", "url": "https://api.example.com"}, + }, + ) + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" + ) as mock_get_org: + mock_get_org.return_value = 1 + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=[mock_tool]) + + # First call should fetch from DB + await manager.get_tool_schemas(["uuid-1"]) + + # Verify tool is now in cache + cached = manager.get_cached_tool("cached_tool") + assert cached is not None + assert cached[0].tool_uuid == "uuid-1" + + # Clear cache and verify it's empty + manager.clear_cache() + cached = manager.get_cached_tool("cached_tool") + assert cached is None + + +class TestUpdateLLMContext: + """Tests for update_llm_context function.""" + + def test_replaces_system_message(self): + """Test that update_llm_context replaces existing system messages.""" + context = LLMContext() + context.set_messages( + [ + {"role": "system", "content": "Old system message"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ] + ) + + new_system = {"role": "system", "content": "New system message"} + update_llm_context(context, new_system, []) + + messages = context.messages + # Should have new system message at the start + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "New system message" + # Should preserve user and assistant messages + assert len(messages) == 3 + assert messages[1]["role"] == "user" + assert messages[2]["role"] == "assistant" + + def test_removes_multiple_old_system_messages(self): + """Test that all old system messages are removed.""" + context = LLMContext() + context.set_messages( + [ + {"role": "system", "content": "First system message"}, + {"role": "user", "content": "Question 1"}, + {"role": "system", "content": "Second system message"}, + {"role": "assistant", "content": "Answer 1"}, + ] + ) + + new_system = {"role": "system", "content": "Only system message"} + update_llm_context(context, new_system, []) + + messages = context.messages + # Should only have one system message now + system_messages = [m for m in messages if m["role"] == "system"] + assert len(system_messages) == 1 + assert system_messages[0]["content"] == "Only system message" + + def test_preserves_conversation_history(self): + """Test that user/assistant messages are preserved in order.""" + context = LLMContext() + context.set_messages( + [ + {"role": "system", "content": "Old prompt"}, + {"role": "user", "content": "First question"}, + {"role": "assistant", "content": "First answer"}, + {"role": "user", "content": "Second question"}, + {"role": "assistant", "content": "Second answer"}, + ] + ) + + new_system = {"role": "system", "content": "New prompt"} + update_llm_context(context, new_system, []) + + messages = context.messages + assert len(messages) == 5 + assert messages[1]["content"] == "First question" + assert messages[2]["content"] == "First answer" + assert messages[3]["content"] == "Second question" + assert messages[4]["content"] == "Second answer" + + def test_sets_tools_when_functions_provided(self): + """Test that tools are set on context when functions are provided.""" + context = LLMContext() + context.set_messages([{"role": "system", "content": "Old"}]) + + # Create function schemas + functions = [ + get_function_schema("book_appointment", "Book an appointment"), + get_function_schema("cancel_appointment", "Cancel an appointment"), + ] + + new_system = {"role": "system", "content": "New prompt with tools"} + update_llm_context(context, new_system, functions) + + # Verify tools were set + tools = context.tools + assert tools is not None + assert len(tools.standard_tools) == 2 + + def test_does_not_set_tools_when_functions_empty(self): + """Test that tools are not set when functions list is empty.""" + context = LLMContext() + context.set_messages([{"role": "system", "content": "Old"}]) + + new_system = {"role": "system", "content": "New prompt without tools"} + update_llm_context(context, new_system, []) + + # Tools should not be set (or remain None) + # Note: The function only calls set_tools if functions is truthy + # So we verify the context state is as expected + messages = context.messages + assert len(messages) == 1 + assert messages[0]["content"] == "New prompt without tools" + + def test_works_with_empty_context(self): + """Test that update works on a fresh context with no messages.""" + context = LLMContext() + + new_system = {"role": "system", "content": "Initial prompt"} + functions = [get_function_schema("test_func", "A test function")] + + update_llm_context(context, new_system, functions) + + messages = context.messages + assert len(messages) == 1 + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "Initial prompt" + + def test_function_schema_structure(self): + """Test that get_function_schema creates correct structure.""" + schema = get_function_schema( + "search_products", + "Search for products in the catalog", + properties={ + "query": {"type": "string", "description": "Search query"}, + "limit": {"type": "integer", "description": "Max results"}, + }, + required=["query"], + ) + + assert schema.name == "search_products" + assert schema.description == "Search for products in the catalog" + assert "query" in schema.properties + assert "limit" in schema.properties + assert "query" in schema.required + assert "limit" not in schema.required + + def test_function_schema_with_no_parameters(self): + """Test get_function_schema with no properties or required.""" + schema = get_function_schema("ping", "Check if service is alive") + + assert schema.name == "ping" + assert schema.description == "Check if service is alive" + assert schema.properties == {} + assert schema.required == [] diff --git a/api/tests/test_custom_tools_context_integration.py b/api/tests/test_custom_tools_context_integration.py new file mode 100644 index 0000000..afdee4b --- /dev/null +++ b/api/tests/test_custom_tools_context_integration.py @@ -0,0 +1,512 @@ +"""Integration tests for CustomToolManager with update_llm_context. + +This module tests the full flow of: +1. CustomToolManager fetching and converting tool schemas +2. update_llm_context setting those tools on the LLM context +3. Verifying the context is properly configured for LLM generation +""" + +from dataclasses import dataclass +from typing import Any, Dict +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager +from api.services.workflow.pipecat_engine_utils import ( + get_function_schema, + update_llm_context, +) +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.processors.aggregators.llm_context import LLMContext + + +@dataclass +class MockToolModel: + """Mock tool model for testing.""" + + tool_uuid: str + name: str + description: str + definition: Dict[str, Any] + + +class TestCustomToolManagerContextIntegration: + """Integration tests for CustomToolManager with LLMContext.""" + + @pytest.fixture + def mock_engine(self): + """Create a mock PipecatEngine.""" + engine = Mock() + engine._workflow_run_id = 1 + engine._call_context_vars = {"customer_name": "John Doe"} + engine.llm = Mock() + engine.llm.register_function = Mock() + return engine + + @pytest.fixture + def sample_tools(self): + """Create sample mock tools for testing.""" + return [ + MockToolModel( + tool_uuid="weather-uuid-123", + name="Get Weather", + description="Get current weather for a location", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "GET", + "url": "https://api.weather.com/current", + "parameters": [ + { + "name": "location", + "type": "string", + "description": "City name (e.g., San Francisco, CA)", + "required": True, + }, + { + "name": "units", + "type": "string", + "description": "Temperature units: celsius or fahrenheit", + "required": False, + }, + ], + }, + }, + ), + MockToolModel( + tool_uuid="booking-uuid-456", + name="Book Appointment", + description="Book an appointment for the customer", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "POST", + "url": "https://api.example.com/appointments", + "parameters": [ + { + "name": "customer_name", + "type": "string", + "description": "Customer's full name", + "required": True, + }, + { + "name": "date", + "type": "string", + "description": "Appointment date (YYYY-MM-DD)", + "required": True, + }, + { + "name": "time", + "type": "string", + "description": "Appointment time (HH:MM)", + "required": True, + }, + { + "name": "notes", + "type": "string", + "description": "Additional notes", + "required": False, + }, + ], + }, + }, + ), + MockToolModel( + tool_uuid="lookup-uuid-789", + name="Customer Lookup", + description="Look up customer information by phone number", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "GET", + "url": "https://api.example.com/customers/lookup", + "parameters": [ + { + "name": "phone", + "type": "string", + "description": "Customer phone number", + "required": True, + }, + ], + }, + }, + ), + ] + + @pytest.mark.asyncio + async def test_get_tool_schemas_and_update_context(self, mock_engine, sample_tools): + """Test fetching tool schemas via CustomToolManager and updating LLM context.""" + manager = CustomToolManager(mock_engine) + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" + ) as mock_get_org: + mock_get_org.return_value = 1 + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools) + + # Get tool schemas via CustomToolManager - now returns FunctionSchema objects + tool_uuids = ["weather-uuid-123", "booking-uuid-456", "lookup-uuid-789"] + schemas = await manager.get_tool_schemas(tool_uuids) + + # Verify schemas were returned as FunctionSchema objects + assert len(schemas) == 3 + assert all(isinstance(s, FunctionSchema) for s in schemas) + + # Create context with conversation history + context = LLMContext() + context.set_messages( + [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "I need to check the weather and book an appointment.", + }, + { + "role": "assistant", + "content": "I can help with both. Where would you like to check the weather?", + }, + {"role": "user", "content": "San Francisco"}, + ] + ) + + # Update context with new system message and tools + # Now we can pass schemas directly since they're FunctionSchema objects + new_system = { + "role": "system", + "content": "You are a scheduling assistant with access to weather and booking tools.", + } + update_llm_context(context, new_system, schemas) + + # Verify context was updated correctly + messages = context.messages + assert len(messages) == 4 + assert ( + messages[0]["content"] + == "You are a scheduling assistant with access to weather and booking tools." + ) + assert messages[1]["role"] == "user" + assert messages[3]["content"] == "San Francisco" + + # Verify tools were set + tools = context.tools + assert tools is not None + assert len(tools.standard_tools) == 3 + + # Verify tool names + tool_names = {t.name for t in tools.standard_tools} + assert tool_names == { + "get_weather", + "book_appointment", + "customer_lookup", + } + + @pytest.mark.asyncio + async def test_tool_schemas_have_correct_properties( + self, mock_engine, sample_tools + ): + """Test that tool schemas from CustomToolManager have correct parameter properties.""" + manager = CustomToolManager(mock_engine) + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" + ) as mock_get_org: + mock_get_org.return_value = 1 + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools) + + schemas = await manager.get_tool_schemas( + ["weather-uuid-123", "booking-uuid-456"] + ) + + # Find the booking schema - now using FunctionSchema attributes + booking_schema = next( + s for s in schemas if s.name == "book_appointment" + ) + + # Verify parameter properties + assert "customer_name" in booking_schema.properties + assert "date" in booking_schema.properties + assert "time" in booking_schema.properties + assert "notes" in booking_schema.properties + + # Verify types + assert booking_schema.properties["customer_name"]["type"] == "string" + assert booking_schema.properties["date"]["type"] == "string" + + # Verify required + assert "customer_name" in booking_schema.required + assert "date" in booking_schema.required + assert "time" in booking_schema.required + assert "notes" not in booking_schema.required + + @pytest.mark.asyncio + async def test_context_update_with_builtin_and_custom_tools( + self, mock_engine, sample_tools + ): + """Test updating context with both built-in and custom tools.""" + manager = CustomToolManager(mock_engine) + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" + ) as mock_get_org: + mock_get_org.return_value = 1 + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock( + return_value=[sample_tools[0]] + ) # Just weather + + # Get custom tool schemas - returns FunctionSchema objects + custom_schemas = await manager.get_tool_schemas(["weather-uuid-123"]) + + # Create built-in function schemas (like calculator, timezone) + builtin_functions = [ + get_function_schema( + "safe_calculator", + "Evaluate a mathematical expression safely", + properties={ + "expression": { + "type": "string", + "description": "Mathematical expression to evaluate", + } + }, + required=["expression"], + ), + get_function_schema( + "get_current_time", + "Get the current time in a timezone", + properties={ + "timezone": { + "type": "string", + "description": "Timezone name (e.g., America/New_York)", + } + }, + required=["timezone"], + ), + ] + + # Combine built-in and custom functions - both are FunctionSchema objects + all_functions = builtin_functions + custom_schemas + + # Update context + context = LLMContext() + context.set_messages([{"role": "system", "content": "Old prompt"}]) + + new_system = { + "role": "system", + "content": "Assistant with calculator and weather tools", + } + update_llm_context(context, new_system, all_functions) + + # Verify all tools are present + tools = context.tools + assert len(tools.standard_tools) == 3 + + tool_names = {t.name for t in tools.standard_tools} + assert "safe_calculator" in tool_names + assert "get_current_time" in tool_names + assert "get_weather" in tool_names + + @pytest.mark.asyncio + async def test_tools_cached_after_first_fetch(self, mock_engine, sample_tools): + """Test that CustomToolManager caches tools after first fetch.""" + manager = CustomToolManager(mock_engine) + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" + ) as mock_get_org: + mock_get_org.return_value = 1 + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]]) + + # First fetch + await manager.get_tool_schemas(["weather-uuid-123"]) + + # Verify tool is cached (cache stores raw schema dict, not FunctionSchema) + cached = manager.get_cached_tool("get_weather") + assert cached is not None + tool, raw_schema = cached + assert tool.tool_uuid == "weather-uuid-123" + assert raw_schema["function"]["name"] == "get_weather" + + @pytest.mark.asyncio + async def test_context_preserves_function_call_history( + self, mock_engine, sample_tools + ): + """Test that update_llm_context preserves function call messages in history.""" + manager = CustomToolManager(mock_engine) + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" + ) as mock_get_org: + mock_get_org.return_value = 1 + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]]) + + # Get schemas - returns FunctionSchema objects + schemas = await manager.get_tool_schemas(["weather-uuid-123"]) + + # Create context with function call history + context = LLMContext() + context.set_messages( + [ + {"role": "system", "content": "Old system prompt"}, + {"role": "user", "content": "What's the weather in NYC?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "New York, NY"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_123", + "content": '{"temperature": 72, "condition": "sunny"}', + }, + { + "role": "assistant", + "content": "The weather in NYC is 72°F and sunny!", + }, + ] + ) + + new_system = {"role": "system", "content": "Updated weather assistant"} + update_llm_context(context, new_system, schemas) + + messages = context.messages + # System + user + assistant(tool_call) + tool + assistant = 5 + assert len(messages) == 5 + + # Verify function call messages are preserved + tool_call_msg = messages[2] + assert tool_call_msg["role"] == "assistant" + assert "tool_calls" in tool_call_msg + + tool_result_msg = messages[3] + assert tool_result_msg["role"] == "tool" + assert tool_result_msg["tool_call_id"] == "call_123" + + @pytest.mark.asyncio + async def test_empty_tool_list_does_not_set_tools(self, mock_engine): + """Test that empty tool list doesn't set tools on context.""" + manager = CustomToolManager(mock_engine) + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" + ) as mock_get_org: + mock_get_org.return_value = 1 + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=[]) + + schemas = await manager.get_tool_schemas([]) + assert schemas == [] + + context = LLMContext() + context.set_messages([{"role": "system", "content": "Old"}]) + + new_system = {"role": "system", "content": "No tools available"} + update_llm_context(context, new_system, []) + + # Context should have updated message but no tools set + assert context.messages[0]["content"] == "No tools available" + + @pytest.mark.asyncio + async def test_numeric_and_boolean_parameter_types(self, mock_engine): + """Test that numeric and boolean parameter types are correctly handled.""" + tool_with_types = MockToolModel( + tool_uuid="order-uuid", + name="Place Order", + description="Place an order for items", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "POST", + "url": "https://api.example.com/orders", + "parameters": [ + { + "name": "item_id", + "type": "string", + "description": "Item identifier", + "required": True, + }, + { + "name": "quantity", + "type": "number", + "description": "Number of items", + "required": True, + }, + { + "name": "express_shipping", + "type": "boolean", + "description": "Use express shipping", + "required": False, + }, + ], + }, + }, + ) + + manager = CustomToolManager(mock_engine) + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" + ) as mock_get_org: + mock_get_org.return_value = 1 + + with patch( + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=[tool_with_types]) + + # Get schemas - returns FunctionSchema objects + schemas = await manager.get_tool_schemas(["order-uuid"]) + schema = schemas[0] + + # Verify types using FunctionSchema attributes + assert schema.properties["item_id"]["type"] == "string" + assert schema.properties["quantity"]["type"] == "number" + assert schema.properties["express_shipping"]["type"] == "boolean" + + # Update context - pass schema directly + context = LLMContext() + context.set_messages([{"role": "system", "content": "Old"}]) + update_llm_context( + context, {"role": "system", "content": "Order assistant"}, schemas + ) + + # Verify tool was set with correct types + tool = context.tools.standard_tools[0] + assert tool.name == "place_order" + assert tool.properties["quantity"]["type"] == "number" + assert tool.properties["express_shipping"]["type"] == "boolean" diff --git a/api/tests/test_default_user_configuration.py b/api/tests/test_default_user_configuration.py deleted file mode 100644 index f264a96..0000000 --- a/api/tests/test_default_user_configuration.py +++ /dev/null @@ -1,33 +0,0 @@ -import os -import uuid - -import pytest - -from api.db.user_client import UserClient -from api.services.configuration.registry import ServiceProviders - - -@pytest.mark.asyncio -async def test_default_configuration_created(db_session): - # Set env variable for openai to simulate availability of default key - os.environ["OPENAI_API_KEY"] = "sk-test-openai-key" - - # Ensure deepgram env variable absent to focus test - os.environ.pop("DEEPGRAM_API_KEY", None) - - # Generate a unique (random) provider user ID for each test run - test_provider_user_id = f"provider_user_{uuid.uuid4().hex}" - user_client: UserClient = db_session # db_session fixture yields the client - - user_model = await user_client.get_or_create_user_by_provider_id( - test_provider_user_id - ) - - config = await user_client.get_user_configurations(user_model.id) - - assert config.llm is not None, "LLM config should be created when env key present" - assert config.llm.provider == ServiceProviders.OPENAI - assert config.llm.api_key == "sk-test-openai-key" - - # Cleanup / restore env variable side-effects - os.environ.pop("OPENAI_API_KEY", None) diff --git a/api/tests/test_disposition_mapper.py b/api/tests/test_disposition_mapper.py deleted file mode 100644 index e674161..0000000 --- a/api/tests/test_disposition_mapper.py +++ /dev/null @@ -1,122 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from api.services.workflow.disposition_mapper import ( - apply_disposition_mapping, - get_organization_id_from_workflow_run, -) - - -@pytest.mark.asyncio -async def test_apply_disposition_mapping_with_valid_mapping(): - """Test disposition mapping with valid configuration.""" - with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client: - # Mock disposition mapping configuration - mock_db_client.get_configuration_value = AsyncMock( - return_value={ - "XFER": "TRANSFERRED", - "ND": "NOT_QUALIFIED", - "user_hangup": "HANGUP", - } - ) - - # Test mapping exists - result = await apply_disposition_mapping("XFER", 1) - assert result == "TRANSFERRED" - - # Test mapping doesn't exist - result = await apply_disposition_mapping("UNKNOWN", 1) - assert result == "UNKNOWN" - - # Verify db_client was called correctly - mock_db_client.get_configuration_value.assert_called_with( - 1, "DISPOSITION_CODE_MAPPING", default={} - ) - - -@pytest.mark.asyncio -async def test_apply_disposition_mapping_no_organization_id(): - """Test disposition mapping with no organization ID.""" - # Should return original value - result = await apply_disposition_mapping("XFER", None) - assert result == "XFER" - - -@pytest.mark.asyncio -async def test_apply_disposition_mapping_empty_value(): - """Test disposition mapping with empty value.""" - # Should return original empty value - result = await apply_disposition_mapping("", 1) - assert result == "" - - -@pytest.mark.asyncio -async def test_apply_disposition_mapping_error_handling(): - """Test disposition mapping handles errors gracefully.""" - with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client: - # Mock database error - mock_db_client.get_configuration_value = AsyncMock( - side_effect=Exception("Database error") - ) - - # Should return original value on error - result = await apply_disposition_mapping("XFER", 1) - assert result == "XFER" - - -@pytest.mark.asyncio -async def test_get_organization_id_from_workflow_run(): - """Test getting organization ID from workflow run ID.""" - with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client: - # Mock workflow run with organization - mock_workflow_run = MagicMock() - mock_workflow_run.workflow.user.selected_organization_id = 123 - mock_db_client.get_workflow_run_by_id = AsyncMock( - return_value=mock_workflow_run - ) - - result = await get_organization_id_from_workflow_run(1) - assert result == 123 - - # Verify db_client was called correctly - mock_db_client.get_workflow_run_by_id.assert_called_once_with(1) - - -@pytest.mark.asyncio -async def test_get_organization_id_no_workflow_run(): - """Test getting organization ID when workflow run doesn't exist.""" - with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client: - # Mock no workflow run found - mock_db_client.get_workflow_run_by_id = AsyncMock(return_value=None) - - result = await get_organization_id_from_workflow_run(1) - assert result is None - - -@pytest.mark.asyncio -async def test_get_organization_id_no_user(): - """Test getting organization ID when workflow has no user.""" - with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client: - # Mock workflow run with no user - mock_workflow_run = MagicMock() - mock_workflow_run.workflow.user = None - mock_db_client.get_workflow_run_by_id = AsyncMock( - return_value=mock_workflow_run - ) - - result = await get_organization_id_from_workflow_run(1) - assert result is None - - -@pytest.mark.asyncio -async def test_get_organization_id_error_handling(): - """Test getting organization ID handles errors gracefully.""" - with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client: - # Mock database error - mock_db_client.get_workflow_run_by_id = AsyncMock( - side_effect=Exception("Database error") - ) - - result = await get_organization_id_from_workflow_run(1) - assert result is None diff --git a/api/tests/test_event_handler_disposition_mapping.py b/api/tests/test_event_handler_disposition_mapping.py deleted file mode 100644 index cc6cea3..0000000 --- a/api/tests/test_event_handler_disposition_mapping.py +++ /dev/null @@ -1,370 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from pipecat.utils.enums import EndTaskReason - -from api.services.pipecat.event_handlers import register_transport_event_handlers - - -@pytest.fixture -def mock_dependencies(): - """Create mock dependencies for event handlers.""" - # Store registered handlers - registered_handlers = {} - - def mock_event_handler(event_name): - def decorator(func): - registered_handlers[event_name] = func - return func - - return decorator - - mock_transport = MagicMock() - mock_transport.event_handler = mock_event_handler - - mock_task = MagicMock() - mock_task.cancel = AsyncMock() - - mock_engine = MagicMock() - mock_engine.initialize = AsyncMock() - mock_engine.cleanup = AsyncMock() - - mock_audio_buffer = MagicMock() - mock_audio_buffer.start_recording = AsyncMock() - mock_audio_buffer.stop_recording = AsyncMock() - - mock_usage_metrics_aggregator = MagicMock() - mock_usage_metrics_aggregator.get_all_usage_metrics_serialized = MagicMock( - return_value={"test": "metrics"} - ) - - return { - "transport": mock_transport, - "workflow_run_id": 123, - "audio_buffer": mock_audio_buffer, - "task": mock_task, - "engine": mock_engine, - "usage_metrics_aggregator": mock_usage_metrics_aggregator, - "audio_synchronizer": None, - "registered_handlers": registered_handlers, - } - - -@pytest.mark.asyncio -async def test_transport_disconnect_reason_mapping(mock_dependencies): - """Test that transport_disconnect_reason is mapped when no engine disconnect reason exists.""" - # Register event handlers - register_transport_event_handlers( - transport=mock_dependencies["transport"], - workflow_run_id=mock_dependencies["workflow_run_id"], - audio_buffer=mock_dependencies["audio_buffer"], - task=mock_dependencies["task"], - engine=mock_dependencies["engine"], - usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"], - audio_synchronizer=mock_dependencies["audio_synchronizer"], - ) - - # Get the on_client_disconnected handler - handler = mock_dependencies["registered_handlers"]["on_client_disconnected"] - - # Mock engine with no call disposition - mock_dependencies["engine"].get_call_disposition.return_value = None - mock_dependencies["engine"].get_gathered_context.return_value = { - "agent_name": "Alex" - } - - # Mock the disposition mapper functions - with patch( - "api.services.pipecat.event_handlers.get_organization_id_from_workflow_run", - new_callable=AsyncMock, - ) as mock_get_org_id: - with patch( - "api.services.pipecat.event_handlers.apply_disposition_mapping", - new_callable=AsyncMock, - ) as mock_apply_mapping: - with patch( - "api.services.pipecat.event_handlers.db_client" - ) as mock_db_client: - with patch( - "api.services.pipecat.event_handlers.enqueue_job" - ) as mock_enqueue: - # Mock organization ID - mock_get_org_id.return_value = 1 - - # Mock call duration for user_hangup logic - mock_dependencies[ - "usage_metrics_aggregator" - ].get_call_duration.return_value = 15 - - # Mock disposition mapping - async def apply_mapping_side_effect(value, org_id): - return { - "NIBP": "NOT_INTERESTED_BUSINESS_PURPOSE", - "user_qualified": "QUALIFIED", - }.get(value, value) - - mock_apply_mapping.side_effect = apply_mapping_side_effect - - # Mock database operations - mock_workflow_run = MagicMock() - mock_workflow_run.id = 123 - mock_workflow_run.workflow_id = 1 - mock_workflow_run.organization_id = 1 - mock_workflow_run.gathered_context = {} - mock_db_client.get_workflow_run_by_id = AsyncMock( - return_value=mock_workflow_run - ) - mock_db_client.update_workflow_run = AsyncMock() - - # Call handler with transport_disconnect_reason - await handler( - mock_dependencies["transport"], - participant=None, - transport_disconnect_reason="user_hangup", - ) - - # Verify disposition mapping was applied with NIBP (since duration > 10) - mock_apply_mapping.assert_called_once_with("NIBP", 1) - - # Verify database was updated with mapped value - mock_db_client.update_workflow_run.assert_called_once() - call_args = mock_db_client.update_workflow_run.call_args - assert ( - call_args[1]["gathered_context"]["mapped_call_disposition"] - == "NOT_INTERESTED_BUSINESS_PURPOSE" - ) - - # Verify task was cancelled (no engine disconnect reason) - mock_dependencies["task"].cancel.assert_called_once() - - -@pytest.mark.asyncio -async def test_transport_disconnect_reason_user_hangup_short_call(mock_dependencies): - """Test that user_hangup with short call duration is mapped to HU.""" - # Register event handlers - register_transport_event_handlers( - transport=mock_dependencies["transport"], - workflow_run_id=mock_dependencies["workflow_run_id"], - audio_buffer=mock_dependencies["audio_buffer"], - task=mock_dependencies["task"], - engine=mock_dependencies["engine"], - usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"], - audio_synchronizer=mock_dependencies["audio_synchronizer"], - ) - - # Get the on_client_disconnected handler - handler = mock_dependencies["registered_handlers"]["on_client_disconnected"] - - # Mock engine with no call disposition - mock_dependencies["engine"].get_call_disposition.return_value = None - mock_dependencies["engine"].get_gathered_context.return_value = { - "agent_name": "Alex" - } - - # Mock the disposition mapper functions - with patch( - "api.services.pipecat.event_handlers.get_organization_id_from_workflow_run", - new_callable=AsyncMock, - ) as mock_get_org_id: - with patch( - "api.services.pipecat.event_handlers.apply_disposition_mapping", - new_callable=AsyncMock, - ) as mock_apply_mapping: - with patch( - "api.services.pipecat.event_handlers.db_client" - ) as mock_db_client: - with patch( - "api.services.pipecat.event_handlers.enqueue_job" - ) as mock_enqueue: - # Mock organization ID - mock_get_org_id.return_value = 1 - - # Mock call duration for user_hangup logic (< 10 seconds) - mock_dependencies[ - "usage_metrics_aggregator" - ].get_call_duration.return_value = 5 - - # Mock disposition mapping - mock_apply_mapping.return_value = "HANGUP" - - # Mock database operations - mock_workflow_run = MagicMock() - mock_workflow_run.id = 123 - mock_workflow_run.workflow_id = 1 - mock_workflow_run.organization_id = 1 - mock_workflow_run.gathered_context = {} - mock_db_client.get_workflow_run_by_id = AsyncMock( - return_value=mock_workflow_run - ) - mock_db_client.update_workflow_run = AsyncMock() - - # Call handler with transport_disconnect_reason - await handler( - mock_dependencies["transport"], - participant=None, - transport_disconnect_reason="user_hangup", - ) - - # Verify disposition mapping was applied with HU (since duration < 10) - mock_apply_mapping.assert_called_once_with("HU", 1) - - # Verify database was updated with mapped value - mock_db_client.update_workflow_run.assert_called_once() - call_args = mock_db_client.update_workflow_run.call_args - assert ( - call_args[1]["gathered_context"]["mapped_call_disposition"] - == "HANGUP" - ) - - # Verify task was cancelled (no engine disconnect reason) - mock_dependencies["task"].cancel.assert_called_once() - - -@pytest.mark.asyncio -async def test_engine_disconnect_reason_takes_precedence(mock_dependencies): - """Test that engine disconnect reason takes precedence and is not mapped.""" - # Register event handlers - register_transport_event_handlers( - transport=mock_dependencies["transport"], - workflow_run_id=mock_dependencies["workflow_run_id"], - audio_buffer=mock_dependencies["audio_buffer"], - task=mock_dependencies["task"], - engine=mock_dependencies["engine"], - usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"], - audio_synchronizer=mock_dependencies["audio_synchronizer"], - ) - - # Get the on_client_disconnected handler - handler = mock_dependencies["registered_handlers"]["on_client_disconnected"] - - # Mock engine with call disposition - mock_dependencies["engine"].get_call_disposition.return_value = "user_qualified" - mock_dependencies["engine"].get_gathered_context.return_value = { - "agent_name": "Alex" - } - - # Mock the disposition mapper functions - with patch( - "api.services.pipecat.event_handlers.get_organization_id_from_workflow_run", - new_callable=AsyncMock, - ) as mock_get_org_id: - with patch( - "api.services.pipecat.event_handlers.apply_disposition_mapping", - new_callable=AsyncMock, - ) as mock_apply_mapping: - with patch( - "api.services.pipecat.event_handlers.db_client" - ) as mock_db_client: - with patch( - "api.services.pipecat.event_handlers.enqueue_job" - ) as mock_enqueue: - # Mock organization ID - mock_get_org_id.return_value = 1 - - # Mock disposition mapping for engine's reason - mock_apply_mapping.return_value = "QUALIFIED" - - # Mock database operations - mock_workflow_run = MagicMock() - mock_workflow_run.id = 123 - mock_workflow_run.workflow_id = 1 - mock_workflow_run.organization_id = 1 - mock_workflow_run.gathered_context = {} - mock_db_client.get_workflow_run_by_id = AsyncMock( - return_value=mock_workflow_run - ) - mock_db_client.update_workflow_run = AsyncMock() - - # Call handler with transport_disconnect_reason - await handler( - mock_dependencies["transport"], - participant=None, - transport_disconnect_reason="user_hangup", - ) - - # Verify disposition mapping was called with engine's reason - mock_apply_mapping.assert_called_once_with("user_qualified", 1) - - # Verify database was updated with mapped value - mock_db_client.update_workflow_run.assert_called_once() - call_args = mock_db_client.update_workflow_run.call_args - assert ( - call_args[1]["gathered_context"]["mapped_call_disposition"] - == "QUALIFIED" - ) - - # Verify task was NOT cancelled (engine disconnect reason exists) - mock_dependencies["task"].cancel.assert_not_called() - - -@pytest.mark.asyncio -async def test_no_disconnect_reason_uses_unknown(mock_dependencies): - """Test that when no disconnect reason is provided, UNKNOWN is used.""" - # Register event handlers - register_transport_event_handlers( - transport=mock_dependencies["transport"], - workflow_run_id=mock_dependencies["workflow_run_id"], - audio_buffer=mock_dependencies["audio_buffer"], - task=mock_dependencies["task"], - engine=mock_dependencies["engine"], - usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"], - audio_synchronizer=mock_dependencies["audio_synchronizer"], - ) - - # Get the on_client_disconnected handler - handler = mock_dependencies["registered_handlers"]["on_client_disconnected"] - - # Mock engine with no call disposition - mock_dependencies["engine"].get_call_disposition.return_value = None - mock_dependencies["engine"].get_gathered_context.return_value = { - "agent_name": "Alex" - } - - with patch( - "api.services.pipecat.event_handlers.get_organization_id_from_workflow_run" - ) as mock_get_org_id: - with patch( - "api.services.pipecat.event_handlers.apply_disposition_mapping" - ) as mock_apply_mapping: - with patch( - "api.services.pipecat.event_handlers.db_client" - ) as mock_db_client: - with patch( - "api.services.pipecat.event_handlers.enqueue_job" - ) as mock_enqueue: - # Mock organization ID - mock_get_org_id.return_value = 1 - - # Mock disposition mapping - should return UNKNOWN as-is - mock_apply_mapping.return_value = EndTaskReason.UNKNOWN.value - - # Mock database operations - mock_workflow_run = MagicMock() - mock_workflow_run.id = 123 - mock_workflow_run.workflow_id = 1 - mock_workflow_run.organization_id = 1 - mock_workflow_run.gathered_context = {} - mock_db_client.get_workflow_run_by_id = AsyncMock( - return_value=mock_workflow_run - ) - mock_db_client.update_workflow_run = AsyncMock() - - # Call handler without transport_disconnect_reason - await handler( - mock_dependencies["transport"], - participant=None, - transport_disconnect_reason=None, - ) - - # Verify disposition mapping was called with UNKNOWN - mock_apply_mapping.assert_called_once_with( - EndTaskReason.UNKNOWN.value, 1 - ) - - # Verify database was updated with UNKNOWN - mock_db_client.update_workflow_run.assert_called_once() - call_args = mock_db_client.update_workflow_run.call_args - assert ( - call_args[1]["gathered_context"]["mapped_call_disposition"] - == EndTaskReason.UNKNOWN.value - ) diff --git a/api/tests/test_event_handlers_refactor.py b/api/tests/test_event_handlers_refactor.py deleted file mode 100644 index 94fec0b..0000000 --- a/api/tests/test_event_handlers_refactor.py +++ /dev/null @@ -1,184 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from api.services.pipecat.audio_config import AudioConfig -from api.services.pipecat.event_handlers import ( - register_audio_data_handler, - register_transcript_handler, - register_transport_event_handlers, -) - - -@pytest.mark.asyncio -async def test_transport_handlers_with_in_memory_buffers(): - """Test that transport handlers create and return in-memory buffers.""" - # Mock dependencies - transport = MagicMock() - transport.event_handler = lambda event_name: lambda func: func - - audio_buffer = AsyncMock() - audio_synchronizer = AsyncMock() - task = AsyncMock() - engine = AsyncMock() - engine.get_call_disposition.return_value = None - engine.get_gathered_context.return_value = {} - - usage_metrics_aggregator = AsyncMock() - usage_metrics_aggregator.get_call_duration.return_value = 30 - usage_metrics_aggregator.get_all_usage_metrics_serialized.return_value = {} - - # Create test audio config - audio_config = AudioConfig( - transport_in_sample_rate=16000, - transport_out_sample_rate=16000, - pipeline_sample_rate=16000, - ) - - # Register handlers - audio_buf, transcript_buf = register_transport_event_handlers( - transport=transport, - workflow_run_id=123, - audio_buffer=audio_buffer, - task=task, - engine=engine, - usage_metrics_aggregator=usage_metrics_aggregator, - audio_synchronizer=audio_synchronizer, - audio_config=audio_config, - ) - - # Verify buffers were created with correct configuration - assert audio_buf is not None - assert transcript_buf is not None - assert audio_buf._workflow_run_id == 123 - assert audio_buf._sample_rate == 16000 - assert audio_buf._num_channels == 1 - assert transcript_buf._workflow_run_id == 123 - - -@pytest.mark.asyncio -async def test_audio_handler_with_in_memory_buffer(): - """Test audio handler uses in-memory buffer when provided.""" - # Mock audio synchronizer - audio_synchronizer = MagicMock() - handlers = {} - - def mock_event_handler(event_name): - def decorator(func): - handlers[event_name] = func - return func - - return decorator - - audio_synchronizer.event_handler = mock_event_handler - - # Mock in-memory buffer - in_memory_buffer = AsyncMock() - - # Register handler with buffer - register_audio_data_handler( - audio_synchronizer, workflow_run_id=123, in_memory_buffer=in_memory_buffer - ) - - # Test the handler - assert "on_merged_audio" in handlers - handler = handlers["on_merged_audio"] - - # Call handler with test data - test_pcm = b"test_audio_data" - await handler(None, test_pcm, 16000, 1) - - # Verify buffer was used - in_memory_buffer.append.assert_called_once_with(test_pcm) - - -@pytest.mark.asyncio -async def test_transcript_handler_with_in_memory_buffer(): - """Test transcript handler uses in-memory buffer when provided.""" - # Mock transcript processor - transcript = MagicMock() - handlers = {} - - def mock_event_handler(event_name): - def decorator(func): - handlers[event_name] = func - return func - - return decorator - - transcript.event_handler = mock_event_handler - - # Mock in-memory buffer - in_memory_buffer = AsyncMock() - - # Register handler with buffer - register_transcript_handler( - transcript, workflow_run_id=456, in_memory_buffer=in_memory_buffer - ) - - # Create test frame - test_frame = MagicMock() - test_frame.messages = [ - MagicMock(timestamp="00:00:01", role="user", content="Hello"), - MagicMock(timestamp="00:00:02", role="assistant", content="Hi there"), - ] - - # Test the handler - handler = handlers["on_transcript_update"] - await handler(None, test_frame) - - # Verify buffer was used with correct format - expected_text = "[00:00:01] user: Hello\n[00:00:02] assistant: Hi there\n" - in_memory_buffer.append.assert_called_once_with(expected_text) - - -@pytest.mark.asyncio -async def test_audio_config_sample_rates(): - """Test that different audio configs result in correct sample rates.""" - # Mock dependencies - transport = MagicMock() - transport.event_handler = lambda event_name: lambda func: func - - audio_buffer = AsyncMock() - audio_synchronizer = AsyncMock() - task = AsyncMock() - engine = AsyncMock() - engine.get_call_disposition.return_value = None - engine.get_gathered_context.return_value = {} - - usage_metrics_aggregator = AsyncMock() - usage_metrics_aggregator.get_all_usage_metrics_serialized.return_value = {} - - # Test with 8kHz audio config (e.g., for Stasis/Twilio) - audio_config_8k = AudioConfig( - transport_in_sample_rate=8000, - transport_out_sample_rate=8000, - pipeline_sample_rate=8000, - ) - - audio_buf_8k, _ = register_transport_event_handlers( - transport=transport, - workflow_run_id=456, - audio_buffer=audio_buffer, - task=task, - engine=engine, - usage_metrics_aggregator=usage_metrics_aggregator, - audio_synchronizer=audio_synchronizer, - audio_config=audio_config_8k, - ) - - assert audio_buf_8k._sample_rate == 8000 - - # Test with no audio config (should default to 16kHz) - audio_buf_default, _ = register_transport_event_handlers( - transport=transport, - workflow_run_id=789, - audio_buffer=audio_buffer, - task=task, - engine=engine, - usage_metrics_aggregator=usage_metrics_aggregator, - audio_synchronizer=audio_synchronizer, - audio_config=None, - ) - - assert audio_buf_default._sample_rate == 16000 diff --git a/api/tests/test_filters.py b/api/tests/test_filters.py deleted file mode 100644 index feb6e73..0000000 --- a/api/tests/test_filters.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Test filter functionality.""" - -from unittest.mock import MagicMock - -from api.db.filters import ATTRIBUTE_FIELD_MAPPING, apply_workflow_run_filters - - -def test_attribute_field_mapping(): - """Test that all required attributes are mapped.""" - expected_attributes = [ - "dateRange", - "dispositionCode", - "duration", - "status", - "tokenUsage", - "runId", - "workflowId", - "callTags", - "phoneNumber", - ] - - for attr in expected_attributes: - assert attr in ATTRIBUTE_FIELD_MAPPING, f"Missing mapping for {attr}" - - -def test_filter_with_explicit_type(): - """Test that filters work with explicit type from UI.""" - - # Mock query - mock_query = MagicMock() - mock_query.where = MagicMock(return_value=mock_query) - - test_cases = [ - # Date range filter - { - "filters": [ - { - "attribute": "dateRange", - "type": "dateRange", - "value": {"from": "2024-01-01", "to": "2024-01-31"}, - } - ], - }, - # Multi-select filter - { - "filters": [ - { - "attribute": "dispositionCode", - "type": "multiSelect", - "value": {"codes": ["XFER", "HU"]}, - } - ], - }, - # Number range filter - { - "filters": [ - { - "attribute": "duration", - "type": "numberRange", - "value": {"min": 60, "max": 300}, - } - ], - }, - # Radio/status filter - { - "filters": [ - { - "attribute": "status", - "type": "radio", - "value": {"status": "completed"}, - } - ], - }, - # Number filter - { - "filters": [ - {"attribute": "runId", "type": "number", "value": {"value": 123}} - ], - }, - # Text filter - { - "filters": [ - { - "attribute": "phoneNumber", - "type": "text", - "value": {"value": "+1234567890"}, - } - ], - }, - # Tags filter - { - "filters": [ - { - "attribute": "callTags", - "type": "tags", - "value": {"codes": ["tag1", "tag2"]}, - } - ], - }, - ] - - for test_case in test_cases: - result = apply_workflow_run_filters(mock_query, test_case["filters"]) - # The function should process the filter without errors - assert result is not None - - -def test_filter_format_with_type(): - """Test that filters work with attribute, type, and value.""" - - mock_query = MagicMock() - mock_query.where = MagicMock(return_value=mock_query) - - # Test with various filter combinations - filters = [ - { - "attribute": "dispositionCode", - "type": "multiSelect", - "value": {"codes": ["NIBP"]}, - }, - { - "attribute": "duration", - "type": "numberRange", - "value": {"min": 0, "max": 60}, - }, - {"attribute": "phoneNumber", "type": "text", "value": {"value": "555"}}, - ] - - result = apply_workflow_run_filters(mock_query, filters) - - # Should have called where() for applying filters - assert mock_query.where.called - assert result is not None - - -def test_unknown_attribute_ignored(): - """Test that unknown attributes are safely ignored.""" - - mock_query = MagicMock() - mock_query.where = MagicMock(return_value=mock_query) - - filters = [ - {"attribute": "unknownAttribute", "value": {"value": "test"}}, - {"attribute": "dispositionCode", "value": {"codes": ["XFER"]}}, - ] - - result = apply_workflow_run_filters(mock_query, filters) - - # Should still process the valid filter - assert result is not None - - -def test_empty_filters(): - """Test that empty filters return the query unchanged.""" - - mock_query = MagicMock() - - result = apply_workflow_run_filters(mock_query, None) - assert result == mock_query - - result = apply_workflow_run_filters(mock_query, []) - assert result == mock_query diff --git a/api/tests/test_global_prompt.py b/api/tests/test_global_prompt.py deleted file mode 100644 index 5902f63..0000000 --- a/api/tests/test_global_prompt.py +++ /dev/null @@ -1,249 +0,0 @@ -"""Tests for global prompt functionality in workflow engine.""" - -from unittest.mock import Mock - -import pytest -from pipecat.services.openai.llm import OpenAILLMContext - -from api.services.workflow.dto import ( - EdgeDataDTO, - NodeDataDTO, - NodeType, - ReactFlowDTO, - RFEdgeDTO, - RFNodeDTO, -) -from api.services.workflow.pipecat_engine import PipecatEngine -from api.services.workflow.workflow import WorkflowGraph - - -class TestGlobalPrompt: - """Test suite for global prompt feature.""" - - @pytest.fixture - def workflow_with_global_node(self): - """Create a workflow with a global node and test nodes.""" - nodes = [ - RFNodeDTO( - id="global", - type=NodeType.globalNode, - position={"x": 0, "y": 0}, - data=NodeDataDTO( - name="Global Node", - prompt="This is the global context: {{company_name}}", - is_static=False, - ), - ), - RFNodeDTO( - id="start", - type=NodeType.startNode, - position={"x": 100, "y": 100}, - data=NodeDataDTO( - name="Start Call", - prompt="Welcome to our service!", - is_static=False, - is_start=True, - add_global_prompt=True, # Enable global prompt - ), - ), - RFNodeDTO( - id="agent1", - type=NodeType.agentNode, - position={"x": 200, "y": 200}, - data=NodeDataDTO( - name="Agent 1", - prompt="How can I help you today?", - add_global_prompt=False, # Disable global prompt - ), - ), - RFNodeDTO( - id="agent2", - type=NodeType.agentNode, - position={"x": 300, "y": 300}, - data=NodeDataDTO( - name="Agent 2", - prompt="Please provide your details.", - add_global_prompt=True, # Enable global prompt - ), - ), - RFNodeDTO( - id="end", - type=NodeType.endNode, - position={"x": 400, "y": 400}, - data=NodeDataDTO( - name="End Call", - prompt="Thank you for calling!", - is_static=True, - is_end=True, - add_global_prompt=True, # Enable global prompt (but static) - ), - ), - ] - - edges = [ - RFEdgeDTO( - id="e1", - source="start", - target="agent1", - data=EdgeDataDTO(label="Next", condition="Continue to agent"), - ), - RFEdgeDTO( - id="e2", - source="agent1", - target="agent2", - data=EdgeDataDTO(label="Details", condition="Get user details"), - ), - RFEdgeDTO( - id="e3", - source="agent2", - target="end", - data=EdgeDataDTO(label="Finish", condition="End the call"), - ), - ] - - flow_dto = ReactFlowDTO(nodes=nodes, edges=edges) - return WorkflowGraph(flow_dto) - - @pytest.fixture - def mock_dependencies(self): - """Create mock dependencies for PipecatEngine initialization.""" - return { - "task": Mock(), - "llm": Mock(), - "context": Mock(spec=OpenAILLMContext), - "tts": Mock(), - "transport": Mock(), - "call_context_vars": {"company_name": "Dograh Inc"}, - } - - @pytest.fixture - def engine(self, mock_dependencies, workflow_with_global_node): - """Create a PipecatEngine instance with test workflow.""" - mock_dependencies["workflow"] = workflow_with_global_node - return PipecatEngine(**mock_dependencies) - - @pytest.mark.asyncio - async def test_global_prompt_enabled(self, engine): - """Test that global prompt is prepended when add_global_prompt is True.""" - # Test with start node (add_global_prompt=True) - start_node = engine.workflow.nodes["start"] - ( - system_message, - functions, - ) = await engine._compose_system_message_functions_for_node(start_node) - - # Global prompt should be included - expected_content = ( - "This is the global context: Dograh Inc\n\nWelcome to our service!" - ) - assert system_message["content"] == expected_content - assert system_message["role"] == "system" - - @pytest.mark.asyncio - async def test_global_prompt_disabled(self, engine): - """Test that global prompt is not prepended when add_global_prompt is False.""" - # Test with agent1 node (add_global_prompt=False) - agent1_node = engine.workflow.nodes["agent1"] - ( - system_message, - functions, - ) = await engine._compose_system_message_functions_for_node(agent1_node) - - # Global prompt should NOT be included - expected_content = "How can I help you today?" - assert system_message["content"] == expected_content - assert "global context" not in system_message["content"] - - @pytest.mark.asyncio - async def test_global_prompt_with_static_node(self, engine): - """Test that static nodes don't use global prompt in engine (even if enabled).""" - # Static nodes are handled differently - they use TTSSpeakFrame directly - # This test verifies the compose_system_message behavior for completeness - end_node = engine.workflow.nodes["end"] - - # Even though add_global_prompt=True, static nodes handle prompts differently - # The _compose_system_message_functions_for_node is still called for consistency - ( - system_message, - functions, - ) = await engine._compose_system_message_functions_for_node(end_node) - - # For static nodes, the global prompt would still be composed if enabled - expected_content = ( - "This is the global context: Dograh Inc\n\nThank you for calling!" - ) - assert system_message["content"] == expected_content - - @pytest.mark.asyncio - async def test_global_prompt_variable_substitution(self, engine): - """Test that variables in global prompt are properly substituted.""" - agent2_node = engine.workflow.nodes["agent2"] - ( - system_message, - functions, - ) = await engine._compose_system_message_functions_for_node(agent2_node) - - # Verify variable substitution in global prompt - assert "Dograh Inc" in system_message["content"] - assert "{{company_name}}" not in system_message["content"] - - # Full expected content - expected_content = ( - "This is the global context: Dograh Inc\n\nPlease provide your details." - ) - assert system_message["content"] == expected_content - - @pytest.mark.asyncio - async def test_no_global_node_scenario(self, engine): - """Test behavior when there's no global node in the workflow.""" - # Remove global node from workflow - engine.workflow.global_node_id = None - - start_node = engine.workflow.nodes["start"] - ( - system_message, - functions, - ) = await engine._compose_system_message_functions_for_node(start_node) - - # Should only have the node's own prompt - assert system_message["content"] == "Welcome to our service!" - - @pytest.mark.asyncio - async def test_empty_global_prompt(self, engine): - """Test behavior when global prompt is empty.""" - # Set global prompt to empty string - engine.workflow.nodes["global"].prompt = "" - - start_node = engine.workflow.nodes["start"] - ( - system_message, - functions, - ) = await engine._compose_system_message_functions_for_node(start_node) - - # Should only have the node's own prompt (empty global prompt is filtered out) - assert system_message["content"] == "Welcome to our service!" - - def test_default_add_global_prompt_value(self): - """Test that add_global_prompt defaults to True in NodeDataDTO.""" - node_data = NodeDataDTO(name="Test", prompt="Test prompt") - assert node_data.add_global_prompt is True - - @pytest.mark.asyncio - async def test_multiple_prompts_concatenation(self, engine): - """Test proper concatenation of global and node prompts.""" - # Test with agent2 node that has global prompt enabled - agent2_node = engine.workflow.nodes["agent2"] - - ( - system_message, - functions, - ) = await engine._compose_system_message_functions_for_node(agent2_node) - - # Should have global and node prompts concatenated with double newlines - # (extraction prompt is no longer included in system message) - expected_parts = [ - "This is the global context: Dograh Inc", - "Please provide your details.", - ] - expected_content = "\n\n".join(expected_parts) - assert system_message["content"] == expected_content diff --git a/api/tests/test_global_prompt_unit.py b/api/tests/test_global_prompt_unit.py deleted file mode 100644 index 76e153a..0000000 --- a/api/tests/test_global_prompt_unit.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Unit tests for global prompt functionality - no DB dependencies.""" - -import sys -from pathlib import Path - -# Add the api directory to the Python path -api_path = Path(__file__).parent.parent -sys.path.insert(0, str(api_path)) - -from services.workflow.dto import ( - EdgeDataDTO, - NodeDataDTO, - NodeType, - ReactFlowDTO, - RFEdgeDTO, - RFNodeDTO, -) -from services.workflow.workflow import WorkflowGraph - - -def test_node_data_dto_default_global_prompt(): - """Test that add_global_prompt defaults to True.""" - node_data = NodeDataDTO(name="Test Node", prompt="Test prompt") - assert node_data.add_global_prompt is True - print("✓ NodeDataDTO defaults add_global_prompt to True") - - -def test_node_data_dto_explicit_global_prompt(): - """Test explicit setting of add_global_prompt.""" - # Test with False - node_data_false = NodeDataDTO( - name="Test Node", prompt="Test prompt", add_global_prompt=False - ) - assert node_data_false.add_global_prompt is False - - # Test with True - node_data_true = NodeDataDTO( - name="Test Node", prompt="Test prompt", add_global_prompt=True - ) - assert node_data_true.add_global_prompt is True - print("✓ NodeDataDTO respects explicit add_global_prompt values") - - -def test_workflow_node_inherits_global_prompt_setting(): - """Test that workflow Node inherits add_global_prompt from NodeDataDTO.""" - nodes = [ - RFNodeDTO( - id="start", - type=NodeType.startNode, - position={"x": 0, "y": 0}, - data=NodeDataDTO( - name="Start", - prompt="Start prompt", - is_start=True, - add_global_prompt=True, - ), - ), - RFNodeDTO( - id="node1", - type=NodeType.agentNode, - position={"x": 100, "y": 0}, - data=NodeDataDTO( - name="Node with global", prompt="Test prompt", add_global_prompt=True - ), - ), - RFNodeDTO( - id="node2", - type=NodeType.agentNode, - position={"x": 200, "y": 0}, - data=NodeDataDTO( - name="Node without global", - prompt="Test prompt", - add_global_prompt=False, - ), - ), - RFNodeDTO( - id="end", - type=NodeType.endNode, - position={"x": 300, "y": 0}, - data=NodeDataDTO( - name="End", prompt="End prompt", is_end=True, add_global_prompt=True - ), - ), - ] - - edges = [ - RFEdgeDTO( - id="e1", - source="start", - target="node1", - data=EdgeDataDTO(label="Next", condition="Continue"), - ), - RFEdgeDTO( - id="e2", - source="node1", - target="node2", - data=EdgeDataDTO(label="Next", condition="Continue"), - ), - RFEdgeDTO( - id="e3", - source="node2", - target="end", - data=EdgeDataDTO(label="End", condition="Finish"), - ), - ] - - flow_dto = ReactFlowDTO(nodes=nodes, edges=edges) - workflow = WorkflowGraph(flow_dto) - - assert workflow.nodes["start"].add_global_prompt is True - assert workflow.nodes["node1"].add_global_prompt is True - assert workflow.nodes["node2"].add_global_prompt is False - assert workflow.nodes["end"].add_global_prompt is True - print("✓ Workflow nodes correctly inherit add_global_prompt setting") - - -def test_compose_system_message_respects_global_prompt_flag(): - """Test that system message composition respects add_global_prompt flag.""" - # This is a simplified version - in real tests we'd use the full engine - # But this demonstrates the logic - - class MockNode: - def __init__(self, add_global_prompt, prompt): - self.add_global_prompt = add_global_prompt - self.prompt = prompt - self.out_edges = [] - self.extraction_enabled = False - - # Simulate the logic from _compose_system_message_functions_for_node - def compose_message(node, global_prompt): - prompts = [] - - # Only add global prompt if node.add_global_prompt is True - if global_prompt and node.add_global_prompt: - prompts.append(global_prompt) - - prompts.append(node.prompt) - - return "\n\n".join(p for p in prompts if p) - - global_prompt = "This is the global context" - - # Test with add_global_prompt=True - node_with_global = MockNode(add_global_prompt=True, prompt="Node prompt") - message_with = compose_message(node_with_global, global_prompt) - assert message_with == "This is the global context\n\nNode prompt" - - # Test with add_global_prompt=False - node_without_global = MockNode(add_global_prompt=False, prompt="Node prompt") - message_without = compose_message(node_without_global, global_prompt) - assert message_without == "Node prompt" - - print("✓ System message composition respects add_global_prompt flag") - - -def test_static_nodes_with_global_prompt(): - """Test static nodes can have add_global_prompt setting.""" - static_node_data = NodeDataDTO( - name="Static Node", prompt="Static text", is_static=True, add_global_prompt=True - ) - - assert static_node_data.is_static is True - assert static_node_data.add_global_prompt is True - print("✓ Static nodes can have add_global_prompt setting") - - -if __name__ == "__main__": - # Run all tests - test_node_data_dto_default_global_prompt() - test_node_data_dto_explicit_global_prompt() - test_workflow_node_inherits_global_prompt_setting() - test_compose_system_message_respects_global_prompt_flag() - test_static_nodes_with_global_prompt() - - print("\n✅ All unit tests passed!") diff --git a/api/tests/test_leave_counter.py b/api/tests/test_leave_counter.py deleted file mode 100644 index 4ef08ac..0000000 --- a/api/tests/test_leave_counter.py +++ /dev/null @@ -1,248 +0,0 @@ -""" -Test cases for _leave_counter mechanism in transport clients. - -This test suite verifies that the _leave_counter prevents premature disconnection -when both input and output transports are using the same client. -""" - -import asyncio -from unittest.mock import AsyncMock, Mock - -import pytest -from pipecat.frames.frames import EndFrame, StartFrame -from pipecat.transports.network.fastapi_websocket import ( - FastAPIWebsocketCallbacks, - FastAPIWebsocketClient, - FastAPIWebsocketParams, - FastAPIWebsocketTransport, -) -from pipecat.transports.network.small_webrtc import SmallWebRTCClient - -from api.services.telephony.stasis_rtp_client import StasisRTPClient - - -class TestLeaveCounterFastAPIWebsocket: - """Test the _leave_counter mechanism in FastAPIWebsocketClient.""" - - @pytest.mark.asyncio - async def test_leave_counter_prevents_early_disconnect(self): - """Test that disconnect only happens when both transports have disconnected.""" - # Create mock websocket - mock_websocket = Mock() - mock_websocket.close = AsyncMock() - # Set client_state directly to WebSocketState.CONNECTED value - from starlette.websockets import WebSocketState - - mock_websocket.client_state = WebSocketState.CONNECTED - - # Create callbacks - callbacks = FastAPIWebsocketCallbacks( - on_client_connected=AsyncMock(), - on_client_disconnected=AsyncMock(), - on_session_timeout=AsyncMock(), - ) - - # Create client - client = FastAPIWebsocketClient( - mock_websocket, is_binary=False, callbacks=callbacks - ) - - # Create StartFrame - start_frame = StartFrame() - - # Simulate both input and output transports calling setup - await client.setup(start_frame) # Input transport - assert client._leave_counter == 1 - - await client.setup(start_frame) # Output transport - assert client._leave_counter == 2 - - # First disconnect - should not actually disconnect - await client.disconnect() - assert client._leave_counter == 1 - mock_websocket.close.assert_not_called() - callbacks.on_client_disconnected.assert_not_called() - - # Second disconnect - should actually disconnect - await client.disconnect() - assert client._leave_counter == 0 - mock_websocket.close.assert_called_once() - callbacks.on_client_disconnected.assert_called_once() - - -class TestLeaveCounterStasisRTP: - """Test the _leave_counter mechanism in StasisRTPClient.""" - - @pytest.mark.asyncio - async def test_leave_counter_prevents_early_disconnect(self): - """Test that disconnect only happens when both transports have disconnected.""" - # Create mock connection - mock_connection = Mock() - mock_connection.is_connected.return_value = True - mock_connection.disconnect = AsyncMock() - mock_connection.notify_sockets_closed = AsyncMock() - - # Mock event_handler as a callable that acts as a decorator - def mock_event_handler(event_name): - def decorator(func): - return func - - return decorator - - mock_connection.event_handler = mock_event_handler - - # Create callbacks - from api.services.telephony.stasis_rtp_transport import StasisRTPCallbacks - - callbacks = StasisRTPCallbacks( - on_client_connected=AsyncMock(), - on_client_disconnected=AsyncMock(), - on_client_closed=AsyncMock(), - ) - - # Create client - client = StasisRTPClient(mock_connection, callbacks) - - # Create StartFrame - start_frame = StartFrame() - - # Simulate both input and output transports calling setup - await client.setup(start_frame) # Input transport - assert client._leave_counter == 1 - - await client.setup(start_frame) # Output transport - assert client._leave_counter == 2 - - # First disconnect - should not actually disconnect - await client.disconnect() - assert client._leave_counter == 1 - mock_connection.disconnect.assert_not_called() - - # Second disconnect - should actually disconnect - await client.disconnect() - assert client._leave_counter == 0 - mock_connection.disconnect.assert_called_once() - - -class TestLeaveCounterSmallWebRTC: - """Test the _leave_counter mechanism in SmallWebRTCClient.""" - - @pytest.mark.asyncio - async def test_leave_counter_prevents_early_disconnect(self): - """Test that disconnect only happens when both transports have disconnected.""" - # Create mock connection - mock_connection = Mock() - mock_connection.is_connected.return_value = True - mock_connection.disconnect = AsyncMock() - mock_connection.notify_sockets_closed = AsyncMock() - - # Mock event_handler as a callable that acts as a decorator - def mock_event_handler(event_name): - def decorator(func): - return func - - return decorator - - mock_connection.event_handler = mock_event_handler - - # Create callbacks - from pipecat.transports.network.small_webrtc import SmallWebRTCCallbacks - - callbacks = SmallWebRTCCallbacks( - on_app_message=AsyncMock(), - on_client_connected=AsyncMock(), - on_client_disconnected=AsyncMock(), - ) - - # Create client - client = SmallWebRTCClient(mock_connection, callbacks) - - # Create StartFrame with required attributes - start_frame = StartFrame() - - # Create mock transport params - from pipecat.transports.base_transport import TransportParams - - params = TransportParams( - audio_in_channels=1, audio_in_sample_rate=16000, audio_out_sample_rate=16000 - ) - - # Simulate both input and output transports calling setup - await client.setup(params, start_frame) # Input transport - assert client._leave_counter == 1 - - await client.setup(params, start_frame) # Output transport - assert client._leave_counter == 2 - - # First disconnect - should not actually disconnect - await client.disconnect() - assert client._leave_counter == 1 - mock_connection.disconnect.assert_not_called() - - # Second disconnect - should actually disconnect - await client.disconnect() - assert client._leave_counter == 0 - mock_connection.disconnect.assert_called_once() - - -@pytest.mark.skip(reason="Complex integration test - requires additional mocking") -@pytest.mark.asyncio -async def test_transport_lifecycle_with_leave_counter(): - """Test complete transport lifecycle with proper leave counter handling.""" - # Create mock websocket - mock_websocket = Mock() - mock_websocket.close = AsyncMock() - # Set client_state directly to WebSocketState.CONNECTED value - from starlette.websockets import WebSocketState - - mock_websocket.client_state = WebSocketState.CONNECTED - mock_websocket.iter_bytes = Mock(return_value=iter([])) - mock_websocket.send_bytes = AsyncMock() - - # Create transport - params = FastAPIWebsocketParams(audio_in_enabled=True, audio_out_enabled=True) - transport = FastAPIWebsocketTransport(mock_websocket, params) - - # Get input and output transports - input_transport = transport.input() - output_transport = transport.output() - - # Setup the transport with required components - from pipecat.clocks.system_clock import SystemClock - from pipecat.processors.frame_processor import FrameProcessorSetup - from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams - - clock = SystemClock() - task_manager = TaskManager() - - # Setup task manager with event loop - loop = asyncio.get_event_loop() - task_manager_params = TaskManagerParams(loop=loop) - task_manager.setup(task_manager_params) - - setup = FrameProcessorSetup(clock=clock, task_manager=task_manager) - - # Setup both input and output transports - await input_transport.setup(setup) - await output_transport.setup(setup) - - # Start both transports - start_frame = StartFrame() - await input_transport.start(start_frame) - await output_transport.start(start_frame) - - # Verify leave counter is 2 - assert transport._client._leave_counter == 2 - - # Stop input transport - end_frame = EndFrame() - await input_transport.stop(end_frame) - - # Verify websocket not closed yet - mock_websocket.close.assert_not_called() - - # Stop output transport - await output_transport.stop(end_frame) - - # Now websocket should be closed - mock_websocket.close.assert_called_once() diff --git a/api/tests/test_llm_response_reorder.py b/api/tests/test_llm_response_reorder.py deleted file mode 100644 index e5727a1..0000000 --- a/api/tests/test_llm_response_reorder.py +++ /dev/null @@ -1,99 +0,0 @@ -import unittest - -from pipecat.frames.frames import ( - FunctionCallInProgressFrame, - LLMFullResponseStartFrame, -) -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.services.google.llm import ( - GoogleAssistantContextAggregator, - GoogleLLMContext, -) -from pipecat.services.openai.llm import OpenAIAssistantContextAggregator - - -class TestReorderOpenAIAssistantContextAggregator(unittest.IsolatedAsyncioTestCase): - async def test_reorder_function_messages_openai(self): - """Ensure that after a text aggregation the function-call messages are moved - to appear immediately after the text response, maintaining chronological - order (assistant text -> function call -> tool response). - """ - - context = OpenAILLMContext() - aggregator = OpenAIAssistantContextAggregator(context) - - # Simulate the start of an LLM response so that the aggregator creates a - # response session ID that is later used for re-ordering. - await aggregator._handle_llm_start(LLMFullResponseStartFrame()) - - # Simulate the model emitting a function call which the aggregator will - # record for potential re-ordering. - await aggregator._handle_function_call_in_progress( - FunctionCallInProgressFrame( - function_name="get_weather", - tool_call_id="1", - arguments={}, - ) - ) - - # Now push the textual part of the assistant response. This should - # trigger the re-ordering so that the two function-related messages - # appear *after* this text. - await aggregator.handle_aggregation("Hello!") - - messages = context.get_messages() - - # We expect exactly three messages after re-ordering. - self.assertEqual(len(messages), 3) - - # 1. Assistant text - self.assertEqual(messages[0]["role"], "assistant") - self.assertEqual(messages[0]["content"], "Hello!") - - # 2. Assistant function-call message - self.assertEqual(messages[1]["role"], "assistant") - self.assertIn("tool_calls", messages[1]) - - # 3. Tool response - self.assertEqual(messages[2]["role"], "tool") - self.assertEqual(messages[2]["tool_call_id"], "1") - - -class TestReorderGoogleAssistantContextAggregator(unittest.IsolatedAsyncioTestCase): - async def test_reorder_function_messages_google(self): - context = GoogleLLMContext() - aggregator = GoogleAssistantContextAggregator(context) - - # Start an LLM response session. - await aggregator._handle_llm_start(LLMFullResponseStartFrame()) - - # Emit a function call. - await aggregator._handle_function_call_in_progress( - FunctionCallInProgressFrame( - function_name="get_weather", - tool_call_id="1", - arguments={}, - ) - ) - - # Push the textual content. - await aggregator.handle_aggregation("Hello!") - - messages = context.messages # Google context stores Content objects. - - self.assertEqual(len(messages), 3) - - # The first message should be the model text. - first_msg = messages[0].to_json_dict() - self.assertEqual(first_msg["role"], "model") - self.assertEqual(first_msg["parts"][0]["text"], "Hello!") - - # The second message contains the function call (also from the model). - second_msg = messages[1].to_json_dict() - self.assertEqual(second_msg["role"], "model") - self.assertIn("function_call", second_msg["parts"][0]) - - # The third message is the placeholder function response. - third_msg = messages[2].to_json_dict() - self.assertEqual(third_msg["role"], "user") - self.assertIn("function_response", third_msg["parts"][0]) diff --git a/api/tests/test_looptalk_routes.py b/api/tests/test_looptalk_routes.py deleted file mode 100644 index 23f19a4..0000000 --- a/api/tests/test_looptalk_routes.py +++ /dev/null @@ -1,506 +0,0 @@ -""" -Tests for LoopTalk API routes and orchestration. - -This module tests the LoopTalk testing functionality including test session creation, -pipeline orchestration, and agent-to-agent communication. -""" - -import asyncio -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -import pytest_asyncio -from fastapi import status - -from api.db.db_client import DBClient -from api.services.looptalk.orchestrator import LoopTalkTestOrchestrator - - -@pytest.fixture -def actor_workflow_definition(): - """Sample actor workflow definition for testing.""" - return { - "nodes": [ - { - "id": "1", - "type": "startCall", - "position": {"x": 0, "y": 0}, - "data": { - "prompt": "Hello, I'm the actor agent.", - "is_static": True, - "name": "Start Call", - "is_start": True, - "allow_interrupt": False, - }, - }, - { - "id": "2", - "type": "agentNode", - "position": {"x": 100, "y": 0}, - "data": { - "prompt": "You are an actor agent testing the adversary. Ask probing questions.", - "name": "Actor Agent", - "allow_interrupt": True, - }, - }, - { - "id": "3", - "type": "endCall", - "position": {"x": 200, "y": 0}, - "data": { - "prompt": "Goodbye!", - "name": "End Call", - "is_end": True, - }, - }, - ], - "edges": [ - { - "id": "e1", - "source": "1", - "target": "2", - "data": {"label": "Continue", "condition": "Always"}, - }, - { - "id": "e2", - "source": "2", - "target": "3", - "data": {"label": "End", "condition": "Always"}, - }, - ], - "stt": {"provider": "openai", "api_key": "test-key", "model": "whisper-1"}, - "llm": {"provider": "openai", "api_key": "test-key", "model": "gpt-4o-mini"}, - "tts": { - "provider": "openai", - "api_key": "test-key", - "model": "tts-1", - "voice": "nova", - }, - } - - -@pytest.fixture -def adversary_workflow_definition(): - """Sample adversary workflow definition for testing.""" - return { - "nodes": [ - { - "id": "1", - "type": "startCall", - "position": {"x": 0, "y": 0}, - "data": { - "prompt": "Hello, I'm the adversary agent.", - "is_static": True, - "name": "Start Call", - "is_start": True, - "allow_interrupt": False, - }, - }, - { - "id": "2", - "type": "agentNode", - "position": {"x": 100, "y": 0}, - "data": { - "prompt": "You are an adversary agent being tested. Respond defensively.", - "name": "Adversary Agent", - "allow_interrupt": True, - }, - }, - { - "id": "3", - "type": "endCall", - "position": {"x": 200, "y": 0}, - "data": { - "prompt": "Goodbye!", - "name": "End Call", - "is_end": True, - }, - }, - ], - "edges": [ - { - "id": "e1", - "source": "1", - "target": "2", - "data": {"label": "Continue", "condition": "Always"}, - }, - { - "id": "e2", - "source": "2", - "target": "3", - "data": {"label": "End", "condition": "Always"}, - }, - ], - "stt": {"provider": "deepgram", "api_key": "test-key", "model": "nova-2"}, - "llm": { - "provider": "groq", - "api_key": "test-key", - "model": "llama-3.1-70b-versatile", - }, - "tts": {"provider": "deepgram", "api_key": "test-key", "voice": "nova-2"}, - } - - -from pipecat.processors.frame_processor import FrameProcessor - - -class MockSTTService(FrameProcessor): - """Mock STT service for testing.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - async def run_stt(self, audio: bytes) -> str: - return "Mock transcription" - - -class MockLLMService(FrameProcessor): - """Mock LLM service for testing.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - async def run_llm(self, messages) -> str: - return "Mock LLM response" - - def create_context_aggregator(self, context): - """Mock context aggregator creation.""" - return MagicMock() - - -class MockTTSService(FrameProcessor): - """Mock TTS service for testing.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - async def run_tts(self, text: str) -> bytes: - return b"Mock audio data" - - -@pytest_asyncio.fixture -async def test_user_with_org(db_session): - """Create a test user with an organization set up.""" - user = await db_session.get_or_create_user_by_provider_id("test_looptalk_user") - org, _ = await db_session.get_or_create_organization_by_provider_id( - "test_looptalk_org" - ) - - user_id = user.id - org_id = org.id - - await db_session.add_user_to_organization(user_id, org_id) - - # Update user's selected organization - async with db_session.async_session() as session: - from sqlalchemy import update - - from api.db.models import UserModel - - await session.execute( - update(UserModel) - .where(UserModel.id == user_id) - .values(selected_organization_id=org_id) - ) - await session.commit() - - # Return fresh user object - return await db_session.get_user_by_id(user_id) - - -@pytest.mark.asyncio -async def test_create_test_session( - test_client_factory, - db_session, - test_user_with_org, - actor_workflow_definition, - adversary_workflow_definition, -): - """Test creating a new LoopTalk test session.""" - async with test_client_factory(test_user_with_org) as test_client: - # First create two workflows - actor_workflow_response = await test_client.post( - "/api/v1/workflow/create", - json={ - "name": "Actor Workflow", - "workflow_definition": actor_workflow_definition, - }, - ) - assert actor_workflow_response.status_code == status.HTTP_200_OK - actor_workflow_id = actor_workflow_response.json()["id"] - - adversary_workflow_response = await test_client.post( - "/api/v1/workflow/create", - json={ - "name": "Adversary Workflow", - "workflow_definition": adversary_workflow_definition, - }, - ) - assert adversary_workflow_response.status_code == status.HTTP_200_OK - adversary_workflow_id = adversary_workflow_response.json()["id"] - - # Create test session - response = await test_client.post( - "/api/v1/looptalk/test-sessions", - json={ - "name": "Test Session 1", - "actor_workflow_id": actor_workflow_id, - "adversary_workflow_id": adversary_workflow_id, - "config": {"test_duration": 60}, - }, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["name"] == "Test Session 1" - assert data["status"] == "pending" - assert data["actor_workflow_id"] == actor_workflow_id - assert data["adversary_workflow_id"] == adversary_workflow_id - assert data["config"]["test_duration"] == 60 - - -@pytest.mark.asyncio -async def test_list_test_sessions(test_client_factory, db_session, test_user_with_org): - """Test listing LoopTalk test sessions.""" - async with test_client_factory(test_user_with_org) as test_client: - response = await test_client.get( - "/api/v1/looptalk/test-sessions", - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert isinstance(data, list) - - -@pytest.mark.asyncio -async def test_looptalk_orchestrator_plumbing( - db_session: DBClient, actor_workflow_definition, adversary_workflow_definition -): - """Test the LoopTalk orchestrator plumbing with mocked services.""" - - # Create test user and organization - user = await db_session.get_or_create_user_by_provider_id( - provider_id="test-user-123" - ) - org, _ = await db_session.get_or_create_organization_by_provider_id( - org_provider_id="test-org-123" - ) - - # Get IDs before session closes - user_id = user.id - org_id = org.id - - await db_session.add_user_to_organization(user_id, org_id) - - # Update user's selected organization manually - async with db_session.async_session() as session: - from sqlalchemy import update - - from api.db.models import UserModel - - await session.execute( - update(UserModel) - .where(UserModel.id == user_id) - .values(selected_organization_id=org_id) - ) - await session.commit() - - actor_workflow = await db_session.create_workflow( - name="Actor Workflow", - workflow_definition=actor_workflow_definition, - user_id=user_id, - ) - - adversary_workflow = await db_session.create_workflow( - name="Adversary Workflow", - workflow_definition=adversary_workflow_definition, - user_id=user_id, - ) - - # Create test session - test_session = await db_session.create_test_session( - organization_id=org_id, - name="Test Session", - actor_workflow_id=actor_workflow.id, - adversary_workflow_id=adversary_workflow.id, - config={"test_duration": 10}, - ) - - # Mock the service factories - patch at the actual import location in pipeline_builder - with ( - patch( - "api.services.looptalk.core.pipeline_builder.create_stt_service" - ) as mock_stt_factory, - patch( - "api.services.looptalk.core.pipeline_builder.create_llm_service" - ) as mock_llm_factory, - patch( - "api.services.looptalk.core.pipeline_builder.create_tts_service" - ) as mock_tts_factory, - patch( - "api.services.workflow.pipecat_engine.PipecatEngine" - ) as mock_engine_class, - patch( - "api.services.pipecat.pipeline_builder.build_pipeline" - ) as mock_build_pipeline, - patch("api.services.pipecat.pipeline_builder.PipelineTask") as mock_task_class, - ): - # Configure mocks - mock_stt_factory.return_value = MockSTTService() - mock_llm_factory.return_value = MockLLMService() - mock_tts_factory.return_value = MockTTSService() - - mock_engine = MagicMock() - mock_engine.initialize = AsyncMock() - mock_engine.get_callback_processor = MagicMock(return_value=MagicMock()) - mock_engine_class.return_value = mock_engine - - # Mock pipeline and task - mock_pipeline = MagicMock() - mock_task = MagicMock() - mock_task.run = AsyncMock() - mock_task.cancel = AsyncMock() # Make cancel async - mock_build_pipeline.return_value = mock_pipeline - mock_task_class.return_value = mock_task - - # Create orchestrator - orchestrator = LoopTalkTestOrchestrator(db_client=db_session) - - # Start test session (in a separate task to avoid blocking) - start_task = asyncio.create_task( - orchestrator.start_test_session( - test_session_id=test_session.id, organization_id=org_id - ) - ) - - # Give it a moment to start - await asyncio.sleep(0.5) - - # Verify the session is running through session manager - session_info = orchestrator.session_manager.get_session(test_session.id) - assert session_info is not None - assert session_info["test_session"].id == test_session.id - assert "actor_task" in session_info - assert "adversary_task" in session_info - - # Verify service factories were called - assert mock_stt_factory.call_count == 2 # Once for each agent - assert mock_llm_factory.call_count == 2 - assert mock_tts_factory.call_count == 2 - - # Verify pipelines were created with PipelineTask - assert mock_task_class.call_count == 2 - - # Stop the test session - await orchestrator.stop_test_session(test_session_id=test_session.id) - - # Verify session was cleaned up - assert orchestrator.session_manager.get_session(test_session.id) is None - - # Cancel the start task - start_task.cancel() - try: - await start_task - except asyncio.CancelledError: - pass - - -@pytest.mark.asyncio -async def test_load_test_creation( - test_client_factory, - db_session, - test_user_with_org, - actor_workflow_definition, - adversary_workflow_definition, -): - """Test creating a load test with multiple sessions.""" - async with test_client_factory(test_user_with_org) as test_client: - # First create two workflows - actor_workflow_response = await test_client.post( - "/api/v1/workflow/create", - json={ - "name": "Actor Workflow", - "workflow_definition": actor_workflow_definition, - }, - ) - actor_workflow_id = actor_workflow_response.json()["id"] - - adversary_workflow_response = await test_client.post( - "/api/v1/workflow/create", - json={ - "name": "Adversary Workflow", - "workflow_definition": adversary_workflow_definition, - }, - ) - adversary_workflow_id = adversary_workflow_response.json()["id"] - - # Create load test - response = await test_client.post( - "/api/v1/looptalk/load-tests", - json={ - "name_prefix": "Load Test", - "actor_workflow_id": actor_workflow_id, - "adversary_workflow_id": adversary_workflow_id, - "test_count": 3, - "config": {"test_duration": 30}, - }, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["total"] == 3 - assert "load_test_group_id" in data - assert len(data["test_session_ids"]) == 3 - - -@pytest.mark.asyncio -async def test_invalid_workflow_ids( - test_client_factory, db_session, test_user_with_org -): - """Test creating test session with invalid workflow IDs.""" - async with test_client_factory(test_user_with_org) as test_client: - response = await test_client.post( - "/api/v1/looptalk/test-sessions", - json={ - "name": "Invalid Test", - "actor_workflow_id": 99999, - "adversary_workflow_id": 99999, - "config": {}, - }, - ) - - assert response.status_code == status.HTTP_404_NOT_FOUND - assert "workflow not found" in response.json()["detail"].lower() - - -@pytest.mark.asyncio -async def test_transport_manager(): - """Test the internal transport manager functionality.""" - from pipecat.transports import InternalTransportManager, TransportParams - - manager = InternalTransportManager() - - # Create transport pair - params = TransportParams( - audio_out_enabled=True, - audio_in_enabled=True, - audio_out_sample_rate=16000, - audio_in_sample_rate=16000, - ) - - actor_transport, adversary_transport = manager.create_transport_pair( - test_session_id="test-123", actor_params=params, adversary_params=params - ) - - # Verify transports are connected - assert actor_transport._output._partner == adversary_transport._input - assert adversary_transport._output._partner == actor_transport._input - - # Verify transport pair is tracked - assert manager.get_active_test_count() == 1 - assert manager.get_transport_pair("test-123") is not None - - # Remove transport pair - manager.remove_transport_pair("test-123") - assert manager.get_active_test_count() == 0 - assert manager.get_transport_pair("test-123") is None diff --git a/api/tests/test_mock_llm_service.py b/api/tests/test_mock_llm_service.py deleted file mode 100644 index a0599b9..0000000 --- a/api/tests/test_mock_llm_service.py +++ /dev/null @@ -1,142 +0,0 @@ -### - The test gets stuck. Need to figure out a way to run the test - -# import asyncio -# import unittest - -# from loguru import logger - -# from pipecat.frames.frames import ( -# FunctionCallFromLLM, -# FunctionCallInProgressFrame, -# FunctionCallResultFrame, -# FunctionCallsStartedFrame, -# LLMFullResponseEndFrame, -# LLMFullResponseStartFrame, -# LLMTextFrame, -# ) -# from pipecat.processors.aggregators.openai_llm_context import ( -# OpenAILLMContext, -# OpenAILLMContextFrame, -# ) -# from pipecat.processors.frame_processor import FrameDirection -# from pipecat.services.llm_service import ( -# FunctionCallParams, -# FunctionCallResultProperties, -# LLMService, -# ) -# from pipecat.tests.utils import run_test - - -# class MockLLMService(LLMService): -# """A very small mocked LLM service that, upon receiving an -# ``OpenAILLMContextFrame``, streams a text completion followed by the -# execution of the supplied tools (function calls). -# """ - -# def __init__(self, *, content: str, tools: list[dict[str, dict]], **kwargs): -# # Run function calls sequentially so that frame ordering is deterministic. -# super().__init__(run_in_parallel=False, **kwargs) -# self._content = content -# self._tools = tools - -# async def process_frame(self, frame, direction: FrameDirection): -# await super().process_frame(frame, direction) - -# if isinstance(frame, OpenAILLMContextFrame) and direction == FrameDirection.DOWNSTREAM: -# # Simulate the start of a streamed completion. -# await self.push_frame(LLMFullResponseStartFrame()) -# await self.push_frame(LLMTextFrame(self._content)) - -# # Convert tool specs into FunctionCallFromLLM objects. -# function_calls = [] -# for idx, tool in enumerate(self._tools): -# function_calls.append( -# FunctionCallFromLLM( -# function_name=tool["function_name"], -# tool_call_id=f"tool_{idx}", -# arguments=tool.get("arguments", {}), -# context=frame.context, -# ) -# ) - -# # Ask the LLM service base class to execute the calls. -# await self.run_function_calls(function_calls) - -# # Finish the streamed response. -# await self.push_frame(LLMFullResponseEndFrame()) - -# async def _run_function_call(self, runner_item): # type: ignore[override] – narrow signature -# # Ensure run_llm=True so that downstream processors know they can -# # immediately trigger another LLM call after the result is committed. -# runner_item.run_llm = True -# await super()._run_function_call(runner_item) - - -# class TestMockLLMPipeline(unittest.IsolatedAsyncioTestCase): -# async def test_mock_llm_pipeline_with_tools(self): -# # ------------------------------------------------------------------ -# # 1. Create mocked LLM service with completion text and tools -# # ------------------------------------------------------------------ -# completion_text = "Hello from mocked LLM!" -# tools = [ -# {"function_name": "tool_one", "arguments": {"a": 1}}, -# {"function_name": "tool_two", "arguments": {"b": 2}}, -# ] -# llm = MockLLMService(content=completion_text, tools=tools) - -# # ------------------------------------------------------------------ -# # 2. Register the tool functions – they simply log & sleep briefly. -# # Each of them marks that it has run so that we can assert later. -# # ------------------------------------------------------------------ -# executed: dict[str, bool] = {t["function_name"]: False for t in tools} - -# def make_handler(name: str): -# async def _handler(params: FunctionCallParams): -# logger.debug(f"Executing {name} with args {params.arguments}") -# executed[name] = True -# await asyncio.sleep(0.01) -# await params.result_callback( -# {"status": "ok"}, -# properties=FunctionCallResultProperties(run_llm=True), -# ) - -# return _handler - -# for t in tools: -# llm.register_function(t["function_name"], make_handler(t["function_name"])) - -# # ------------------------------------------------------------------ -# # 3. Build the pipeline and send the initial context frame that -# # triggers the completion. -# # ------------------------------------------------------------------ -# context = OpenAILLMContext() -# context.add_message({"role": "user", "content": "Hi!"}) -# frames_to_send = [OpenAILLMContextFrame(context)] - -# expected_down_frames = [ -# LLMFullResponseStartFrame, -# LLMTextFrame, -# FunctionCallsStartedFrame, -# FunctionCallInProgressFrame, -# FunctionCallResultFrame, -# FunctionCallInProgressFrame, -# FunctionCallResultFrame, -# LLMFullResponseEndFrame, -# ] - -# # Run the test pipeline. -# received_down_frames, _ = await run_test( -# llm, -# frames_to_send=frames_to_send, -# expected_down_frames=expected_down_frames, -# ) - -# # ------------------------------------------------------------------ -# # 4. Verify that both tool functions executed and that run_llm=True -# # in all FunctionCallResultFrame instances. -# # ------------------------------------------------------------------ -# self.assertTrue(all(executed.values())) - -# for frame in received_down_frames: -# if isinstance(frame, FunctionCallResultFrame): -# self.assertTrue(frame.run_llm) diff --git a/api/tests/test_pipecat_disposition_mapping.py b/api/tests/test_pipecat_disposition_mapping.py deleted file mode 100644 index aa0c289..0000000 --- a/api/tests/test_pipecat_disposition_mapping.py +++ /dev/null @@ -1,236 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from api.services.workflow.pipecat_engine import PipecatEngine - - -def create_disposition_mapping_side_effect(mapping_dict): - """Helper to create a side effect function for disposition mapping.""" - - async def side_effect(value, org_id): - return mapping_dict.get(value, value) - - return side_effect - - -@pytest.fixture -def mock_dependencies(): - """Create mock dependencies for PipecatEngine.""" - mock_task = MagicMock() - mock_task.queue_frame = AsyncMock() - - mock_llm = MagicMock() - mock_context = MagicMock() - mock_workflow = MagicMock() - - return { - "task": mock_task, - "llm": mock_llm, - "context": mock_context, - "workflow": mock_workflow, - "call_context_vars": {}, - "workflow_run_id": 123, - } - - -@pytest.mark.asyncio -async def test_apply_disposition_mapping_with_call_disposition(mock_dependencies): - """Test disposition mapping when call_disposition is present.""" - engine = PipecatEngine(**mock_dependencies) - - # Setup gathered context - engine._gathered_context = { - "call_disposition": "XFER", - "agent_name": "Alex", - "total_debt": "$15000", - } - - # Mock the disposition mapper functions - with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run" - ) as mock_get_org_id: - with patch( - "api.services.workflow.pipecat_engine.apply_disposition_mapping" - ) as mock_apply_mapping: - # Mock organization ID - mock_get_org_id.return_value = 1 - - # Mock disposition mapping - mock_apply_mapping.side_effect = create_disposition_mapping_side_effect( - { - "XFER": "TRANSFERRED", - "ND": "NOT_QUALIFIED", - } - ) - - # Call send_end_task_frame - await engine.send_end_task_frame(reason="user_qualified") - - # Verify the frame was queued with mapped values - mock_dependencies["task"].queue_frame.assert_called_once() - frame = mock_dependencies["task"].queue_frame.call_args[0][0] - - # Check metadata contains mapped values - assert frame.metadata["reason"] == "user_qualified" # No mapping for this - assert ( - frame.metadata["call_transfer_context"]["disposition"] == "TRANSFERRED" - ) - - # Check gathered context was updated - assert engine._gathered_context["call_disposition"] == "TRANSFERRED" - - -@pytest.mark.asyncio -async def test_apply_disposition_mapping_with_disconnect_reason(mock_dependencies): - """Test disposition mapping for disconnect_reason when no call_disposition exists.""" - engine = PipecatEngine(**mock_dependencies) - - # Setup gathered context without call_disposition - engine._gathered_context = { - "agent_name": "Alex", - } - - # Mock the disposition mapper functions - with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run" - ) as mock_get_org_id: - with patch( - "api.services.workflow.pipecat_engine.apply_disposition_mapping" - ) as mock_apply_mapping: - # Mock organization ID - mock_get_org_id.return_value = 1 - - # Mock disposition mapping - mock_apply_mapping.side_effect = create_disposition_mapping_side_effect( - { - "user_qualified": "QUALIFIED", - "user_disqualified": "NOT_QUALIFIED", - "user_hangup": "HANGUP", - } - ) - - # Call send_end_task_frame with a mappable reason - await engine.send_end_task_frame(reason="user_qualified") - - # Verify the frame was queued with mapped disposition - mock_dependencies["task"].queue_frame.assert_called_once() - frame = mock_dependencies["task"].queue_frame.call_args[0][0] - - # Check metadata contains original reason - assert frame.metadata["reason"] == "user_qualified" - - # Check call_transfer_context has mapped disconnect_reason as disposition - assert frame.metadata["call_transfer_context"]["disposition"] == "QUALIFIED" - - # Check gathered context was updated with mapped call_disposition - assert engine._gathered_context["call_disposition"] == "QUALIFIED" - - # Check internal call_disposition stores mapped value - assert engine._call_disposition == "QUALIFIED" - - -@pytest.mark.asyncio -async def test_call_disposition_takes_precedence(mock_dependencies): - """Test that call_disposition is used when both call_disposition and reason could be mapped.""" - engine = PipecatEngine(**mock_dependencies) - - # Setup gathered context with call_disposition - engine._gathered_context = { - "call_disposition": "XFER", - "agent_name": "Alex", - } - - # Mock the disposition mapper functions - with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run" - ) as mock_get_org_id: - with patch( - "api.services.workflow.pipecat_engine.apply_disposition_mapping" - ) as mock_apply_mapping: - # Mock organization ID - mock_get_org_id.return_value = 1 - - # Mock disposition mapping - mock_apply_mapping.side_effect = create_disposition_mapping_side_effect( - { - "XFER": "TRANSFERRED", - "user_qualified": "QUALIFIED", - } - ) - - # Call send_end_task_frame with a reason that could also be mapped - await engine.send_end_task_frame(reason="user_qualified") - - # Verify the frame was queued - mock_dependencies["task"].queue_frame.assert_called_once() - frame = mock_dependencies["task"].queue_frame.call_args[0][0] - - # Check that call_disposition mapping was used, not reason mapping - assert ( - frame.metadata["call_transfer_context"]["disposition"] == "TRANSFERRED" - ) - - # Check only call_disposition was updated in gathered context - assert engine._gathered_context["call_disposition"] == "TRANSFERRED" - assert "disconnect_reason" not in engine._gathered_context - - -@pytest.mark.asyncio -async def test_disposition_mapping_no_organization_id(mock_dependencies): - """Test when organization_id cannot be retrieved.""" - # Set workflow_run_id to None - mock_dependencies["workflow_run_id"] = None - engine = PipecatEngine(**mock_dependencies) - - engine._gathered_context = { - "call_disposition": "XFER", - } - - # Call send_end_task_frame - await engine.send_end_task_frame(reason="user_qualified") - - # Verify the frame was queued with original values (no mapping) - mock_dependencies["task"].queue_frame.assert_called_once() - frame = mock_dependencies["task"].queue_frame.call_args[0][0] - - # Check values remain unchanged - assert frame.metadata["reason"] == "user_qualified" - assert frame.metadata["call_transfer_context"]["disposition"] == "XFER" - - # Gathered context should remain unchanged - assert engine._gathered_context["call_disposition"] == "XFER" - - -@pytest.mark.asyncio -async def test_disposition_mapping_no_configuration(mock_dependencies): - """Test when no disposition mapping is configured.""" - engine = PipecatEngine(**mock_dependencies) - - engine._gathered_context = { - "call_disposition": "XFER", - } - - # Mock the disposition mapper functions - with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run" - ) as mock_get_org_id: - with patch( - "api.services.workflow.pipecat_engine.apply_disposition_mapping" - ) as mock_apply_mapping: - # Mock organization ID - mock_get_org_id.return_value = 1 - - # Mock no disposition mapping (return original value) - mock_apply_mapping.side_effect = lambda value, org_id: value - - # Call send_end_task_frame - await engine.send_end_task_frame(reason="user_qualified") - - # Verify the frame was queued with original values - mock_dependencies["task"].queue_frame.assert_called_once() - frame = mock_dependencies["task"].queue_frame.call_args[0][0] - - # Check values remain unchanged - assert frame.metadata["reason"] == "user_qualified" - assert frame.metadata["call_transfer_context"]["disposition"] == "XFER" diff --git a/api/tests/test_pipecat_engine.py b/api/tests/test_pipecat_engine.py deleted file mode 100644 index 8e68067..0000000 --- a/api/tests/test_pipecat_engine.py +++ /dev/null @@ -1,206 +0,0 @@ -from unittest.mock import Mock - -import pytest - -from api.services.workflow.pipecat_engine import PipecatEngine -from api.services.workflow.workflow import WorkflowGraph - - -class TestPipecatEngine: - @pytest.fixture - def mock_dependencies(self): - """Create mock dependencies for PipecatEngine initialization.""" - return { - "task": Mock(), - "llm": Mock(), - "context": Mock(), - "tts": Mock(), - "transport": Mock(), - "workflow": Mock(spec=WorkflowGraph), - "call_context_vars": {}, - } - - @pytest.fixture - def engine_with_context(self, mock_dependencies): - """Create a PipecatEngine instance with test context variables.""" - context_vars = { - "first_name": "John", - "last_name": "Doe", - "age": 25, - "email": "john.doe@example.com", - "empty_var": "", - "zero_var": 0, - "false_var": False, - } - mock_dependencies["call_context_vars"] = context_vars - return PipecatEngine(**mock_dependencies) - - @pytest.fixture - def engine_empty_context(self, mock_dependencies): - """Create a PipecatEngine instance with empty context variables.""" - mock_dependencies["call_context_vars"] = {} - return PipecatEngine(**mock_dependencies) - - def test_format_prompt_simple_variable_replacement(self, engine_with_context): - """Test simple variable replacement without filters.""" - prompt = "Hello {{ first_name }}, welcome!" - result = engine_with_context._format_prompt(prompt) - assert result == "Hello John, welcome!" - - def test_format_prompt_multiple_variables(self, engine_with_context): - """Test multiple variable replacements in a single prompt.""" - prompt = "Hello {{ first_name }} {{ last_name }}, you are {{ age }} years old." - result = engine_with_context._format_prompt(prompt) - assert result == "Hello John Doe, you are 25 years old." - - def test_format_prompt_with_fallback_existing_value(self, engine_with_context): - """Test fallback filter when value exists.""" - prompt = "Hello {{ first_name | fallback }}, nice to meet you!" - result = engine_with_context._format_prompt(prompt) - assert result == "Hello John, nice to meet you!" - - def test_format_prompt_with_fallback_missing_value(self, engine_empty_context): - """Test fallback filter when value is missing.""" - prompt = "Hello {{ first_name | fallback }}, nice to meet you!" - result = engine_empty_context._format_prompt(prompt) - assert result == "Hello First_Name, nice to meet you!" - - def test_format_prompt_with_custom_fallback_missing_value( - self, engine_empty_context - ): - """Test fallback filter with custom fallback value when variable is missing.""" - prompt = "Hello {{ first_name | fallback:Guest }}, welcome!" - result = engine_empty_context._format_prompt(prompt) - assert result == "Hello Guest, welcome!" - - def test_format_prompt_with_custom_fallback_existing_value( - self, engine_with_context - ): - """Test fallback filter with custom fallback value when variable exists.""" - prompt = "Hello {{ first_name | fallback:Guest }}, welcome!" - result = engine_with_context._format_prompt(prompt) - assert result == "Hello John, welcome!" - - def test_format_prompt_empty_string_variable(self, engine_with_context): - """Test variable with empty string value.""" - prompt = "Value: '{{ empty_var | fallback:No Value }}'" - result = engine_with_context._format_prompt(prompt) - assert result == "Value: 'No Value'" - - def test_format_prompt_zero_value(self, engine_with_context): - """Test variable with zero value (should not trigger fallback).""" - prompt = "Count: {{ zero_var | fallback:None }}" - result = engine_with_context._format_prompt(prompt) - assert result == "Count: 0" - - def test_format_prompt_false_value(self, engine_with_context): - """Test variable with False value (should not trigger fallback).""" - prompt = "Status: {{ false_var | fallback:Unknown }}" - result = engine_with_context._format_prompt(prompt) - assert result == "Status: False" - - def test_format_prompt_missing_variable_no_fallback(self, engine_empty_context): - """Test missing variable without fallback filter.""" - prompt = "Hello {{ missing_var }}, welcome!" - result = engine_empty_context._format_prompt(prompt) - assert result == "Hello , welcome!" - - def test_format_prompt_complex_mixed_scenario(self, engine_with_context): - """Test complex scenario with multiple variables, some with fallbacks.""" - prompt = ( - "Dear {{ first_name | fallback:Customer }}, " - "your email {{ email }} is confirmed. " - "{{ missing_info | fallback:Additional information }} will be sent later. " - "You are {{ age }} years old." - ) - result = engine_with_context._format_prompt(prompt) - expected = ( - "Dear John, " - "your email john.doe@example.com is confirmed. " - "Additional information will be sent later. " - "You are 25 years old." - ) - assert result == expected - - def test_format_prompt_whitespace_handling(self, engine_with_context): - """Test handling of whitespace in template variables.""" - prompt = "Hello {{ first_name | fallback : Default }}, welcome!" - result = engine_with_context._format_prompt(prompt) - assert result == "Hello John, welcome!" - - def test_format_prompt_no_variables(self, engine_with_context): - """Test prompt with no template variables.""" - prompt = "This is a regular prompt with no variables." - result = engine_with_context._format_prompt(prompt) - assert result == "This is a regular prompt with no variables." - - def test_format_prompt_empty_prompt(self, engine_with_context): - """Test empty prompt.""" - prompt = "" - result = engine_with_context._format_prompt(prompt) - assert result == "" - - def test_format_prompt_none_prompt(self, engine_with_context): - """Test None prompt.""" - prompt = None - result = engine_with_context._format_prompt(prompt) - assert result is None - - def test_format_prompt_nested_braces(self, engine_with_context): - """Test handling of nested or malformed braces.""" - prompt = "Hello {{ first_name }}, this {is not a template} variable." - result = engine_with_context._format_prompt(prompt) - assert result == "Hello John, this {is not a template} variable." - - def test_format_prompt_special_characters_in_value(self): - """Test variables containing special characters.""" - mock_deps = { - "task": Mock(), - "llm": Mock(), - "context": Mock(), - "tts": Mock(), - "transport": Mock(), - "workflow": Mock(spec=WorkflowGraph), - "call_context_vars": { - "special_name": "John & Jane's Company", - "email": "test@domain.com", - }, - } - engine = PipecatEngine(**mock_deps) - - prompt = "Company: {{ special_name }}, Contact: {{ email }}" - result = engine._format_prompt(prompt) - assert result == "Company: John & Jane's Company, Contact: test@domain.com" - - def test_format_prompt_numeric_and_boolean_conversion(self): - """Test conversion of different data types to strings.""" - mock_deps = { - "task": Mock(), - "llm": Mock(), - "context": Mock(), - "tts": Mock(), - "transport": Mock(), - "workflow": Mock(spec=WorkflowGraph), - "call_context_vars": { - "count": 42, - "price": 99.99, - "is_active": True, - "items": ["apple", "banana"], - }, - } - engine = PipecatEngine(**mock_deps) - - prompt = "Count: {{ count }}, Price: ${{ price }}, Active: {{ is_active }}, Items: {{ items }}" - result = engine._format_prompt(prompt) - assert ( - result - == "Count: 42, Price: $99.99, Active: True, Items: ['apple', 'banana']" - ) - - def test_format_prompt_case_sensitivity(self, engine_with_context): - """Test that variable names are case sensitive.""" - prompt = ( - "Hello {{ First_Name | fallback }}, welcome!" # Note the capitalization - ) - result = engine_with_context._format_prompt(prompt) - assert result == "Hello First_Name, welcome!" # Should use fallback diff --git a/api/tests/test_provider_switching.py b/api/tests/test_provider_switching.py deleted file mode 100644 index 9f11db7..0000000 --- a/api/tests/test_provider_switching.py +++ /dev/null @@ -1,295 +0,0 @@ -""" -Test scenarios for provider switching and billing integrity. -This test suite validates that the multi-provider telephony system -handles provider switches correctly without losing billing data. -""" - -import asyncio - -# Test scenarios to validate - - -async def test_scenario_1_mid_call_provider_switch(): - """ - Test: What happens if provider is switched while a call is active? - - Expected behavior: - - Active call continues with original provider - - Call is billed to original provider - - New calls use new provider - """ - print("Test 1: Mid-call provider switching") - - # Simulate workflow run with Twilio - twilio_run = { - "id": 1, - "mode": "twilio", - "cost_info": {"twilio_call_sid": "CA123456789", "provider": "twilio"}, - "is_completed": False, - } - - # Provider switch happens here (in real scenario, user changes config) - # But the call continues... - - # When cost calculation runs, it should: - # 1. Use the provider stored in cost_info - # 2. Fetch cost from Twilio using twilio_call_sid - # 3. Store cost with provider attribution - - result = { - "test": "mid_call_switch", - "status": "PASS", - "reason": "Call continues with original provider, billing intact", - } - print(f" ✓ {result['reason']}") - return result - - -async def test_scenario_2_pending_cost_calculation(): - """ - Test: Calls that ended but cost not yet calculated when provider switches. - - Expected behavior: - - Background job should use the provider info stored in cost_info - - Cost should be fetched from correct provider - """ - print("\nTest 2: Pending cost calculation during switch") - - # Workflow runs that ended but cost job hasn't run yet - pending_runs = [ - { - "id": 2, - "mode": "twilio", - "cost_info": {"twilio_call_sid": "CA987654321", "provider": "twilio"}, - "is_completed": True, - }, - { - "id": 3, - "mode": "vonage", - "cost_info": {"vonage_call_uuid": "uuid-123", "provider": "vonage"}, - "is_completed": True, - }, - ] - - # Provider switch happens here - # Cost calculation jobs run after switch - - # Each job should: - # 1. Check the provider field in cost_info - # 2. Use appropriate provider API to fetch cost - # 3. Handle gracefully if credentials changed - - result = { - "test": "pending_cost_calculation", - "status": "PASS", - "reason": "Cost jobs use stored provider info correctly", - } - print(f" ✓ {result['reason']}") - return result - - -async def test_scenario_3_mixed_provider_history(): - """ - Test: Organization has calls from both Twilio and Vonage. - - Expected behavior: - - Historical costs remain intact - - Reports show correct attribution - - Total costs aggregate correctly - """ - print("\nTest 3: Mixed provider history") - - historical_runs = [ - {"provider": "twilio", "cost_usd": 0.15, "date": "2024-01-01"}, - {"provider": "vonage", "cost_usd": 0.12, "date": "2024-01-02"}, - {"provider": "twilio", "cost_usd": 0.18, "date": "2024-01-03"}, - {"provider": "vonage", "cost_usd": 0.14, "date": "2024-01-04"}, - ] - - # Calculate totals - total_cost = sum(run["cost_usd"] for run in historical_runs) - twilio_cost = sum( - run["cost_usd"] for run in historical_runs if run["provider"] == "twilio" - ) - vonage_cost = sum( - run["cost_usd"] for run in historical_runs if run["provider"] == "vonage" - ) - - result = { - "test": "mixed_provider_history", - "status": "PASS", - "total_cost": total_cost, - "twilio_cost": twilio_cost, - "vonage_cost": vonage_cost, - "reason": f"Costs correctly aggregated: Total ${total_cost:.2f} (Twilio: ${twilio_cost:.2f}, Vonage: ${vonage_cost:.2f})", - } - print(f" ✓ {result['reason']}") - return result - - -async def test_scenario_4_cost_api_failure(): - """ - Test: Provider API fails when fetching cost. - - Expected behavior: - - Error logged but system continues - - Call record preserved - - Cost marked as 0 or unknown - """ - print("\nTest 4: Cost API failure handling") - - # Simulate API failure scenarios - failure_scenarios = [ - { - "provider": "twilio", - "error": "401 Unauthorized - credentials changed", - "expected": "Cost set to 0, error logged", - }, - { - "provider": "vonage", - "error": "404 Not Found - call record deleted", - "expected": "Cost set to 0, error logged", - }, - { - "provider": "twilio", - "error": "500 Internal Server Error", - "expected": "Cost set to 0, retry possible", - }, - ] - - for scenario in failure_scenarios: - print(f" - {scenario['provider']}: {scenario['error']}") - print(f" Expected: {scenario['expected']}") - - result = { - "test": "cost_api_failure", - "status": "PASS", - "reason": "All failure scenarios handled gracefully", - } - print(f" ✓ {result['reason']}") - return result - - -async def test_scenario_5_configuration_migration(): - """ - Test: Database migration from single to multi-provider format. - - Expected behavior: - - Old TWILIO_CONFIGURATION migrated to TELEPHONY_CONFIGURATION - - Single provider config wrapped in multi-provider structure - - Existing cost_info gets provider field added - """ - print("\nTest 5: Configuration migration") - - # Old format - old_config = { - "account_sid": "AC123", - "auth_token": "token123", - "from_numbers": ["+1234567890"], - "provider": "twilio", - } - - # New format after migration - new_config = { - "active_provider": "twilio", - "providers": { - "twilio": { - "account_sid": "AC123", - "auth_token": "token123", - "from_numbers": ["+1234567890"], - } - }, - } - - # Validate migration - assert new_config["active_provider"] == "twilio" - assert "providers" in new_config - assert new_config["providers"]["twilio"]["account_sid"] == old_config["account_sid"] - - result = { - "test": "configuration_migration", - "status": "PASS", - "reason": "Configuration migrated to multi-provider format correctly", - } - print(f" ✓ {result['reason']}") - return result - - -async def test_scenario_6_provider_cost_discrepancy(): - """ - Test: Webhook cost vs API cost discrepancy. - - Expected behavior: - - Webhook cost stored immediately if available - - API cost fetched later for verification - - Both costs stored for auditing - """ - print("\nTest 6: Provider cost discrepancy handling") - - # Vonage webhook provides immediate cost - webhook_cost = {"vonage_webhook_price": 0.15, "vonage_webhook_duration": 120} - - # API call provides authoritative cost - api_cost = { - "cost_usd": 0.14, # Slight difference - "duration": 120, - } - - # Both should be stored - final_cost_info = { - **webhook_cost, - "cost_breakdown": {"telephony_call": api_cost["cost_usd"]}, - "provider": "vonage", - } - - result = { - "test": "cost_discrepancy", - "status": "PASS", - "reason": "Both webhook and API costs stored for auditing", - } - print(f" ✓ {result['reason']}") - return result - - -async def run_all_tests(): - """Run all test scenarios.""" - print("=" * 60) - print("PROVIDER SWITCHING TEST SUITE") - print("=" * 60) - - tests = [ - test_scenario_1_mid_call_provider_switch, - test_scenario_2_pending_cost_calculation, - test_scenario_3_mixed_provider_history, - test_scenario_4_cost_api_failure, - test_scenario_5_configuration_migration, - test_scenario_6_provider_cost_discrepancy, - ] - - results = [] - for test in tests: - result = await test() - results.append(result) - - print("\n" + "=" * 60) - print("TEST SUMMARY") - print("=" * 60) - - passed = sum(1 for r in results if r["status"] == "PASS") - failed = sum(1 for r in results if r["status"] == "FAIL") - - print(f"Total Tests: {len(results)}") - print(f"Passed: {passed}") - print(f"Failed: {failed}") - - if failed == 0: - print("\n✅ ALL TESTS PASSED - Provider switching is working correctly!") - else: - print("\n❌ Some tests failed - Review the implementation") - - return results - - -if __name__ == "__main__": - # Run the test suite - asyncio.run(run_all_tests()) diff --git a/api/tests/test_run_integrations_db_client.py b/api/tests/test_run_integrations_db_client.py deleted file mode 100644 index 6c3f9d6..0000000 --- a/api/tests/test_run_integrations_db_client.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Tests for run_integrations with new DB client methods.""" - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from api.enums import WorkflowRunMode -from api.tasks.run_integrations import run_integrations_post_workflow_run - - -@pytest.fixture(autouse=True) -def mock_logger(): - """Mock the logger for all tests.""" - with patch("api.tasks.run_integrations.logger") as mock_logger: - mock_logger.bind.return_value = mock_logger - yield mock_logger - - -@pytest.fixture -def mock_workflow_run(): - """Create a mock workflow run with all required attributes.""" - workflow_run = MagicMock() - workflow_run.id = 1 - workflow_run.mode = "browser" - workflow_run.gathered_context = { - "call_disposition": "XFER", - "mapped_call_disposition": "XFER", # Required for Slack integration - "call_duration": "120", - "agent_name": "TestAgent", - } - workflow_run.initial_context = {"vendor_id": "123"} - - # Setup workflow and user chain - workflow_run.workflow = MagicMock() - workflow_run.workflow.user = MagicMock() - workflow_run.workflow.user.selected_organization_id = 100 - - return workflow_run - - -@pytest.fixture -def mock_integration(): - """Create a mock integration.""" - integration = MagicMock() - integration.id = 1 - integration.organisation_id = 100 - integration.provider = "slack" - integration.is_active = True - integration.connection_details = { - "connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test"} - } - return integration - - -@pytest.mark.asyncio -async def test_run_integrations_with_db_client_methods( - mock_workflow_run, mock_integration -): - """Test that run_integrations uses the new DB client methods correctly.""" - - with patch("api.tasks.run_integrations.set_current_run_id") as mock_set_run_id: - with patch("api.tasks.run_integrations.db_client") as mock_db_client: - # Mock the new DB client methods - mock_db_client.get_workflow_run_with_context = AsyncMock( - return_value=(mock_workflow_run, 100) - ) - mock_db_client.get_active_integrations_by_organization = AsyncMock( - return_value=[mock_integration] - ) - mock_db_client.get_configuration_value = AsyncMock( - return_value={ - "slack": { - "DISPOSITION_CODE": "Disposition: {{mapped_call_disposition}}" - } - } - ) - - # Mock the aiohttp session for Slack webhook - with patch( - "api.tasks.run_integrations.aiohttp.ClientSession" - ) as mock_session_class: - mock_response = MagicMock() - mock_response.status = 200 - - mock_session = MagicMock() - mock_session.__aenter__.return_value = mock_session - mock_session.__aexit__.return_value = AsyncMock() - - mock_post = MagicMock() - mock_post.__aenter__.return_value = mock_response - mock_post.__aexit__.return_value = AsyncMock() - - mock_session.post.return_value = mock_post - mock_session_class.return_value = mock_session - - # Call the function - await run_integrations_post_workflow_run(None, 1) - - # Verify the correct DB client methods were called - mock_set_run_id.assert_called_once_with(1) - mock_db_client.get_workflow_run_with_context.assert_called_once_with(1) - mock_db_client.get_active_integrations_by_organization.assert_called_once_with( - 100 - ) - - # Verify the Slack webhook was called - mock_session.post.assert_called_once() - assert ( - mock_session.post.call_args[0][0] == "https://hooks.slack.com/test" - ) - - -@pytest.mark.asyncio -async def test_run_integrations_no_workflow_run(): - """Test handling when workflow run is not found.""" - - with patch("api.tasks.run_integrations.set_current_run_id"): - with patch("api.tasks.run_integrations.db_client") as mock_db_client: - # Mock workflow run not found - mock_db_client.get_workflow_run_with_context = AsyncMock( - return_value=(None, None) - ) - - # Call the function - await run_integrations_post_workflow_run(None, 999) - - # Verify it returns early and doesn't call other DB methods - mock_db_client.get_workflow_run_with_context.assert_called_once_with(999) - mock_db_client.get_active_integrations_by_organization.assert_not_called() - - -@pytest.mark.asyncio -async def test_run_integrations_no_organization(): - """Test handling when user has no organization.""" - - mock_workflow_run = MagicMock() - mock_workflow_run.id = 1 - mock_workflow_run.gathered_context = {"test": "data"} - mock_workflow_run.workflow = MagicMock() - mock_workflow_run.workflow.user = MagicMock() - - with patch("api.tasks.run_integrations.set_current_run_id"): - with patch("api.tasks.run_integrations.db_client") as mock_db_client: - # Mock workflow run found but no organization - mock_db_client.get_workflow_run_with_context = AsyncMock( - return_value=(mock_workflow_run, None) - ) - - # Call the function - await run_integrations_post_workflow_run(None, 1) - - # Verify it returns early after checking organization - mock_db_client.get_workflow_run_with_context.assert_called_once_with(1) - mock_db_client.get_active_integrations_by_organization.assert_not_called() - - -@pytest.mark.asyncio -async def test_run_integrations_no_gathered_context(mock_workflow_run): - """Test handling when workflow run has no gathered context.""" - - mock_workflow_run.gathered_context = None - - with patch("api.tasks.run_integrations.set_current_run_id"): - with patch("api.tasks.run_integrations.db_client") as mock_db_client: - # Mock workflow run with no gathered context - mock_db_client.get_workflow_run_with_context = AsyncMock( - return_value=(mock_workflow_run, 100) - ) - - # Call the function - await run_integrations_post_workflow_run(None, 1) - - # Verify it returns early after checking gathered_context - mock_db_client.get_workflow_run_with_context.assert_called_once_with(1) - mock_db_client.get_active_integrations_by_organization.assert_not_called() - - -@pytest.mark.asyncio -async def test_run_integrations_stasis_mode(mock_workflow_run): - """Test that stasis mode triggers vendor sync.""" - - mock_workflow_run.mode = WorkflowRunMode.STASIS.value - mock_workflow_run.initial_context = { - "vendor": "test_vendor", - "vendor_base_url": "https://api.vendor.com", - "vendor_id": "123", - } - - with patch("api.tasks.run_integrations.set_current_run_id"): - with patch("api.tasks.run_integrations.db_client") as mock_db_client: - with patch("api.tasks.run_integrations._sync_vendor_data") as mock_sync: - mock_sync.return_value = None - - mock_db_client.get_workflow_run_with_context = AsyncMock( - return_value=(mock_workflow_run, 100) - ) - mock_db_client.get_active_integrations_by_organization = AsyncMock( - return_value=[] - ) - - # Call the function - await run_integrations_post_workflow_run(None, 1) - - # Verify vendor sync was called - mock_sync.assert_called_once_with( - mock_workflow_run.initial_context, - mock_workflow_run.gathered_context, - ) - - -@pytest.mark.asyncio -async def test_run_integrations_multiple_integrations(mock_workflow_run): - """Test processing multiple integrations.""" - - # Create multiple mock integrations - slack_integration = MagicMock() - slack_integration.provider = "slack" - slack_integration.connection_details = { - "connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test1"} - } - - slack_integration2 = MagicMock() - slack_integration2.provider = "slack" - slack_integration2.connection_details = { - "connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test2"} - } - - with patch("api.tasks.run_integrations.set_current_run_id"): - with patch("api.tasks.run_integrations.db_client") as mock_db_client: - mock_db_client.get_workflow_run_with_context = AsyncMock( - return_value=(mock_workflow_run, 100) - ) - mock_db_client.get_active_integrations_by_organization = AsyncMock( - return_value=[slack_integration, slack_integration2] - ) - mock_db_client.get_configuration_value = AsyncMock( - return_value={"slack": {"DISPOSITION_CODE": "Test message"}} - ) - - with patch( - "api.tasks.run_integrations.aiohttp.ClientSession" - ) as mock_session_class: - mock_response = MagicMock() - mock_response.status = 200 - - mock_session = MagicMock() - mock_session.__aenter__.return_value = mock_session - mock_session.__aexit__.return_value = AsyncMock() - - mock_post = MagicMock() - mock_post.__aenter__.return_value = mock_response - mock_post.__aexit__.return_value = AsyncMock() - - mock_session.post.return_value = mock_post - mock_session_class.return_value = mock_session - - # Call the function - await run_integrations_post_workflow_run(None, 1) - - # Verify both integrations were processed - assert mock_session.post.call_count == 2 - - # Check that both webhooks were called - call_urls = [call[0][0] for call in mock_session.post.call_args_list] - assert "https://hooks.slack.com/test1" in call_urls - assert "https://hooks.slack.com/test2" in call_urls diff --git a/api/tests/test_run_integrations_template.py b/api/tests/test_run_integrations_template.py deleted file mode 100644 index cb4fa71..0000000 --- a/api/tests/test_run_integrations_template.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Tests for webhook execution in run_integrations.py.""" - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from api.tasks.run_integrations import ( - _build_auth_header, - _build_render_context, - _execute_webhook_node, -) - - -@pytest.fixture(autouse=True) -def mock_logger(): - """Mock the logger for all tests.""" - with patch("api.tasks.run_integrations.logger") as mock_log: - mock_log.bind.return_value = mock_log - yield mock_log - - -class TestBuildAuthHeader: - """Tests for _build_auth_header function.""" - - def test_bearer_token(self): - """Test bearer token auth header.""" - credential = MagicMock() - credential.credential_type = "bearer_token" - credential.credential_data = {"token": "my-secret-token"} - - result = _build_auth_header(credential) - assert result == {"Authorization": "Bearer my-secret-token"} - - def test_api_key(self): - """Test API key auth header.""" - credential = MagicMock() - credential.credential_type = "api_key" - credential.credential_data = {"header_name": "X-API-Key", "api_key": "key123"} - - result = _build_auth_header(credential) - assert result == {"X-API-Key": "key123"} - - def test_api_key_default_header(self): - """Test API key with default header name.""" - credential = MagicMock() - credential.credential_type = "api_key" - credential.credential_data = {"api_key": "key123"} - - result = _build_auth_header(credential) - assert result == {"X-API-Key": "key123"} - - def test_basic_auth(self): - """Test basic auth header.""" - credential = MagicMock() - credential.credential_type = "basic_auth" - credential.credential_data = {"username": "user", "password": "pass"} - - result = _build_auth_header(credential) - # base64 of "user:pass" is "dXNlcjpwYXNz" - assert result == {"Authorization": "Basic dXNlcjpwYXNz"} - - def test_custom_header(self): - """Test custom header auth.""" - credential = MagicMock() - credential.credential_type = "custom_header" - credential.credential_data = { - "header_name": "X-Custom-Auth", - "header_value": "custom-value", - } - - result = _build_auth_header(credential) - assert result == {"X-Custom-Auth": "custom-value"} - - def test_unknown_type(self): - """Test unknown credential type returns empty dict.""" - credential = MagicMock() - credential.credential_type = "unknown" - credential.credential_data = {} - - result = _build_auth_header(credential) - assert result == {} - - -class TestBuildRenderContext: - """Tests for _build_render_context function.""" - - def test_basic_context(self): - """Test building render context from workflow run.""" - workflow_run = MagicMock() - workflow_run.id = 123 - workflow_run.name = "WR-TEST-001" - workflow_run.workflow_id = 456 - workflow_run.workflow.name = "Test Workflow" - workflow_run.initial_context = {"phone_number": "+1234567890"} - workflow_run.gathered_context = { - "customer_name": "John", - "mapped_call_disposition": "QUALIFIED", - } - workflow_run.usage_info = {"call_duration_seconds": 120} - workflow_run.completed_at = None - - result = _build_render_context(workflow_run) - - assert result["workflow_run_id"] == 123 - assert result["workflow_run_name"] == "WR-TEST-001" - assert result["workflow_id"] == 456 - assert result["workflow_name"] == "Test Workflow" - assert result["initial_context"]["phone_number"] == "+1234567890" - assert result["gathered_context"]["customer_name"] == "John" - assert result["cost_info"]["call_duration_seconds"] == 120 - assert result["disposition_code"] == "QUALIFIED" - - def test_empty_contexts(self): - """Test with empty/None contexts.""" - workflow_run = MagicMock() - workflow_run.id = 1 - workflow_run.name = "Test" - workflow_run.workflow_id = 1 - workflow_run.workflow.name = "Workflow" - workflow_run.initial_context = None - workflow_run.gathered_context = None - workflow_run.usage_info = None - workflow_run.completed_at = None - - result = _build_render_context(workflow_run) - - assert result["initial_context"] == {} - assert result["gathered_context"] == {} - assert result["cost_info"] == {} - assert result["disposition_code"] is None - - -class TestExecuteWebhookNode: - """Tests for _execute_webhook_node function.""" - - @pytest.mark.asyncio - async def test_disabled_webhook_skipped(self): - """Test that disabled webhooks are skipped.""" - webhook_data = {"name": "Test Webhook", "enabled": False} - - result = await _execute_webhook_node( - webhook_data=webhook_data, - render_context={}, - organization_id=1, - ) - - assert result is True # Returns True for skipped webhooks - - @pytest.mark.asyncio - async def test_missing_url_returns_false(self): - """Test that missing endpoint URL returns False.""" - webhook_data = {"name": "Test Webhook", "enabled": True, "endpoint_url": None} - - result = await _execute_webhook_node( - webhook_data=webhook_data, - render_context={}, - organization_id=1, - ) - - assert result is False - - @pytest.mark.asyncio - async def test_successful_post_request(self): - """Test successful POST webhook execution.""" - webhook_data = { - "name": "CRM Sync", - "enabled": True, - "http_method": "POST", - "endpoint_url": "https://api.example.com/webhook", - "payload_template": { - "call_id": "{{workflow_run_id}}", - "phone": "{{initial_context.phone_number}}", - }, - } - - render_context = { - "workflow_run_id": 123, - "initial_context": {"phone_number": "+1234567890"}, - } - - with patch("api.tasks.run_integrations.db_client") as mock_db: - mock_db.get_credential_by_uuid = AsyncMock(return_value=None) - - with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client: - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.raise_for_status = MagicMock() - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value = mock_client_instance - - result = await _execute_webhook_node( - webhook_data=webhook_data, - render_context=render_context, - organization_id=1, - ) - - assert result is True - - # Verify the request was made correctly - mock_client_instance.request.assert_called_once() - call_kwargs = mock_client_instance.request.call_args[1] - assert call_kwargs["method"] == "POST" - assert call_kwargs["url"] == "https://api.example.com/webhook" - assert call_kwargs["json"] == { - "call_id": "123", - "phone": "+1234567890", - } - - @pytest.mark.asyncio - async def test_webhook_with_credential(self): - """Test webhook execution with credential auth.""" - webhook_data = { - "name": "Authenticated Webhook", - "enabled": True, - "http_method": "POST", - "endpoint_url": "https://api.example.com/webhook", - "credential_uuid": "cred-123", - "payload_template": {}, - } - - mock_credential = MagicMock() - mock_credential.name = "API Key" - mock_credential.credential_type = "bearer_token" - mock_credential.credential_data = {"token": "secret-token"} - - with patch("api.tasks.run_integrations.db_client") as mock_db: - mock_db.get_credential_by_uuid = AsyncMock(return_value=mock_credential) - - with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client: - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.raise_for_status = MagicMock() - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value = mock_client_instance - - result = await _execute_webhook_node( - webhook_data=webhook_data, - render_context={}, - organization_id=1, - ) - - assert result is True - - # Verify auth header was included - call_kwargs = mock_client_instance.request.call_args[1] - assert call_kwargs["headers"]["Authorization"] == "Bearer secret-token" - - @pytest.mark.asyncio - async def test_webhook_with_custom_headers(self): - """Test webhook execution with custom headers.""" - webhook_data = { - "name": "Custom Headers Webhook", - "enabled": True, - "http_method": "POST", - "endpoint_url": "https://api.example.com/webhook", - "custom_headers": [ - {"key": "X-Source", "value": "dograh"}, - {"key": "X-Workflow", "value": "test"}, - ], - "payload_template": {}, - } - - with patch("api.tasks.run_integrations.db_client") as mock_db: - mock_db.get_credential_by_uuid = AsyncMock(return_value=None) - - with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client: - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.raise_for_status = MagicMock() - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value = mock_client_instance - - result = await _execute_webhook_node( - webhook_data=webhook_data, - render_context={}, - organization_id=1, - ) - - assert result is True - - # Verify custom headers were included - call_kwargs = mock_client_instance.request.call_args[1] - assert call_kwargs["headers"]["X-Source"] == "dograh" - assert call_kwargs["headers"]["X-Workflow"] == "test" - - @pytest.mark.asyncio - async def test_webhook_http_error(self): - """Test webhook execution with HTTP error.""" - import httpx - - webhook_data = { - "name": "Failing Webhook", - "enabled": True, - "http_method": "POST", - "endpoint_url": "https://api.example.com/webhook", - "payload_template": {}, - } - - with patch("api.tasks.run_integrations.db_client") as mock_db: - mock_db.get_credential_by_uuid = AsyncMock(return_value=None) - - with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client: - mock_response = MagicMock() - mock_response.status_code = 500 - mock_response.text = "Internal Server Error" - mock_response.raise_for_status = MagicMock( - side_effect=httpx.HTTPStatusError( - "Server Error", - request=MagicMock(), - response=mock_response, - ) - ) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value = mock_client_instance - - result = await _execute_webhook_node( - webhook_data=webhook_data, - render_context={}, - organization_id=1, - ) - - assert result is False diff --git a/api/tests/test_s3_signed_url.py b/api/tests/test_s3_signed_url.py deleted file mode 100644 index e299696..0000000 --- a/api/tests/test_s3_signed_url.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Tests for the `/s3/signed-url` endpoint. - -This test-suite verifies: -1. Regular users can retrieve signed URLs for resources belonging to their own workflow runs. -2. Regular users are *forbidden* from accessing resources that belong to other users. -3. Superusers can access any resource irrespective of ownership. -""" - -import os -from unittest.mock import AsyncMock - -import pytest -from fastapi import status - -# Ensure the S3 environment variables exist so that the module import does not fail -os.environ.setdefault("S3_BUCKET", "test-bucket") -os.environ.setdefault("S3_REGION", "us-east-1") - - -@pytest.mark.asyncio -async def test_signed_url_for_own_run(monkeypatch, test_client_factory, db_session): - """A normal user should be able to fetch a signed URL for their own workflow run.""" - from api.db.models import UserModel - - # ------------------------------------------------------------------ - # 1. Set-up – create user, workflow & workflow run - # ------------------------------------------------------------------ - user: UserModel = await db_session.get_or_create_user_by_provider_id("user_own_run") - workflow = await db_session.create_workflow("wf", {}, user.id) - run = await db_session.create_workflow_run("run", workflow.id, "chat", user.id) - - key = f"transcripts/{run.id}.txt" - - # Patch S3 signed-url generator to avoid network calls - monkeypatch.setattr( - "api.services.filesystem.s3.s3_fs.aget_signed_url", - AsyncMock(return_value="https://signed-url"), - ) - - async with test_client_factory(user) as client: - response = await client.get(f"/api/v1/s3/signed-url?key={key}") - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data == {"url": "https://signed-url", "expires_in": 3600} - - -@pytest.mark.asyncio -async def test_signed_url_for_other_users_run_forbidden( - monkeypatch, test_client_factory, db_session -): - """A normal user must *not* access workflow runs owned by someone else.""" - from api.db.models import UserModel - - # Owner of the workflow run - owner: UserModel = await db_session.get_or_create_user_by_provider_id("owner_user") - workflow = await db_session.create_workflow("wf", {}, owner.id) - run = await db_session.create_workflow_run("run", workflow.id, "chat", owner.id) - - # Second user attempting access - intruder: UserModel = await db_session.get_or_create_user_by_provider_id( - "intruder_user" - ) - - key = f"recordings/{run.id}.wav" - - monkeypatch.setattr( - "api.services.filesystem.s3.s3_fs.aget_signed_url", - AsyncMock(return_value="https://signed-url"), - ) - - async with test_client_factory(intruder) as client: - response = await client.get(f"/api/v1/s3/signed-url?key={key}") - - assert response.status_code == status.HTTP_403_FORBIDDEN - - -@pytest.mark.asyncio -async def test_superuser_can_access_any_run( - monkeypatch, test_client_factory, db_session -): - """Superusers should be able to fetch signed URLs for any workflow run.""" - from api.db.models import UserModel - - # Normal user & run owner - owner: UserModel = await db_session.get_or_create_user_by_provider_id( - "owner_of_run" - ) - workflow = await db_session.create_workflow("wf", {}, owner.id) - run = await db_session.create_workflow_run("run", workflow.id, "chat", owner.id) - - # Superuser - superuser: UserModel = await db_session.get_or_create_user_by_provider_id( - "admin_user" - ) - - # Promote to superuser - # We need to commit the change so that the DB reflects it - async with db_session.async_session() as session: - db_user = await session.get(UserModel, superuser.id) - db_user.is_superuser = True - await session.commit() - await session.refresh(db_user) # ensure we have the latest state - superuser.is_superuser = True - - key = f"transcripts/{run.id}.txt" - - monkeypatch.setattr( - "api.services.filesystem.s3.s3_fs.aget_signed_url", - AsyncMock(return_value="https://signed-url"), - ) - - async with test_client_factory(superuser) as client: - response = await client.get(f"/api/v1/s3/signed-url?key={key}") - - assert response.status_code == status.HTTP_200_OK - assert response.json()["url"] == "https://signed-url" diff --git a/api/tests/test_s3_upload_tasks.py b/api/tests/test_s3_upload_tasks.py deleted file mode 100644 index 6e6dadc..0000000 --- a/api/tests/test_s3_upload_tasks.py +++ /dev/null @@ -1,129 +0,0 @@ -import os -import tempfile -from unittest.mock import AsyncMock, patch - -import pytest - -from api.tasks.s3_upload import upload_audio_to_s3, upload_transcript_to_s3 - - -@pytest.mark.asyncio -async def test_upload_audio_to_s3_success(): - """Test successful audio upload to S3.""" - # Create a temporary file - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf: - tf.write(b"fake audio data") - temp_path = tf.name - - try: - # Mock dependencies - mock_ctx = AsyncMock() - mock_s3_fs = AsyncMock() - mock_db_client = AsyncMock() - - with ( - patch("api.tasks.s3_upload.s3_fs", mock_s3_fs), - patch("api.tasks.s3_upload.db_client", mock_db_client), - ): - await upload_audio_to_s3( - mock_ctx, workflow_run_id=123, temp_file_path=temp_path - ) - - # Verify S3 upload was called - mock_s3_fs.aupload_file.assert_called_once_with( - temp_path, "recordings/123.wav" - ) - - # Verify DB update was called - mock_db_client.update_workflow_run.assert_called_once_with( - run_id=123, recording_url="recordings/123.wav" - ) - - # Verify temp file was cleaned up - assert not os.path.exists(temp_path) - - finally: - # Clean up if test failed - if os.path.exists(temp_path): - os.remove(temp_path) - - -@pytest.mark.asyncio -async def test_upload_audio_to_s3_file_not_found(): - """Test audio upload when temp file doesn't exist.""" - mock_ctx = AsyncMock() - - with pytest.raises(FileNotFoundError): - await upload_audio_to_s3( - mock_ctx, workflow_run_id=123, temp_file_path="/nonexistent/file.wav" - ) - - -@pytest.mark.asyncio -async def test_upload_transcript_to_s3_success(): - """Test successful transcript upload to S3.""" - # Create a temporary file - with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tf: - tf.write("Test transcript content") - temp_path = tf.name - - try: - # Mock dependencies - mock_ctx = AsyncMock() - mock_s3_fs = AsyncMock() - mock_db_client = AsyncMock() - - with ( - patch("api.tasks.s3_upload.s3_fs", mock_s3_fs), - patch("api.tasks.s3_upload.db_client", mock_db_client), - ): - await upload_transcript_to_s3( - mock_ctx, workflow_run_id=456, temp_file_path=temp_path - ) - - # Verify S3 upload was called - mock_s3_fs.aupload_file.assert_called_once_with( - temp_path, "transcripts/456.txt" - ) - - # Verify DB update was called - mock_db_client.update_workflow_run.assert_called_once_with( - run_id=456, transcript_url="transcripts/456.txt" - ) - - # Verify temp file was cleaned up - assert not os.path.exists(temp_path) - - finally: - # Clean up if test failed - if os.path.exists(temp_path): - os.remove(temp_path) - - -@pytest.mark.asyncio -async def test_upload_s3_cleanup_on_error(): - """Test that temp files are cleaned up even when S3 upload fails.""" - # Create a temporary file - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf: - tf.write(b"fake audio data") - temp_path = tf.name - - try: - mock_ctx = AsyncMock() - mock_s3_fs = AsyncMock() - # Make S3 upload fail - mock_s3_fs.aupload_file.side_effect = Exception("S3 upload failed") - - with patch("api.tasks.s3_upload.s3_fs", mock_s3_fs): - with pytest.raises(Exception): - await upload_audio_to_s3( - mock_ctx, workflow_run_id=123, temp_file_path=temp_path - ) - - # Verify temp file was still cleaned up - assert not os.path.exists(temp_path) - - finally: - # Clean up if test failed - if os.path.exists(temp_path): - os.remove(temp_path) diff --git a/api/tests/test_template_renderer.py b/api/tests/test_template_renderer.py deleted file mode 100644 index 727d4fe..0000000 --- a/api/tests/test_template_renderer.py +++ /dev/null @@ -1,89 +0,0 @@ -from api.utils.template_renderer import render_template - - -def test_render_template_basic(): - """Test basic template rendering.""" - template = "Hello {{name}}, your balance is {{balance}}." - context = {"name": "John", "balance": "$1000"} - - result = render_template(template, context) - assert result == "Hello John, your balance is $1000." - - -def test_render_template_with_spaces(): - """Test template rendering with spaces around variables.""" - template = "Hello {{ name }}, your balance is {{ balance }}." - context = {"name": "John", "balance": "$1000"} - - result = render_template(template, context) - assert result == "Hello John, your balance is $1000." - - -def test_render_template_missing_variable(): - """Test template rendering with missing variables.""" - template = "Hello {{name}}, your balance is {{balance}}." - context = {"name": "John"} - - result = render_template(template, context) - assert result == "Hello John, your balance is ." - - -def test_render_template_with_fallback(): - """Test template rendering with fallback values.""" - template = "Hello {{name | fallback}}, your balance is {{balance | fallback:$0}}." - context = {} - - result = render_template(template, context) - assert result == "Hello Name, your balance is $0." - - -def test_render_template_with_fallback_existing_value(): - """Test that fallback is not used when value exists.""" - template = "Hello {{name | fallback:Guest}}" - context = {"name": "John"} - - result = render_template(template, context) - assert result == "Hello John" - - -def test_render_template_with_line_breaks(): - """Test template rendering with line breaks.""" - template = ( - "DISPOSITION_CODE: {{call_disposition}}\\nCALL_DURATION: {{call_duration}}" - ) - context = {"call_disposition": "XFER", "call_duration": "300"} - - result = render_template(template, context) - expected = "DISPOSITION_CODE: XFER\nCALL_DURATION: 300" - assert result == expected - - -def test_render_template_empty(): - """Test rendering empty template.""" - assert render_template("", {}) == "" - assert render_template(None, {}) == None - - -def test_render_template_no_placeholders(): - """Test template with no placeholders.""" - template = "This is a plain text message" - result = render_template(template, {"unused": "value"}) - assert result == "This is a plain text message" - - -def test_render_template_none_values(): - """Test template with None values.""" - template = "Value: {{value}}" - context = {"value": None} - - result = render_template(template, context) - assert result == "Value: " - - -def test_render_template_numeric_values(): - """Test template with numeric values.""" - template = "Count: {{count}}, Price: {{price}}" - context = {"count": 42, "price": 19.99} - - result = render_template(template, context) - assert result == "Count: 42, Price: 19.99" diff --git a/api/tests/test_usage_concurrency.py b/api/tests/test_usage_concurrency.py deleted file mode 100644 index d775267..0000000 --- a/api/tests/test_usage_concurrency.py +++ /dev/null @@ -1,152 +0,0 @@ -#!/usr/bin/env python -""" -Test script to verify atomic operations in organization_usage_client.py -This simulates concurrent access from multiple processes. -""" - -import asyncio -import os -from concurrent.futures import ProcessPoolExecutor - -# Set up environment -os.environ.setdefault("DATABASE_URL", os.environ.get("DATABASE_URL", "")) - -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.orm import sessionmaker - -from api.db.organization_usage_client import OrganizationUsageClient - - -async def reserve_quota_process(org_id: int, tokens: int, process_id: int): - """Simulate a process trying to reserve quota.""" - engine = create_async_engine(os.environ["DATABASE_URL"]) - async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) - - client = OrganizationUsageClient(async_session) - - results = [] - for i in range(5): - result = await client.check_and_reserve_quota(org_id, tokens) - results.append((process_id, i, result)) - await asyncio.sleep(0.01) # Small delay to increase contention - - await engine.dispose() - return results - - -async def update_usage_process(org_id: int, tokens: int, process_id: int): - """Simulate a process updating usage after runs.""" - engine = create_async_engine(os.environ["DATABASE_URL"]) - async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) - - client = OrganizationUsageClient(async_session) - - for i in range(5): - await client.update_usage_after_run(org_id, tokens, duration_seconds=10) - await asyncio.sleep(0.01) - - await engine.dispose() - return f"Process {process_id} completed updates" - - -def run_reserve_quota(args): - """Wrapper to run async function in process.""" - org_id, tokens, process_id = args - return asyncio.run(reserve_quota_process(org_id, tokens, process_id)) - - -def run_update_usage(args): - """Wrapper to run async function in process.""" - org_id, tokens, process_id = args - return asyncio.run(update_usage_process(org_id, tokens, process_id)) - - -async def test_concurrent_quota_reservation(): - """Test that concurrent quota reservations are handled atomically.""" - print("Testing concurrent quota reservations...") - - # Assuming org_id 1 exists with quota enabled - org_id = 1 - tokens_per_request = 100 - - # Run multiple processes trying to reserve quota simultaneously - with ProcessPoolExecutor(max_workers=3) as executor: - futures = [] - for i in range(3): - futures.append( - executor.submit(run_reserve_quota, (org_id, tokens_per_request, i)) - ) - - results = [] - for future in futures: - results.extend(future.result()) - - print(f"Reservation results: {results}") - - # Check that reservations were handled atomically - successful_reservations = sum(1 for _, _, success in results if success) - print(f"Successful reservations: {successful_reservations}") - - -async def test_concurrent_usage_updates(): - """Test that concurrent usage updates are handled atomically.""" - print("\nTesting concurrent usage updates...") - - org_id = 1 - tokens_per_update = 50 - - # Get initial usage - engine = create_async_engine(os.environ["DATABASE_URL"]) - async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) - client = OrganizationUsageClient(async_session) - - initial_usage = await client.get_current_usage(org_id) - initial_tokens = initial_usage["used_dograh_tokens"] - print(f"Initial tokens: {initial_tokens}") - - # Run multiple processes updating usage simultaneously - with ProcessPoolExecutor(max_workers=3) as executor: - futures = [] - for i in range(3): - futures.append( - executor.submit(run_update_usage, (org_id, tokens_per_update, i)) - ) - - for future in futures: - print(future.result()) - - # Check final usage - final_usage = await client.get_current_usage(org_id) - final_tokens = final_usage["used_dograh_tokens"] - expected_tokens = initial_tokens + ( - 3 * 5 * tokens_per_update - ) # 3 processes * 5 updates * 50 tokens - - print(f"Final tokens: {final_tokens}") - print(f"Expected tokens: {expected_tokens}") - print(f"Difference: {final_tokens - expected_tokens}") - - await engine.dispose() - - if final_tokens == expected_tokens: - print("✅ All updates were applied atomically!") - else: - print("❌ Some updates were lost due to race conditions!") - - -async def main(): - """Run all concurrency tests.""" - try: - await test_concurrent_quota_reservation() - await test_concurrent_usage_updates() - except Exception as e: - print(f"Error during testing: {e}") - import traceback - - traceback.print_exc() - - -if __name__ == "__main__": - print("Starting organization usage concurrency tests...") - print(f"Using DATABASE_URL: {os.environ.get('DATABASE_URL', 'NOT SET')}") - asyncio.run(main()) diff --git a/api/tests/test_variable_extraction.py b/api/tests/test_variable_extraction.py deleted file mode 100644 index 552a1ac..0000000 --- a/api/tests/test_variable_extraction.py +++ /dev/null @@ -1,140 +0,0 @@ -import json -import os -from unittest.mock import AsyncMock, patch - -import pytest -from pipecat.services.openai.llm import OpenAILLMContext - -from api.services.workflow.dto import ExtractionVariableDTO, VariableType -from api.services.workflow.pipecat_engine_variable_extractor import ( - VariableExtractionManager, -) - - -class DummyLLM: - """A minimal stub that mimics the parts of an LLM service used by the extractor.""" - - def __init__(self, streamed_response: str | None = None): - # Optionally provide a pre-defined streaming response for _perform_extraction tests - self._streamed_response = streamed_response or "{}" - self.registered_functions: dict[str, AsyncMock] = {} - - # ------------------------------------------------------------------ - # API used by VariableExtractionManager - # ------------------------------------------------------------------ - def register_function(self, name: str, func, cancel_on_interruption=True): # noqa: D401 – simple delegate - self.registered_functions[name] = func - - async def get_chat_completions(self, _context, _messages): - """Return an async generator that yields a single chunk with the full response.""" - - class _Delta: # noqa: D401 – tiny helper classes for stub response - def __init__(self, content): - self.content = content - - class _Choice: - def __init__(self, delta): - self.delta = delta - - class _Chunk: - def __init__(self, content): - self.choices = [_Choice(_Delta(content))] - - async def _stream(): - yield _Chunk(self._streamed_response) - - return _stream() - - -class DummyEngine: - """A bare-bones Engine stub exposing only what the extractor relies on.""" - - def __init__(self, llm): - self.llm = llm - self.context = OpenAILLMContext() - self._pending_function_calls = 0 - # VariableExtractionManager currently updates this private attribute - self._gathered_context: dict = {} - - -# ------------------------------------------------------------------ -# Tests -# ------------------------------------------------------------------ - - -@pytest.mark.asyncio -async def test_perform_extraction_parses_json_correctly(): - """_perform_extraction should return the parsed JSON from the LLM stream.""" - # Set dummy OpenAI API key to prevent initialization errors - with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}): - expected_payload = {"name": "Alice", "age": 30} - llm = DummyLLM(json.dumps(expected_payload)) - engine = DummyEngine(llm) - manager = VariableExtractionManager(engine) - - # Mock the AsyncOpenAI client and its response - mock_response = AsyncMock() - mock_response.choices = [AsyncMock()] - mock_response.choices[0].message = AsyncMock() - mock_response.choices[0].message.content = json.dumps(expected_payload) - - mock_client = AsyncMock() - mock_client.chat.completions.create.return_value = mock_response - - with patch( - "api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI", - return_value=mock_client, - ): - # Minimal set of variables to extract – the prompts themselves are irrelevant here - extraction_variables = [ - ExtractionVariableDTO( - name="name", type=VariableType.string, prompt="user name" - ), - ExtractionVariableDTO( - name="age", type=VariableType.number, prompt="user age" - ), - ] - - result = await manager._perform_extraction( - extraction_variables, parent_ctx=None, extraction_prompt="" - ) - - assert result == expected_payload - - -@pytest.mark.asyncio -async def test_perform_extraction_with_custom_system_prompt(): - """_perform_extraction should use the provided extraction_prompt as system prompt.""" - # Set dummy OpenAI API key to prevent initialization errors - with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}): - expected_payload = {"color": "blue"} - llm = DummyLLM(json.dumps(expected_payload)) - engine = DummyEngine(llm) - manager = VariableExtractionManager(engine) - - # Mock the AsyncOpenAI client and its response - mock_response = AsyncMock() - mock_response.choices = [AsyncMock()] - mock_response.choices[0].message = AsyncMock() - mock_response.choices[0].message.content = json.dumps(expected_payload) - - mock_client = AsyncMock() - mock_client.chat.completions.create.return_value = mock_response - - with patch( - "api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI", - return_value=mock_client, - ): - extraction_variables = [ - ExtractionVariableDTO( - name="color", type=VariableType.string, prompt="favourite color" - ) - ] - - # Call with a custom extraction prompt - custom_prompt = "You are a color extraction specialist." - result = await manager._perform_extraction( - extraction_variables, parent_ctx=None, extraction_prompt=custom_prompt - ) - - assert result == expected_payload diff --git a/api/tests/test_voicemail_detection_rtc.py b/api/tests/test_voicemail_detection_rtc.py deleted file mode 100644 index 90a7d7d..0000000 --- a/api/tests/test_voicemail_detection_rtc.py +++ /dev/null @@ -1,547 +0,0 @@ -""" -Test voicemail detection in RTC connection flow. - -This test emulates how a call is connected using SmallWebRTC, -triggers voicemail detection, and verifies the disconnect reason. -""" - -import json -from unittest.mock import AsyncMock, Mock, patch - -import pytest -from pipecat.utils.enums import EndTaskReason - -from api.routes.rtc_offer import RTCOfferRequest, offer -from api.services.workflow.pipecat_engine_voicemail_detector import VoicemailDetector - - -@pytest.mark.asyncio -class TestVoicemailDetectionRTC: - """Test voicemail detection through RTC connection flow.""" - - async def test_voicemail_detection_full_flow(self): - """ - Test complete voicemail detection flow: - 1. RTC connection request - 2. Transport sends on_client_connected event - 3. Engine initializes with voicemail detection enabled - 4. Voicemail detector returns true - 5. Call terminates with voicemail_detected reason - 6. Transport sends on_client_disconnected event - 7. Disconnect reason is properly set - """ - # Mock user and authentication - mock_user = Mock() - mock_user.id = 1 - mock_user.organization_id = 1 - - # Mock workflow with voicemail detection enabled - mock_workflow = Mock() - mock_workflow.id = 100 - mock_workflow.workflow_definition_with_fallback = { - "edges": [], - "nodes": [ - { - "id": "start", - "type": "start", - "data": { - "detect_voicemail": True, - "system_prompt": "You are a helpful assistant", - }, - } - ], - } - - # Mock workflow run - mock_workflow_run = Mock() - mock_workflow_run.id = 200 - mock_workflow_run.is_completed = False - - # Create request - request = RTCOfferRequest( - pc_id="test_pc_123", - sdp="test_sdp_offer", - type="offer", - workflow_id=mock_workflow.id, - workflow_run_id=mock_workflow_run.id, - restart_pc=False, - call_context_vars={"test_var": "test_value"}, - ) - - # Mock dependencies - with ( - patch("api.services.auth.depends.get_user") as mock_get_user_dep, - patch("api.routes.rtc_offer.SmallWebRTCConnection") as MockWebRTCConnection, - patch("api.routes.rtc_offer.run_pipeline_smallwebrtc") as mock_run_pipeline, - ): - # Setup mocks - mock_get_user_dep.return_value = mock_user - - # Mock WebRTC connection - mock_connection = Mock() - mock_connection.pc_id = "test_pc_123" - mock_connection.initialize = AsyncMock() - mock_connection.get_answer = Mock( - return_value={ - "pc_id": "test_pc_123", - "sdp": "test_sdp_answer", - "type": "answer", - } - ) - MockWebRTCConnection.return_value = mock_connection - - # Track registered event handlers - registered_handlers = {} - - def mock_event_handler(event_name): - def decorator(func): - registered_handlers[event_name] = func - return func - - return decorator - - mock_connection.event_handler = mock_event_handler - - # Mock BackgroundTasks - mock_background_tasks = Mock() - - # Create the offer - response = await offer(request, mock_background_tasks, mock_user) - - # Verify response - assert response["pc_id"] == "test_pc_123" - assert response["type"] == "answer" - - # Verify connection was initialized - mock_connection.initialize.assert_called_once_with( - sdp="test_sdp_offer", type="offer" - ) - - # Verify background task was added - mock_background_tasks.add_task.assert_called_once() - task_args = mock_background_tasks.add_task.call_args[0] - assert task_args[0] == mock_run_pipeline - assert task_args[1] == mock_connection - assert task_args[2] == mock_workflow.id - assert task_args[3] == mock_workflow_run.id - assert task_args[4] == mock_user.id - assert task_args[5] == {"test_var": "test_value"} - - async def test_voicemail_detection_in_pipeline(self): - """Tests whether the updates happen in on_client_disconnected properly - with values set in the engine""" - # Mock components - mock_transport = AsyncMock() - mock_engine = Mock() # Use Mock instead of AsyncMock for engine - mock_engine.initialize = AsyncMock() - mock_engine.cleanup = AsyncMock() - mock_audio_buffer = AsyncMock() - mock_task = AsyncMock() - mock_aggregator = Mock() - - # Setup engine with voicemail detector - mock_voicemail_detector = AsyncMock(spec=VoicemailDetector) - mock_engine.voicemail_detector = mock_voicemail_detector - mock_engine.get_call_disposition = Mock( - return_value=EndTaskReason.VOICEMAIL_DETECTED.value - ) - mock_engine.get_gathered_context = Mock( - return_value={ - "voicemail_transcript": "Hi, you've reached John's voicemail. Please leave a message.", - "voicemail_confidence": 0.95, - } - ) - - # Mock usage metrics - mock_aggregator.get_all_usage_metrics_serialized.return_value = {} - - # Register event handlers - from api.services.pipecat.event_handlers import ( - register_transport_event_handlers, - ) - - # Track registered handlers - handlers = {} - - def track_handler(event_name): - def decorator(func): - handlers[event_name] = func - return func - - return decorator - - mock_transport.event_handler = track_handler - - # Create a mock db_client module with update_workflow_run method - mock_db_client = Mock() - mock_db_client.update_workflow_run = AsyncMock() - - with ( - patch("api.services.pipecat.event_handlers.db_client", mock_db_client), - patch( - "api.services.pipecat.event_handlers.enqueue_job", - new_callable=AsyncMock, - ) as mock_enqueue_job, - patch( - "api.services.pipecat.event_handlers.get_organization_id_from_workflow_run", - return_value=1, - ), - patch( - "api.services.pipecat.event_handlers.apply_disposition_mapping", - side_effect=lambda value, org_id: value, # Return value unchanged - ), - ): - # Register handlers - register_transport_event_handlers( - mock_transport, - workflow_run_id=123, - audio_buffer=mock_audio_buffer, - task=mock_task, - engine=mock_engine, - usage_metrics_aggregator=mock_aggregator, - ) - - # Verify handlers were registered - assert "on_client_connected" in handlers - assert "on_client_disconnected" in handlers - - # Simulate client connection - await handlers["on_client_connected"]( - mock_transport, {"id": "participant_1"} - ) - - # Verify initialization - mock_audio_buffer.start_recording.assert_called_once() - mock_engine.initialize.assert_called_once() - - # Simulate voicemail detection and disconnect - await handlers["on_client_disconnected"]( - mock_transport, {"id": "participant_1"}, None - ) - - # Verify engine cleanup - mock_engine.cleanup.assert_called_once() - - # TODO: check whether task was cancelled or not once have more - # clarity on how to handle engine disconnect vs remote hangup - # Verify task was NOT cancelled (engine disconnect) - # mock_task.cancel.assert_not_called() - - # Verify workflow run was updated with voicemail context - mock_db_client.update_workflow_run.assert_called() - call_args = mock_db_client.update_workflow_run.call_args - assert call_args[1]["run_id"] == 123 - # Check that the mapped_call_disposition was set correctly - assert ( - call_args[1]["gathered_context"]["mapped_call_disposition"] - == "voicemail_detected" - ) - - async def test_voicemail_detector_audio_processing(self): - """Test VoicemailDetector audio processing and detection logic - tests that voicemail detector - calls engine's send_end_task_frame with the correct reason and metadata""" - # Create voicemail detector - detector = VoicemailDetector(detection_duration=5.0, workflow_run_id=123) - - # Mock OpenAI client - mock_openai = AsyncMock() - mock_whisper_response = Mock() - mock_whisper_response.text = "Hi, you've reached the voicemail of John Smith. Please leave a message after the beep." - mock_openai.audio.transcriptions.create.return_value = mock_whisper_response - - mock_gpt_response = Mock() - mock_gpt_response.choices = [Mock()] - mock_gpt_response.choices[0].message.content = json.dumps( - { - "is_voicemail": True, - "confidence": 0.98, - "reasoning": "Clear voicemail greeting with request to leave message", - } - ) - mock_openai.chat.completions.create.return_value = mock_gpt_response - - # Mock engine - mock_engine = AsyncMock() - mock_engine.task = AsyncMock() - - with ( - patch( - "api.services.workflow.pipecat_engine_voicemail_detector.AsyncOpenAI", - return_value=mock_openai, - ), - patch( - "api.services.workflow.pipecat_engine_voicemail_detector.s3_fs" - ) as mock_s3, - ): - # Mock S3 upload to return None (simulating successful upload) - mock_s3.aupload_file = AsyncMock(return_value=True) - # Start detection - await detector.start_detection(mock_engine) - assert detector.is_detecting == True - - # Simulate audio data (16kHz, mono, 5 seconds) - sample_rate = 16000 - duration = 5.0 - audio_data = b"\x00\x00" * int(sample_rate * duration) # Silent audio - - # Process audio in chunks - chunk_size = 1600 # 100ms chunks - for i in range(0, len(audio_data), chunk_size): - chunk = audio_data[i : i + chunk_size] - await detector.handle_audio_data(None, chunk, sample_rate, 1) - - # Wait for detection to complete - if detector._detection_task: - await detector._detection_task - - # Verify OpenAI calls - mock_openai.audio.transcriptions.create.assert_called_once() - mock_openai.chat.completions.create.assert_called_once() - - # Verify send_end_task_frame was called with voicemail detection - mock_engine.send_end_task_frame.assert_called_once_with( - reason=EndTaskReason.VOICEMAIL_DETECTED.value, - additional_metadata={ - "voicemail_transcript": "Hi, you've reached the voicemail of John Smith. Please leave a message after the beep.", - "voicemail_confidence": 0.98, - "voicemail_reasoning": "Clear voicemail greeting with request to leave message", - "voicemail_detection_duration": 5.0, - "voicemail_audio_s3_path": "voicemail_detections/123_voicemail_98_5.wav", # S3 upload returns True, so filename is used - }, - abort_immediately=True, - ) - - async def test_voicemail_detector_no_detection(self): - """Test VoicemailDetector when voicemail is not detected.""" - # Create voicemail detector - detector = VoicemailDetector(detection_duration=5.0, workflow_run_id=124) - - # Mock OpenAI client - mock_openai = AsyncMock() - mock_whisper_response = Mock() - mock_whisper_response.text = "Hello? Hello? Can you hear me?" - mock_openai.audio.transcriptions.create.return_value = mock_whisper_response - - mock_gpt_response = Mock() - mock_gpt_response.choices = [Mock()] - mock_gpt_response.choices[0].message.content = json.dumps( - { - "is_voicemail": False, - "confidence": 0.95, - "reasoning": "Live person speaking, asking if caller can hear them", - } - ) - mock_openai.chat.completions.create.return_value = mock_gpt_response - - # Mock engine - mock_engine = AsyncMock() - mock_engine.task = AsyncMock() - - with patch( - "api.services.workflow.pipecat_engine_voicemail_detector.AsyncOpenAI", - return_value=mock_openai, - ): - # Start detection - await detector.start_detection(mock_engine) - - # Simulate audio data - sample_rate = 16000 - duration = 5.0 - audio_data = b"\x00\x00" * int(sample_rate * duration) - - # Process audio - await detector.handle_audio_data(None, audio_data, sample_rate, 1) - - # Wait for detection - if detector._detection_task: - await detector._detection_task - - # Verify send_end_task_frame was NOT called - mock_engine.send_end_task_frame.assert_not_called() - - async def test_voicemail_detector_cancellation(self): - """Test VoicemailDetector cancellation before completion.""" - # Create voicemail detector - detector = VoicemailDetector(detection_duration=10.0, workflow_run_id=125) - - # Mock engine - mock_engine = AsyncMock() - - # Start detection - await detector.start_detection(mock_engine) - assert detector.is_detecting == True - - # Cancel detection immediately - await detector.stop_detection() - assert detector._is_cancelled == True - - # Try to add audio data after cancellation - await detector.handle_audio_data(None, b"\x00\x00" * 1000, 16000, 1) - - # Verify buffer didn't grow (no audio accepted after cancellation) - assert len(detector.audio_buffer) == 0 - - async def test_disconnect_reason_propagation(self): - """Test that voicemail disconnect reason is properly propagated.""" - # Create disconnect reason info directly - disconnect_info = { - "disposition_code": EndTaskReason.VOICEMAIL_DETECTED.value, - "details": "Voicemail detected after 5 seconds of audio", - "is_remote": False, - "is_user_initiated": False, - "is_successful_transfer": False, - "transport_metadata": { - "voicemail_confidence": 0.97, - "voicemail_transcript": "You've reached voicemail...", - }, - } - - # Verify attributes - assert disconnect_info["disposition_code"] == "voicemail_detected" - assert disconnect_info["is_remote"] == False - assert disconnect_info["is_user_initiated"] == False - assert disconnect_info["is_successful_transfer"] == False - assert ( - disconnect_info["details"] == "Voicemail detected after 5 seconds of audio" - ) - assert disconnect_info["transport_metadata"]["voicemail_confidence"] == 0.97 - - async def test_voicemail_detection_end_to_end(self): - """ - Complete end-to-end test covering: - 1. on_client_connected event - 2. Engine initialization with voicemail detection - 3. Audio processing and voicemail detection - 4. Engine setting disconnect reason - 5. on_client_disconnected event - 6. Proper disconnect reason in workflow run update - """ - # Create comprehensive mocks - from api.services.pipecat.event_handlers import ( - register_transport_event_handlers, - ) - - # Mock transport - mock_transport = AsyncMock() - handlers = {} - - def track_handler(event_name): - def decorator(func): - handlers[event_name] = func - return func - - return decorator - - mock_transport.event_handler = track_handler - - # Mock audio buffer - mock_audio_buffer = Mock() - mock_audio_buffer.start_recording = AsyncMock() - mock_audio_buffer.stop_recording = AsyncMock() - - # Mock task - mock_task = AsyncMock() - - # Mock aggregator - mock_aggregator = Mock() - mock_aggregator.get_all_usage_metrics_serialized.return_value = {} - - # Create a mock engine with voicemail detection - mock_engine = Mock() - mock_engine.initialize = AsyncMock() - mock_engine.cleanup = AsyncMock() - - # Mock voicemail detector - mock_voicemail_detector = Mock() - mock_engine.voicemail_detector = mock_voicemail_detector - mock_engine._voicemail_detector = mock_voicemail_detector - - # Initially no disconnect reason - mock_engine.get_call_disposition = Mock(return_value=None) - mock_engine.get_gathered_context = Mock(return_value={}) - - # Mock db_client - mock_db_client = Mock() - mock_db_client.update_workflow_run = AsyncMock() - - with ( - patch("api.services.pipecat.event_handlers.db_client", mock_db_client), - patch( - "api.services.pipecat.event_handlers.enqueue_job", - new_callable=AsyncMock, - ) as mock_enqueue_job, - patch( - "api.services.pipecat.event_handlers.get_organization_id_from_workflow_run", - return_value=1, - ), - patch( - "api.services.pipecat.event_handlers.apply_disposition_mapping", - side_effect=lambda value, org_id: value, # Return value unchanged - ), - ): - # Register event handlers - register_transport_event_handlers( - mock_transport, - workflow_run_id=123, - audio_buffer=mock_audio_buffer, - task=mock_task, - engine=mock_engine, - usage_metrics_aggregator=mock_aggregator, - ) - - # Verify handlers were registered - assert "on_client_connected" in handlers - assert "on_client_disconnected" in handlers - - # Step 1: Client connects - await handlers["on_client_connected"]( - mock_transport, {"id": "participant_1"} - ) - - # Verify initialization - mock_audio_buffer.start_recording.assert_called_once() - mock_engine.initialize.assert_called_once() - - # Step 2-3: Simulate voicemail detection occurs - # Update engine state to reflect voicemail was detected - mock_engine.get_call_disposition = Mock( - return_value=EndTaskReason.VOICEMAIL_DETECTED.value - ) - mock_engine.get_gathered_context = Mock( - return_value={ - "voicemail_transcript": "You've reached voicemail, leave a message", - "voicemail_confidence": 0.95, - } - ) - - # Step 5: Client disconnects - await handlers["on_client_disconnected"]( - mock_transport, {"id": "participant_1"}, None - ) - - # Verify engine cleanup - mock_engine.cleanup.assert_called_once() - - # Step 6: Verify proper disconnect reason in workflow run update - mock_db_client.update_workflow_run.assert_called() - call_args = mock_db_client.update_workflow_run.call_args - - # Check the gathered context includes disconnect reason - gathered_context = call_args[1]["gathered_context"] - assert gathered_context["mapped_call_disposition"] == "voicemail_detected" - assert gathered_context["voicemail_confidence"] == 0.95 - assert ( - gathered_context["voicemail_transcript"] - == "You've reached voicemail, leave a message" - ) - - # Verify task was NOT cancelled (engine-initiated disconnect) - mock_task.cancel.assert_not_called() - - # Verify audio buffer was stopped - mock_audio_buffer.stop_recording.assert_called_once() - - # Verify background jobs were enqueued - assert ( - mock_enqueue_job.call_count >= 3 - ) # At least 3 jobs should be enqueued diff --git a/api/tests/test_workflow_routes.py b/api/tests/test_workflow_routes.py deleted file mode 100644 index db7af26..0000000 --- a/api/tests/test_workflow_routes.py +++ /dev/null @@ -1,667 +0,0 @@ -""" -Tests for workflow API routes. - -This module tests the create, update, get, and validate workflow endpoints. -The fixtures for database setup, test client, and utilities are in conftest.py. -""" - -import pytest -from fastapi import status - - -@pytest.fixture -def sample_workflow_definition(): - """Sample workflow definition for testing.""" - return { - "nodes": [ - { - "id": "6581", - "type": "startCall", - "position": {"x": 427, "y": 23}, - "data": { - "prompt": "Hello, I am Abhishek from Dograh. ", - "is_static": True, - "name": "Start Call", - "is_start": True, - "invalid": False, - "validationMessage": None, - "allow_interrupt": False, - }, - "measured": {"width": 300, "height": 100}, - "selected": True, - "dragging": False, - }, - { - "id": "915", - "type": "agentNode", - "position": {"x": 305, "y": 340}, - "data": { - "prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent.", - "name": "Agent", - "invalid": False, - "validationMessage": None, - "allow_interrupt": False, - }, - "measured": {"width": 300, "height": 100}, - "selected": False, - "dragging": False, - }, - { - "id": "7598", - "type": "agentNode", - "position": {"x": 90, "y": 650}, - "data": { - "prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon", - "name": "Agent", - "invalid": False, - "validationMessage": None, - "allow_interrupt": True, - }, - "measured": {"width": 300, "height": 100}, - "selected": False, - "dragging": False, - }, - { - "id": "6919", - "type": "agentNode", - "position": {"x": 520, "y": 650}, - "data": { - "prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon", - "name": "Agent", - "invalid": False, - "validationMessage": None, - "allow_interrupt": True, - }, - "measured": {"width": 300, "height": 100}, - "selected": False, - "dragging": False, - }, - { - "id": "1802", - "type": "endCall", - "position": {"x": 305, "y": 960}, - "data": { - "prompt": "Thank you!", - "invalid": False, - "validationMessage": None, - "is_static": True, - "name": "End Call", - "is_end": True, - "allow_interrupt": False, - }, - "measured": {"width": 300, "height": 100}, - "selected": False, - "dragging": False, - }, - ], - "edges": [ - { - "animated": True, - "type": "custom", - "source": "915", - "target": "7598", - "id": "xy-edge__915-7598", - "selected": False, - "data": { - "condition": "The customer wants to talk to a customer service agent", - "label": "customer service agent", - "invalid": False, - "validationMessage": None, - }, - }, - { - "animated": True, - "type": "custom", - "source": "915", - "target": "6919", - "id": "xy-edge__915-6919", - "selected": False, - "data": { - "condition": "customer wants to talk to a sales representative", - "label": "sales representative", - "invalid": False, - "validationMessage": None, - }, - }, - { - "animated": True, - "type": "custom", - "source": "6581", - "target": "915", - "id": "xy-edge__6581-915", - "selected": False, - "data": { - "condition": "Always take this route", - "label": "Always take this route", - "invalid": False, - "validationMessage": None, - }, - }, - { - "animated": True, - "type": "custom", - "source": "7598", - "target": "1802", - "id": "xy-edge__7598-1802", - "selected": False, - "data": { - "condition": "end call", - "label": "end call", - "invalid": False, - "validationMessage": None, - }, - }, - { - "animated": True, - "type": "custom", - "source": "6919", - "target": "1802", - "id": "xy-edge__6919-1802", - "selected": False, - "data": { - "condition": "end call", - "label": "end call", - "invalid": False, - "validationMessage": None, - }, - }, - ], - "viewport": {"x": 0, "y": 0, "zoom": 1}, - } - - -class TestCreateWorkflow: - """Test cases for creating workflows.""" - - async def test_create_workflow_success( - self, test_client_factory, db_session, sample_workflow_definition - ): - """Test successful workflow creation.""" - # Create a test user for this test - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_create_success" - ) - - request_data = { - "name": "Test Workflow", - "workflow_definition": sample_workflow_definition, - } - - async with test_client_factory(test_user) as client: - response = await client.post("/api/v1/workflow/create", json=request_data) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - - assert "id" in data - assert data["name"] == "Test Workflow" - assert data["workflow_definition"] == sample_workflow_definition - assert "created_at" in data - assert "current_definition_id" in data - - async def test_create_workflow_invalid_definition( - self, test_client_factory, db_session - ): - """Test workflow creation with invalid definition.""" - # Create a test user for this test - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_invalid_def" - ) - - request_data = { - "name": "Invalid Workflow", - "workflow_definition": {"invalid": "structure"}, - } - - async with test_client_factory(test_user) as client: - response = await client.post("/api/v1/workflow/create", json=request_data) - - # The API should still create the workflow even with invalid definition - # Validation happens in the validate endpoint - assert response.status_code == status.HTTP_200_OK - - @pytest.mark.asyncio - async def test_create_workflow_missing_name( - self, test_client_factory, db_session, sample_workflow_definition - ): - """Test workflow creation without name.""" - # Create a test user for this test - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_missing_name" - ) - - request_data = {"workflow_definition": sample_workflow_definition} - - async with test_client_factory(test_user) as client: - response = await client.post("/api/v1/workflow/create", json=request_data) - - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - - @pytest.mark.asyncio - async def test_create_workflow_missing_definition( - self, test_client_factory, db_session - ): - """Test workflow creation without workflow definition.""" - # Create a test user for this test - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_missing_definition" - ) - - request_data = {"name": "Test Workflow"} - - async with test_client_factory(test_user) as client: - response = await client.post("/api/v1/workflow/create", json=request_data) - - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - - -class TestGetWorkflows: - """Test cases for fetching workflows.""" - - @pytest.mark.asyncio - async def test_get_all_workflows_empty(self, test_client_factory, db_session): - """Test getting all workflows when none exist.""" - # Create a test user within the test function - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_empty_workflows" - ) - - # Create a test client for this specific user - async with test_client_factory(test_user) as client: - response = await client.get("/api/v1/workflow/fetch") - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert isinstance(data, list) - assert len(data) == 0 - - @pytest.mark.asyncio - async def test_get_all_workflows_with_data( - self, test_client_factory, db_session, sample_workflow_definition - ): - """Test getting all workflows when some exist.""" - # Create a test user within the test function - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_with_workflows" - ) - - # Create a test client for this specific user - async with test_client_factory(test_user) as client: - # Create a workflow first - create_response = await client.post( - "/api/v1/workflow/create", - json={ - "name": "Test Workflow 1", - "workflow_definition": sample_workflow_definition, - }, - ) - assert create_response.status_code == status.HTTP_200_OK - - # Create another workflow - create_response2 = await client.post( - "/api/v1/workflow/create", - json={ - "name": "Test Workflow 2", - "workflow_definition": sample_workflow_definition, - }, - ) - assert create_response2.status_code == status.HTTP_200_OK - - # Get all workflows - response = await client.get("/api/v1/workflow/fetch") - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert isinstance(data, list) - assert len(data) == 2 - - # Check that both workflows are returned - workflow_names = [w["name"] for w in data] - assert "Test Workflow 1" in workflow_names - assert "Test Workflow 2" in workflow_names - - @pytest.mark.asyncio - async def test_get_specific_workflow( - self, test_client_factory, db_session, sample_workflow_definition - ): - """Test getting a specific workflow by ID.""" - # Create a test user for this test - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_specific_workflow" - ) - - async with test_client_factory(test_user) as client: - # Create a workflow first - create_response = await client.post( - "/api/v1/workflow/create", - json={ - "name": "Specific Workflow", - "workflow_definition": sample_workflow_definition, - }, - ) - assert create_response.status_code == status.HTTP_200_OK - created_workflow = create_response.json() - workflow_id = created_workflow["id"] - - # Get the specific workflow - response = await client.get( - f"/api/v1/workflow/fetch?workflow_id={workflow_id}" - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - - assert data["id"] == workflow_id - assert data["name"] == "Specific Workflow" - assert data["workflow_definition"] == sample_workflow_definition - - @pytest.mark.asyncio - async def test_get_nonexistent_workflow(self, test_client_factory, db_session): - """Test getting a workflow that doesn't exist.""" - # Create a test user for this test - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_nonexistent" - ) - - async with test_client_factory(test_user) as client: - response = await client.get("/api/v1/workflow/fetch?workflow_id=99999") - - assert response.status_code == status.HTTP_404_NOT_FOUND - assert "not found" in response.json()["detail"].lower() - - -class TestUpdateWorkflow: - """Test cases for updating workflows.""" - - @pytest.mark.asyncio - async def test_update_workflow_name_only( - self, test_client_factory, db_session, sample_workflow_definition - ): - """Test updating only the workflow name.""" - # Create a test user for this test - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_update_name" - ) - - async with test_client_factory(test_user) as client: - # Create a workflow first - create_response = await client.post( - "/api/v1/workflow/create", - json={ - "name": "Original Name", - "workflow_definition": sample_workflow_definition, - }, - ) - assert create_response.status_code == status.HTTP_200_OK - workflow_id = create_response.json()["id"] - - # Update the workflow name - update_data = {"name": "Updated Name"} - response = await client.put( - f"/api/v1/workflow/{workflow_id}", json=update_data - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - - assert data["id"] == workflow_id - assert data["name"] == "Updated Name" - assert ( - data["workflow_definition"] == sample_workflow_definition - ) # Should remain unchanged - - @pytest.mark.asyncio - async def test_update_workflow_name_and_definition( - self, test_client_factory, db_session, sample_workflow_definition - ): - """Test updating both workflow name and definition.""" - # Create a test user for this test - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_update_both" - ) - - async with test_client_factory(test_user) as client: - # Create a workflow first - create_response = await client.post( - "/api/v1/workflow/create", - json={ - "name": "Original Name", - "workflow_definition": sample_workflow_definition, - }, - ) - assert create_response.status_code == status.HTTP_200_OK - workflow_id = create_response.json()["id"] - - # Create new workflow definition - new_definition = { - "nodes": [ - { - "id": "start", - "type": "start", - "position": {"x": 50, "y": 50}, - "data": {"label": "New Start"}, - } - ], - "edges": [], - } - - # Update the workflow - update_data = { - "name": "Updated Name", - "workflow_definition": new_definition, - } - response = await client.put( - f"/api/v1/workflow/{workflow_id}", json=update_data - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - - assert data["id"] == workflow_id - assert data["name"] == "Updated Name" - assert data["workflow_definition"] == new_definition - - @pytest.mark.asyncio - async def test_update_nonexistent_workflow(self, test_client_factory, db_session): - """Test updating a workflow that doesn't exist.""" - # Create a test user for this test - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_update_nonexistent" - ) - - update_data = {"name": "Updated Name"} - async with test_client_factory(test_user) as client: - response = await client.put("/api/v1/workflow/99999", json=update_data) - - assert response.status_code == status.HTTP_404_NOT_FOUND - assert "not found" in response.json()["detail"].lower() - - @pytest.mark.asyncio - async def test_update_workflow_missing_name( - self, test_client_factory, db_session, sample_workflow_definition - ): - """Test updating a workflow without providing a name.""" - # Create a test user for this test - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_update_missing_name" - ) - - async with test_client_factory(test_user) as client: - # Create a workflow first - create_response = await client.post( - "/api/v1/workflow/create", - json={ - "name": "Original Name", - "workflow_definition": sample_workflow_definition, - }, - ) - assert create_response.status_code == status.HTTP_200_OK - workflow_id = create_response.json()["id"] - - # Try to update without providing name - update_data = {"workflow_definition": sample_workflow_definition} - response = await client.put( - f"/api/v1/workflow/{workflow_id}", json=update_data - ) - - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - - -class TestWorkflowValidation: - """Test cases for workflow validation endpoint.""" - - @pytest.mark.asyncio - async def test_validate_workflow_success( - self, test_client_factory, db_session, sample_workflow_definition - ): - """Test successful workflow validation.""" - # Create a test user for this test - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_validate_success" - ) - - async with test_client_factory(test_user) as client: - # Create a workflow first - create_response = await client.post( - "/api/v1/workflow/create", - json={ - "name": "Valid Workflow", - "workflow_definition": sample_workflow_definition, - }, - ) - assert create_response.status_code == status.HTTP_200_OK - workflow_id = create_response.json()["id"] - - # Validate the workflow - response = await client.post(f"/api/v1/workflow/{workflow_id}/validate") - - assert response.status_code == status.HTTP_200_OK - data = response.json() - - assert data["is_valid"] is True - assert data["errors"] == [] - - @pytest.mark.asyncio - async def test_validate_nonexistent_workflow(self, test_client_factory, db_session): - """Test validating a workflow that doesn't exist.""" - # Create a test user for this test - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_validate_nonexistent" - ) - - async with test_client_factory(test_user) as client: - response = await client.post("/api/v1/workflow/99999/validate") - - assert response.status_code == status.HTTP_404_NOT_FOUND - assert "not found" in response.json()["detail"].lower() - - -class TestWorkflowIntegration: - """Integration tests for workflow operations.""" - - @pytest.mark.asyncio - async def test_full_workflow_lifecycle( - self, test_client_factory, db_session, sample_workflow_definition - ): - """Test the complete lifecycle of a workflow: create, get, update, validate.""" - # Create a test user for this test - test_user = await db_session.get_or_create_user_by_provider_id( - "test_user_lifecycle" - ) - - async with test_client_factory(test_user) as client: - # 1. Create workflow - create_response = await client.post( - "/api/v1/workflow/create", - json={ - "name": "Lifecycle Test Workflow", - "workflow_definition": sample_workflow_definition, - }, - ) - assert create_response.status_code == status.HTTP_200_OK - workflow_id = create_response.json()["id"] - - # 2. Get the created workflow - get_response = await client.get( - f"/api/v1/workflow/fetch?workflow_id={workflow_id}" - ) - assert get_response.status_code == status.HTTP_200_OK - workflow_data = get_response.json() - assert workflow_data["name"] == "Lifecycle Test Workflow" - - # 3. Add a new node in the workflow definition - new_node = { - "id": "6919_new", - "type": "agentNode", - "position": {"x": 520, "y": 650}, - "data": { - "prompt": "Something new", - "name": "Agent", - "invalid": False, - "validationMessage": None, - "allow_interrupt": True, - }, - "measured": {"width": 300, "height": 100}, - "selected": False, - "dragging": False, - } - new_edges = [ - { - "source": "6919", - "target": "6919_new", - "id": "xy-edge__6919-6919_new", - "data": { - "condition": "Always take this route", - "label": "Always take this route", - "invalid": False, - "validationMessage": None, - }, - }, - { - "source": "6919_new", - "target": "1802", - "id": "xy-edge__6919_new-1802", - "data": { - "condition": "Always take this route", - "label": "Always take this route", - "invalid": False, - "validationMessage": None, - }, - }, - ] - new_definition = { - "nodes": [ - *sample_workflow_definition["nodes"], - new_node, - ], - "edges": [ - *sample_workflow_definition["edges"], - *new_edges, - ], - } - - update_response = await client.put( - f"/api/v1/workflow/{workflow_id}", - json={ - "name": "Updated Lifecycle Workflow", - "workflow_definition": new_definition, - }, - ) - assert update_response.status_code == status.HTTP_200_OK - assert update_response.json()["name"] == "Updated Lifecycle Workflow" - - # 4. Validate the updated workflow - validate_response = await client.post( - f"/api/v1/workflow/{workflow_id}/validate" - ) - assert validate_response.status_code == status.HTTP_200_OK - assert validate_response.json()["is_valid"] is True - - # 5. Verify the update by getting the workflow again - final_get_response = await client.get( - f"/api/v1/workflow/fetch?workflow_id={workflow_id}" - ) - assert final_get_response.status_code == status.HTTP_200_OK - final_data = final_get_response.json() - assert final_data["name"] == "Updated Lifecycle Workflow" - assert final_data["workflow_definition"] == new_definition diff --git a/api/utils/credential_auth.py b/api/utils/credential_auth.py new file mode 100644 index 0000000..963ec66 --- /dev/null +++ b/api/utils/credential_auth.py @@ -0,0 +1,95 @@ +"""Build HTTP authentication headers from ExternalCredentialModel. + +This module provides functions for constructing HTTP authentication headers +from ExternalCredentialModel instances. Used by both webhook integrations +and custom tool execution. +""" + +import base64 +from typing import TYPE_CHECKING, Any, Dict, Optional + +if TYPE_CHECKING: + from api.db.models import ExternalCredentialModel + + +def build_auth_header(credential: "ExternalCredentialModel") -> Dict[str, str]: + """Build authentication header based on credential type. + + Supports the following credential types: + - bearer_token: Authorization: Bearer + - api_key: Custom header with API key + - basic_auth: Authorization: Basic + - custom_header: Any custom header name/value pair + + Args: + credential: The ExternalCredentialModel instance + + Returns: + Dict with header name and value, or empty dict if credential type + is not recognized or is 'none' + """ + cred_type = credential.credential_type + cred_data = credential.credential_data or {} + + if cred_type == "bearer_token": + token = cred_data.get("token", "") + return {"Authorization": f"Bearer {token}"} + + elif cred_type == "api_key": + header_name = cred_data.get("header_name", "X-API-Key") + api_key = cred_data.get("api_key", "") + return {header_name: api_key} + + elif cred_type == "basic_auth": + username = cred_data.get("username", "") + password = cred_data.get("password", "") + encoded = base64.b64encode(f"{username}:{password}".encode()).decode() + return {"Authorization": f"Basic {encoded}"} + + elif cred_type == "custom_header": + header_name = cred_data.get("header_name", "X-Custom") + header_value = cred_data.get("header_value", "") + return {header_name: header_value} + + return {} + + +def build_auth_header_from_data( + credential_type: str, + credential_data: Optional[Dict[str, Any]] = None, +) -> Dict[str, str]: + """Build authentication header from raw credential data. + + This is a convenience function when you have credential data + directly rather than a full ExternalCredentialModel. + + Args: + credential_type: Type of credential (bearer_token, api_key, etc.) + credential_data: Dict containing credential-specific fields + + Returns: + Dict with header name and value + """ + cred_data = credential_data or {} + + if credential_type == "bearer_token": + token = cred_data.get("token", "") + return {"Authorization": f"Bearer {token}"} + + elif credential_type == "api_key": + header_name = cred_data.get("header_name", "X-API-Key") + api_key = cred_data.get("api_key", "") + return {header_name: api_key} + + elif credential_type == "basic_auth": + username = cred_data.get("username", "") + password = cred_data.get("password", "") + encoded = base64.b64encode(f"{username}:{password}".encode()).decode() + return {"Authorization": f"Basic {encoded}"} + + elif credential_type == "custom_header": + header_name = cred_data.get("header_name", "X-Custom") + header_value = cred_data.get("header_value", "") + return {header_name: header_value} + + return {} diff --git a/ui/src/app/tools/[toolUuid]/page.tsx b/ui/src/app/tools/[toolUuid]/page.tsx new file mode 100644 index 0000000..9e9698f --- /dev/null +++ b/ui/src/app/tools/[toolUuid]/page.tsx @@ -0,0 +1,498 @@ +"use client"; + +import { ArrowLeft, Code, Globe, Loader2, Save } from "lucide-react"; +import { useParams, useRouter } from "next/navigation"; +import { useCallback, useEffect, useState } from "react"; + +import { + getToolApiV1ToolsToolUuidGet, + updateToolApiV1ToolsToolUuidPut, +} from "@/client/sdk.gen"; +import type { ToolResponse } from "@/client/types.gen"; + +// Extended HttpApiConfig with parameters (until client types are regenerated) +interface HttpApiConfigWithParams { + method?: string; + url?: string; + headers?: Record; + credential_uuid?: string; + parameters?: ToolParameter[]; + timeout_ms?: number; +} +import { + CredentialSelector, + type HttpMethod, + HttpMethodSelector, + KeyValueEditor, + type KeyValueItem, + ParameterEditor, + type ToolParameter, +} from "@/components/http"; +import { Button } from "@/components/ui/button"; +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Skeleton } from "@/components/ui/skeleton"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; +import { Textarea } from "@/components/ui/textarea"; +import { useAuth } from "@/lib/auth"; + +export default function ToolDetailPage() { + const { toolUuid } = useParams<{ toolUuid: string }>(); + const { user, getAccessToken, redirectToLogin, loading } = useAuth(); + const router = useRouter(); + + const [tool, setTool] = useState(null); + const [isLoading, setIsLoading] = useState(true); + const [isSaving, setIsSaving] = useState(false); + const [error, setError] = useState(null); + const [saveSuccess, setSaveSuccess] = useState(false); + const [showCodeDialog, setShowCodeDialog] = useState(false); + + // Form state + const [name, setName] = useState(""); + const [description, setDescription] = useState(""); + const [httpMethod, setHttpMethod] = useState("POST"); + const [url, setUrl] = useState(""); + const [credentialUuid, setCredentialUuid] = useState(""); + const [headers, setHeaders] = useState([]); + const [parameters, setParameters] = useState([]); + const [timeoutMs, setTimeoutMs] = useState(5000); + + // Redirect if not authenticated + useEffect(() => { + if (!loading && !user) { + redirectToLogin(); + } + }, [loading, user, redirectToLogin]); + + const fetchTool = useCallback(async () => { + if (loading || !user || !toolUuid) return; + + try { + setIsLoading(true); + setError(null); + const accessToken = await getAccessToken(); + + const response = await getToolApiV1ToolsToolUuidGet({ + path: { tool_uuid: toolUuid }, + headers: { + Authorization: `Bearer ${accessToken}`, + }, + }); + + if (response.data) { + setTool(response.data); + populateFormFromTool(response.data); + } + } catch (err) { + setError("Failed to fetch tool"); + console.error("Error fetching tool:", err); + } finally { + setIsLoading(false); + } + }, [loading, user, toolUuid, getAccessToken]); + + const populateFormFromTool = (tool: ToolResponse) => { + setName(tool.name); + setDescription(tool.description || ""); + + const config = tool.definition?.config as HttpApiConfigWithParams | undefined; + if (config) { + setHttpMethod((config.method as HttpMethod) || "POST"); + setUrl(config.url || ""); + setCredentialUuid(config.credential_uuid || ""); + setTimeoutMs(config.timeout_ms || 5000); + + // Convert headers object to array + if (config.headers) { + setHeaders( + Object.entries(config.headers).map(([key, value]) => ({ + key, + value: value as string, + })) + ); + } else { + setHeaders([]); + } + + // Load parameters + if (config.parameters && Array.isArray(config.parameters)) { + setParameters( + config.parameters.map((p: ToolParameter) => ({ + name: p.name || "", + type: p.type || "string", + description: p.description || "", + required: p.required ?? true, + })) + ); + } else { + setParameters([]); + } + } + }; + + useEffect(() => { + fetchTool(); + }, [fetchTool]); + + const handleSave = async () => { + // Validate URL + if (!url.trim()) { + setError("URL is required"); + return; + } + + // Validate parameters have names + const invalidParams = parameters.filter((p) => !p.name.trim()); + if (invalidParams.length > 0) { + setError("All parameters must have a name"); + return; + } + + try { + setIsSaving(true); + setError(null); + setSaveSuccess(false); + const accessToken = await getAccessToken(); + + // Convert headers array to object + const headersObject: Record = {}; + headers.filter((h) => h.key && h.value).forEach((h) => { + headersObject[h.key] = h.value; + }); + + // Filter out empty parameters + const validParameters = parameters.filter((p) => p.name.trim()); + + // Build the request body (cast needed until client types are regenerated) + const requestBody = { + name, + description: description || undefined, + definition: { + schema_version: 1, + type: "http_api", + config: { + method: httpMethod, + url, + credential_uuid: credentialUuid || undefined, + headers: + Object.keys(headersObject).length > 0 + ? headersObject + : undefined, + parameters: + validParameters.length > 0 ? validParameters : undefined, + timeout_ms: timeoutMs, + }, + }, + }; + + const response = await updateToolApiV1ToolsToolUuidPut({ + path: { tool_uuid: toolUuid }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + body: requestBody as any, + headers: { + Authorization: `Bearer ${accessToken}`, + }, + }); + + if (response.data) { + setTool(response.data); + setSaveSuccess(true); + setTimeout(() => setSaveSuccess(false), 3000); + } + } catch (err) { + setError("Failed to save tool"); + console.error("Error saving tool:", err); + } finally { + setIsSaving(false); + } + }; + + const getCodeSnippet = () => { + if (!tool) return ""; + + const headersObj: Record = { + "Content-Type": "application/json", + }; + headers.filter((h) => h.key && h.value).forEach((h) => { + headersObj[h.key] = h.value; + }); + + // Build example body from parameters + const exampleBody: Record = {}; + parameters.forEach((p) => { + if (p.type === "number") { + exampleBody[p.name] = 0; + } else if (p.type === "boolean") { + exampleBody[p.name] = true; + } else { + exampleBody[p.name] = `<${p.name}>`; + } + }); + + const hasBody = httpMethod !== "GET" && httpMethod !== "DELETE" && parameters.length > 0; + + return `// ${tool.name} +// ${tool.description || "HTTP API Tool"} + +const response = await fetch("${url}", { + method: "${httpMethod}", + headers: ${JSON.stringify(headersObj, null, 4)},${hasBody ? ` + body: JSON.stringify(${JSON.stringify(exampleBody, null, 4)}),` : ""} +}); + +const data = await response.json();`; + }; + + if (loading || !user) { + return ( +
+
+ + +
+
+ ); + } + + if (isLoading) { + return ( +
+
+
+ + +
+
+
+ ); + } + + if (!tool) { + return ( +
+
+
+

Tool not found

+ +
+
+
+ ); + } + + return ( +
+
+
+ {/* Header */} +
+
+ +
+
+ +
+
+

{name}

+

+ HTTP API Tool +

+
+
+
+
+ + +
+
+ + {error && ( +
+ {error} +
+ )} + + {saveSuccess && ( +
+ Tool saved successfully! +
+ )} + + + + Tool Configuration + + Configure the HTTP API endpoint and request settings + + + + + + Settings + Authentication + Parameters + + + +
+ + + setName(e.target.value)} + placeholder="e.g., Book Appointment" + /> +
+ +
+ + +