diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml new file mode 100644 index 0000000..6172450 --- /dev/null +++ b/.github/workflows/api-tests.yml @@ -0,0 +1,104 @@ +name: API tests + +on: + pull_request: + branches: [main] + paths: + - "api/**" + - "pipecat/**" + - "scripts/**" + - ".github/workflows/api-tests.yml" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + pytest: + runs-on: ubuntu-latest + timeout-minutes: 30 + + services: + postgres: + # pgvector image: api/alembic/versions/dc33eef8dabe_add_document_tables.py + # runs CREATE EXTENSION vector, which plain postgres doesn't ship. + image: pgvector/pgvector:pg17 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U postgres" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + redis: + image: redis:7 + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + env: + # api/conftest.py creates test_db via the postgres admin DB. + DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db + REDIS_URL: redis://localhost:6379/0 + ENABLE_AWS_S3: "false" + # MINIO is not actually contacted by tests, but storage.py constructs + # the MinioFileSystem at import time and requires MINIO_PUBLIC_ENDPOINT. + MINIO_PUBLIC_ENDPOINT: http://localhost:9000 + DEPLOYMENT_MODE: oss + ENVIRONMENT: test + LOG_LEVEL: DEBUG + PYTHONPATH: ${{ github.workspace }} + + steps: + - name: Checkout repo (with pipecat submodule) + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + cache-dependency-path: | + api/requirements.txt + pipecat/pyproject.toml + + - name: Set up Node 22 (test_ts_bridge.py shells out to node) + uses: actions/setup-node@v4 + with: + node-version: "22" + + - name: Install api dependencies + run: | + pip install -r api/requirements.txt + pip install -r api/requirements.dev.txt + pip install './pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb]' + + - name: Install ts_validator npm deps + working-directory: api/mcp_server/ts_validator + run: npm install + + - name: Run pytest + working-directory: api + run: pytest tests/ -xvs + + - name: Send Slack notification - Failure + if: failure() + uses: slackapi/slack-github-action@v1.26.0 + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} + with: + payload: | + { + "text": "❌ Dograh API tests failed on ${{ github.ref_name }} by ${{ github.actor }} - <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|View Logs>" + } diff --git a/.github/workflows/docs-openapi-drift.yml b/.github/workflows/docs-openapi-drift.yml index 40387b7..bd2a5ef 100644 --- a/.github/workflows/docs-openapi-drift.yml +++ b/.github/workflows/docs-openapi-drift.yml @@ -43,6 +43,7 @@ jobs: DATABASE_URL: postgresql+asyncpg://dummy:dummy@localhost/dummy REDIS_URL: redis://localhost:6379/0 ENABLE_AWS_S3: "false" + MINIO_PUBLIC_ENDPOINT: http://localhost:9000 DEPLOYMENT_MODE: oss run: python -u -m scripts.dump_docs_openapi @@ -54,3 +55,14 @@ jobs: exit 1 fi echo "OpenAPI spec is in sync." + + - name: Send Slack notification - Failure + if: failure() + uses: slackapi/slack-github-action@v1.26.0 + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} + with: + payload: | + { + "text": "❌ Dograh Docs OpenAPI drift check failed on ${{ github.ref_name }} by ${{ github.actor }} - <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|View Logs>" + } diff --git a/api/alembic.ini b/api/alembic.ini index b2cca73..37ac375 100644 --- a/api/alembic.ini +++ b/api/alembic.ini @@ -15,6 +15,13 @@ script_location = %(here)s/alembic # defaults to the current working directory. prepend_sys_path = . +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +path_separator = os + # timezone to use when rendering the date within the migration file # as well as the filename. # If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. diff --git a/api/alembic/env.py b/api/alembic/env.py index 697aa50..215894a 100644 --- a/api/alembic/env.py +++ b/api/alembic/env.py @@ -87,6 +87,7 @@ def do_run_migrations(connection): render_item=render_item, compare_type=True, compare_server_default=True, + transaction_per_migration=True, ) with context.begin_transaction(): context.run_migrations() diff --git a/api/conftest.py b/api/conftest.py index 33781ec..ab47760 100644 --- a/api/conftest.py +++ b/api/conftest.py @@ -77,44 +77,17 @@ from sqlalchemy.pool import NullPool def get_test_database_url() -> str: - """ - Get the test database URL by appending _test to the database name. - - Example: - postgresql+asyncpg://user:pass@host/mydb - -> postgresql+asyncpg://user:pass@host/mydb_test - """ - original_url = os.environ.get("DATABASE_URL") - if not original_url: + """Get the test database URL from the DATABASE_URL env var.""" + test_url = os.environ.get("DATABASE_URL") + if not test_url: raise ValueError("DATABASE_URL environment variable is not set") - - parsed = urlparse(original_url) - # Append _test to the database name (path without leading slash) - original_db_name = parsed.path.lstrip("/") - test_db_name = f"{original_db_name}_test" - - # Reconstruct the URL with the new database name - test_url = urlunparse( - ( - parsed.scheme, - parsed.netloc, - f"/{test_db_name}", - parsed.params, - parsed.query, - parsed.fragment, - ) - ) return test_url def get_base_database_url() -> str: - """ - Get base database URL (postgres) for creating/dropping test database. - """ - original_url = os.environ.get("DATABASE_URL") - parsed = urlparse(original_url) - # Connect to 'postgres' database for admin operations - base_url = urlunparse( + """Get base database URL (postgres) for creating/dropping test database.""" + parsed = urlparse(get_test_database_url()) + return urlunparse( ( parsed.scheme, parsed.netloc, @@ -124,15 +97,12 @@ def get_base_database_url() -> str: parsed.fragment, ) ) - return base_url def get_test_db_name() -> str: - """Extract the test database name.""" - original_url = os.environ.get("DATABASE_URL") - parsed = urlparse(original_url) - original_db_name = parsed.path.lstrip("/") - return f"{original_db_name}_test" + """Extract the test database name from DATABASE_URL.""" + parsed = urlparse(get_test_database_url()) + return parsed.path.lstrip("/") @pytest.fixture(scope="session") diff --git a/api/db/workflow_run_client.py b/api/db/workflow_run_client.py index 7e4bcf2..4a91dfb 100644 --- a/api/db/workflow_run_client.py +++ b/api/db/workflow_run_client.py @@ -238,6 +238,22 @@ class WorkflowRunClient(BaseDBClient): ) return result.scalars().first() + async def get_organization_id_by_workflow_run_id( + self, run_id: int | None + ) -> int | None: + """Resolve organization_id from a workflow run via workflow.user.""" + if not run_id: + return None + async with self.async_session() as session: + result = await session.execute( + select(WorkflowModel.organization_id) + .join( + WorkflowRunModel, WorkflowRunModel.workflow_id == WorkflowModel.id + ) + .where(WorkflowRunModel.id == run_id) + ) + return result.scalar_one_or_none() + async def get_workflow_runs_by_workflow_id( self, workflow_id: int, diff --git a/api/logging_config.py b/api/logging_config.py index bd4649f..3e4f443 100644 --- a/api/logging_config.py +++ b/api/logging_config.py @@ -3,6 +3,7 @@ import os import sys import loguru +from pipecat.utils.run_context import run_id_var from api.constants import ( ENVIRONMENT, @@ -15,7 +16,6 @@ from api.constants import ( ) from api.enums import Environment from api.utils.worker import get_worker_id, is_worker_process -from pipecat.utils.run_context import run_id_var # Track if logging has been initialized _logging_initialized = False diff --git a/api/routes/agent_stream.py b/api/routes/agent_stream.py index 2debe7a..b593a31 100644 --- a/api/routes/agent_stream.py +++ b/api/routes/agent_stream.py @@ -17,13 +17,13 @@ from typing import Optional from fastapi import APIRouter, WebSocket from loguru import logger +from pipecat.utils.run_context import set_current_org_id, set_current_run_id from starlette.websockets import WebSocketDisconnect from api.db import db_client from api.enums import CallType, WorkflowRunState from api.services.quota_service import check_dograh_quota_by_user_id from api.services.telephony import registry as telephony_registry -from pipecat.utils.run_context import set_current_org_id, set_current_run_id router = APIRouter(prefix="/agent-stream") diff --git a/api/routes/telephony.py b/api/routes/telephony.py index f72dcd5..a81694d 100644 --- a/api/routes/telephony.py +++ b/api/routes/telephony.py @@ -15,6 +15,7 @@ from fastapi import ( WebSocket, ) from loguru import logger +from pipecat.utils.run_context import set_current_run_id from pydantic import BaseModel, field_validator from starlette.websockets import WebSocketDisconnect @@ -44,7 +45,6 @@ from api.utils.telephony_helper import ( numbers_match, parse_webhook_request, ) -from pipecat.utils.run_context import set_current_run_id router = APIRouter(prefix="/telephony") diff --git a/api/routes/webrtc_signaling.py b/api/routes/webrtc_signaling.py index 04eee4b..4246b0b 100644 --- a/api/routes/webrtc_signaling.py +++ b/api/routes/webrtc_signaling.py @@ -24,6 +24,8 @@ from aiortc import RTCIceServer from aiortc.sdp import candidate_from_sdp from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect from loguru import logger +from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection +from pipecat.utils.run_context import set_current_org_id, set_current_run_id from starlette.websockets import WebSocketState from api.constants import ENVIRONMENT @@ -43,8 +45,6 @@ from api.services.pipecat.ws_sender_registry import ( unregister_ws_sender, ) from api.services.quota_service import check_dograh_quota -from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection -from pipecat.utils.run_context import set_current_org_id, set_current_run_id router = APIRouter(prefix="/ws") diff --git a/api/services/filesystem/__init__.py b/api/services/filesystem/__init__.py index fb716f3..170bfe1 100644 --- a/api/services/filesystem/__init__.py +++ b/api/services/filesystem/__init__.py @@ -1,9 +1,11 @@ from .base import BaseFileSystem from .minio import MinioFileSystem +from .null import NullFileSystem from .s3 import S3FileSystem __all__ = [ "BaseFileSystem", "S3FileSystem", "MinioFileSystem", + "NullFileSystem", ] diff --git a/api/services/filesystem/null.py b/api/services/filesystem/null.py new file mode 100644 index 0000000..e01c72b --- /dev/null +++ b/api/services/filesystem/null.py @@ -0,0 +1,50 @@ +from typing import Any, BinaryIO, Dict, NoReturn, Optional + +from .base import BaseFileSystem + + +class NullFileSystem(BaseFileSystem): + """No-op filesystem used when storage is not configured (e.g. tests). + + Every operation raises so that any test that exercises storage fails + loudly instead of silently succeeding against a stub. + """ + + def _fail(self, op: str) -> NoReturn: + raise RuntimeError( + f"NullFileSystem.{op} called — storage is not configured. " + "Set ENVIRONMENT to a non-test value or inject a real filesystem fixture." + ) + + async def acreate_file(self, file_path: str, content: BinaryIO) -> bool: + self._fail("acreate_file") + + async def aupload_file(self, local_path: str, destination_path: str) -> bool: + self._fail("aupload_file") + + async def aget_signed_url( + self, + file_path: str, + expiration: int = 3600, + force_inline: bool = False, + use_internal_endpoint: bool = False, + ) -> Optional[str]: + self._fail("aget_signed_url") + + async def aget_file_metadata(self, file_path: str) -> Optional[Dict[str, Any]]: + self._fail("aget_file_metadata") + + async def aget_presigned_put_url( + self, + file_path: str, + expiration: int = 900, + content_type: str = "text/csv", + max_size: int = 10_485_760, + ) -> Optional[str]: + self._fail("aget_presigned_put_url") + + async def adownload_file(self, source_path: str, local_path: str) -> bool: + self._fail("adownload_file") + + async def acopy_file(self, source_path: str, destination_path: str) -> bool: + self._fail("acopy_file") diff --git a/api/services/looptalk/audio_streamer.py b/api/services/looptalk/audio_streamer.py index 8221c4d..0acdb22 100644 --- a/api/services/looptalk/audio_streamer.py +++ b/api/services/looptalk/audio_streamer.py @@ -9,7 +9,6 @@ import asyncio from typing import Dict, Set from loguru import logger - from pipecat.audio.utils import mix_audio from pipecat.frames.frames import ( Frame, diff --git a/api/services/looptalk/core/pipeline_builder.py b/api/services/looptalk/core/pipeline_builder.py index ee11613..77e634f 100644 --- a/api/services/looptalk/core/pipeline_builder.py +++ b/api/services/looptalk/core/pipeline_builder.py @@ -3,6 +3,10 @@ from typing import Any, Dict from loguru import logger +from pipecat.pipeline.pipeline import Pipeline +from pipecat.processors.aggregators.llm_response_universal import ( + LLMContextAggregatorPair, +) from api.db.db_client import DBClient from api.services.looptalk.audio_streamer import get_or_create_audio_streamer @@ -23,10 +27,6 @@ from api.services.pipecat.service_factory import ( from api.services.workflow.dto import ReactFlowDTO from api.services.workflow.pipecat_engine import PipecatEngine from api.services.workflow.workflow import WorkflowGraph -from pipecat.pipeline.pipeline import Pipeline -from pipecat.processors.aggregators.llm_response_universal import ( - LLMContextAggregatorPair, -) class LoopTalkPipelineBuilder: diff --git a/api/services/looptalk/internal_serializer.py b/api/services/looptalk/internal_serializer.py index 75d89ea..6a94a40 100644 --- a/api/services/looptalk/internal_serializer.py +++ b/api/services/looptalk/internal_serializer.py @@ -7,7 +7,6 @@ """Internal frame serializer for agent-to-agent communication.""" from loguru import logger - from pipecat.frames.frames import ( Frame, InputAudioRawFrame, diff --git a/api/services/looptalk/internal_transport.py b/api/services/looptalk/internal_transport.py index 756b499..00bfc93 100644 --- a/api/services/looptalk/internal_transport.py +++ b/api/services/looptalk/internal_transport.py @@ -11,8 +11,6 @@ import time from typing import Dict, Optional, Tuple from loguru import logger - -from api.services.looptalk.internal_serializer import InternalFrameSerializer from pipecat.frames.frames import ( CancelFrame, EndFrame, @@ -29,6 +27,8 @@ from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport from pipecat.transports.base_transport import BaseTransport, TransportParams +from api.services.looptalk.internal_serializer import InternalFrameSerializer + class InternalInputTransport(BaseInputTransport): """Input side of internal transport for agent-to-agent communication.""" diff --git a/api/services/looptalk/orchestrator.py b/api/services/looptalk/orchestrator.py index d4c969a..d51e9da 100644 --- a/api/services/looptalk/orchestrator.py +++ b/api/services/looptalk/orchestrator.py @@ -6,6 +6,8 @@ from pathlib import Path from typing import Any, Dict, Optional from loguru import logger +from pipecat.pipeline.task import PipelineTask +from pipecat.utils.run_context import set_current_run_id from api.db.db_client import DBClient from api.services.looptalk.internal_transport import ( @@ -13,8 +15,6 @@ from api.services.looptalk.internal_transport import ( InternalTransportManager, ) from api.services.pipecat.transport_setup import create_internal_transport -from pipecat.pipeline.task import PipelineTask -from pipecat.utils.run_context import set_current_run_id from .core.pipeline_builder import LoopTalkPipelineBuilder from .core.recording_manager import RecordingManager diff --git a/api/services/pipecat/event_handlers.py b/api/services/pipecat/event_handlers.py index fe6c769..0ce639a 100644 --- a/api/services/pipecat/event_handlers.py +++ b/api/services/pipecat/event_handlers.py @@ -188,7 +188,12 @@ def register_event_handlers( await engine.llm.queue_frame(LLMContextFrame(engine.context)) else: logger.debug("Playing text greeting via TTS") - await task.queue_frame(TTSSpeakFrame(greeting_value)) + # append_to_context=True so the assistant aggregator commits + # the greeting to the LLM context once TTS finishes; without + # it the LLM would re-greet on its first generation. + await task.queue_frame( + TTSSpeakFrame(greeting_value, append_to_context=True) + ) else: logger.debug( "Both pipeline_started and client_connected received - triggering initial LLM generation" diff --git a/api/services/pipecat/run_pipeline.py b/api/services/pipecat/run_pipeline.py index 660efe3..d5618be 100644 --- a/api/services/pipecat/run_pipeline.py +++ b/api/services/pipecat/run_pipeline.py @@ -429,7 +429,6 @@ async def _run_pipeline( engine.set_audio_config(audio_config) assistant_params = LLMAssistantAggregatorParams( - expect_stripped_words=True, correct_aggregation_callback=engine.create_aggregation_correction_callback(), ) diff --git a/api/services/smart_turn/app.py b/api/services/smart_turn/app.py index 66ccf5a..6bbd0ab 100644 --- a/api/services/smart_turn/app.py +++ b/api/services/smart_turn/app.py @@ -21,9 +21,8 @@ from fastapi import ( status, ) from fastapi.websockets import WebSocketState -from scipy.io import wavfile - from pipecat.audio.turn.smart_turn.local_smart_turn_v2 import LocalSmartTurnAnalyzerV2 +from scipy.io import wavfile LOG_LEVEL = ( logging.DEBUG diff --git a/api/services/smart_turn/websocket_smart_turn.py b/api/services/smart_turn/websocket_smart_turn.py index 4220d3c..82a7e6f 100644 --- a/api/services/smart_turn/websocket_smart_turn.py +++ b/api/services/smart_turn/websocket_smart_turn.py @@ -20,7 +20,6 @@ from typing import Any, Dict, Optional import numpy as np import websockets from loguru import logger - from pipecat.audio.turn.smart_turn.base_smart_turn import ( BaseSmartTurn, SmartTurnTimeoutException, diff --git a/api/services/storage.py b/api/services/storage.py index 0f05073..b24310b 100644 --- a/api/services/storage.py +++ b/api/services/storage.py @@ -2,6 +2,7 @@ from loguru import logger from api.constants import ( ENABLE_AWS_S3, + ENVIRONMENT, MINIO_ACCESS_KEY, MINIO_BUCKET, MINIO_ENDPOINT, @@ -11,9 +12,9 @@ from api.constants import ( S3_BUCKET, S3_REGION, ) -from api.enums import StorageBackend +from api.enums import Environment, StorageBackend -from .filesystem import BaseFileSystem, MinioFileSystem, S3FileSystem +from .filesystem import BaseFileSystem, MinioFileSystem, NullFileSystem, S3FileSystem def get_storage_for_backend(backend: str) -> BaseFileSystem: @@ -73,12 +74,18 @@ def get_current_storage_backend() -> StorageBackend: return StorageBackend.get_current_backend() -# Create a single storage instance at module load time -_backend = StorageBackend.get_current_backend() -logger.info( - f"Initializing storage backend: {_backend.name} (value: {_backend.value}, ENABLE_AWS_S3={ENABLE_AWS_S3})" -) -storage_fs = get_storage_for_backend(_backend.value) +# Create a single storage instance at module load time. +# In the test environment we skip the real backend so import doesn't require +# MinIO/S3 to be reachable; tests that need storage must inject a real fs. +if ENVIRONMENT == Environment.TEST.value: + logger.info("ENVIRONMENT=test — using NullFileSystem (no storage backend)") + storage_fs: BaseFileSystem = NullFileSystem() +else: + _backend = StorageBackend.get_current_backend() + logger.info( + f"Initializing storage backend: {_backend.name} (value: {_backend.value}, ENABLE_AWS_S3={ENABLE_AWS_S3})" + ) + storage_fs = get_storage_for_backend(_backend.value) # For backward compatibility, keep get_storage() function diff --git a/api/services/telephony/providers/ari/strategies.py b/api/services/telephony/providers/ari/strategies.py index 4e02c8e..288110f 100644 --- a/api/services/telephony/providers/ari/strategies.py +++ b/api/services/telephony/providers/ari/strategies.py @@ -6,7 +6,6 @@ This module contains the business logic for Asterisk ARI call operations. from typing import Any, Dict from loguru import logger - from pipecat.serializers.call_strategies import HangupStrategy, TransferStrategy diff --git a/api/services/telephony/providers/ari/transport.py b/api/services/telephony/providers/ari/transport.py index cb35bbf..08f7aed 100644 --- a/api/services/telephony/providers/ari/transport.py +++ b/api/services/telephony/providers/ari/transport.py @@ -1,15 +1,15 @@ """ARI (Asterisk) transport factory.""" from fastapi import WebSocket - -from api.services.pipecat.audio_config import AudioConfig -from api.services.pipecat.audio_mixer import build_audio_out_mixer -from api.services.telephony.factory import load_credentials_for_transport from pipecat.transports.websocket.fastapi import ( FastAPIWebsocketParams, FastAPIWebsocketTransport, ) +from api.services.pipecat.audio_config import AudioConfig +from api.services.pipecat.audio_mixer import build_audio_out_mixer +from api.services.telephony.factory import load_credentials_for_transport + from .serializers import AsteriskFrameSerializer from .strategies import ARIBridgeSwapStrategy, ARIHangupStrategy diff --git a/api/services/telephony/providers/cloudonix/routes.py b/api/services/telephony/providers/cloudonix/routes.py index 6f39831..cd4758a 100644 --- a/api/services/telephony/providers/cloudonix/routes.py +++ b/api/services/telephony/providers/cloudonix/routes.py @@ -8,6 +8,7 @@ import json from fastapi import APIRouter, Request from loguru import logger +from pipecat.utils.run_context import set_current_run_id from api.db import db_client from api.services.telephony.factory import get_telephony_provider_for_run @@ -15,7 +16,6 @@ from api.services.telephony.status_processor import ( StatusCallbackRequest, _process_status_update, ) -from pipecat.utils.run_context import set_current_run_id router = APIRouter() diff --git a/api/services/telephony/providers/cloudonix/strategies.py b/api/services/telephony/providers/cloudonix/strategies.py index 1fa4fad..b64cf6d 100644 --- a/api/services/telephony/providers/cloudonix/strategies.py +++ b/api/services/telephony/providers/cloudonix/strategies.py @@ -3,9 +3,9 @@ from typing import Any, Dict from loguru import logger +from pipecat.serializers.call_strategies import HangupStrategy from api.services.telephony.providers.cloudonix.provider import CLOUDONIX_API_BASE_URL -from pipecat.serializers.call_strategies import HangupStrategy class CloudonixHangupStrategy(HangupStrategy): diff --git a/api/services/telephony/providers/cloudonix/transport.py b/api/services/telephony/providers/cloudonix/transport.py index 4ad73bd..9e06d16 100644 --- a/api/services/telephony/providers/cloudonix/transport.py +++ b/api/services/telephony/providers/cloudonix/transport.py @@ -1,15 +1,15 @@ """Cloudonix transport factory.""" from fastapi import WebSocket - -from api.services.pipecat.audio_config import AudioConfig -from api.services.pipecat.audio_mixer import build_audio_out_mixer -from api.services.telephony.factory import load_credentials_for_transport from pipecat.transports.websocket.fastapi import ( FastAPIWebsocketParams, FastAPIWebsocketTransport, ) +from api.services.pipecat.audio_config import AudioConfig +from api.services.pipecat.audio_mixer import build_audio_out_mixer +from api.services.telephony.factory import load_credentials_for_transport + from .serializers import CloudonixFrameSerializer from .strategies import CloudonixHangupStrategy diff --git a/api/services/telephony/providers/plivo/routes.py b/api/services/telephony/providers/plivo/routes.py index 6fad8c3..be1ecd7 100644 --- a/api/services/telephony/providers/plivo/routes.py +++ b/api/services/telephony/providers/plivo/routes.py @@ -9,6 +9,7 @@ from typing import Optional from fastapi import APIRouter, Header, Request from loguru import logger +from pipecat.utils.run_context import set_current_run_id from starlette.responses import HTMLResponse from api.db import db_client @@ -18,7 +19,6 @@ from api.services.telephony.status_processor import ( _process_status_update, ) from api.utils.common import get_backend_endpoints -from pipecat.utils.run_context import set_current_run_id router = APIRouter() diff --git a/api/services/telephony/providers/plivo/transport.py b/api/services/telephony/providers/plivo/transport.py index cd765a2..4a83eb2 100644 --- a/api/services/telephony/providers/plivo/transport.py +++ b/api/services/telephony/providers/plivo/transport.py @@ -1,15 +1,15 @@ """Plivo transport factory.""" from fastapi import WebSocket - -from api.services.pipecat.audio_config import AudioConfig -from api.services.pipecat.audio_mixer import build_audio_out_mixer -from api.services.telephony.factory import load_credentials_for_transport from pipecat.transports.websocket.fastapi import ( FastAPIWebsocketParams, FastAPIWebsocketTransport, ) +from api.services.pipecat.audio_config import AudioConfig +from api.services.pipecat.audio_mixer import build_audio_out_mixer +from api.services.telephony.factory import load_credentials_for_transport + from .serializers import PlivoFrameSerializer diff --git a/api/services/telephony/providers/telnyx/routes.py b/api/services/telephony/providers/telnyx/routes.py index 23df07e..0947b14 100644 --- a/api/services/telephony/providers/telnyx/routes.py +++ b/api/services/telephony/providers/telnyx/routes.py @@ -8,6 +8,7 @@ import json from fastapi import APIRouter, Request from loguru import logger +from pipecat.utils.run_context import set_current_run_id from api.db import db_client from api.services.telephony.factory import get_telephony_provider_for_run @@ -16,7 +17,6 @@ from api.services.telephony.status_processor import ( StatusCallbackRequest, _process_status_update, ) -from pipecat.utils.run_context import set_current_run_id router = APIRouter() diff --git a/api/services/telephony/providers/telnyx/transport.py b/api/services/telephony/providers/telnyx/transport.py index f47d0c3..a47102e 100644 --- a/api/services/telephony/providers/telnyx/transport.py +++ b/api/services/telephony/providers/telnyx/transport.py @@ -1,15 +1,15 @@ """Telnyx transport factory.""" from fastapi import WebSocket - -from api.services.pipecat.audio_config import AudioConfig -from api.services.pipecat.audio_mixer import build_audio_out_mixer -from api.services.telephony.factory import load_credentials_for_transport from pipecat.transports.websocket.fastapi import ( FastAPIWebsocketParams, FastAPIWebsocketTransport, ) +from api.services.pipecat.audio_config import AudioConfig +from api.services.pipecat.audio_mixer import build_audio_out_mixer +from api.services.telephony.factory import load_credentials_for_transport + from .serializers import TelnyxFrameSerializer diff --git a/api/services/telephony/providers/twilio/routes.py b/api/services/telephony/providers/twilio/routes.py index 11fca1b..e8ac939 100644 --- a/api/services/telephony/providers/twilio/routes.py +++ b/api/services/telephony/providers/twilio/routes.py @@ -9,6 +9,7 @@ from typing import Optional from fastapi import APIRouter, Header, Request from loguru import logger +from pipecat.utils.run_context import set_current_run_id from starlette.responses import HTMLResponse from api.db import db_client @@ -18,7 +19,6 @@ from api.services.telephony.status_processor import ( _process_status_update, ) from api.utils.common import get_backend_endpoints -from pipecat.utils.run_context import set_current_run_id router = APIRouter() diff --git a/api/services/telephony/providers/twilio/strategies.py b/api/services/telephony/providers/twilio/strategies.py index 003eb33..e80e1a6 100644 --- a/api/services/telephony/providers/twilio/strategies.py +++ b/api/services/telephony/providers/twilio/strategies.py @@ -8,7 +8,6 @@ from typing import Any, Dict import aiohttp from loguru import logger - from pipecat.serializers.call_strategies import HangupStrategy, TransferStrategy diff --git a/api/services/telephony/providers/twilio/transport.py b/api/services/telephony/providers/twilio/transport.py index e363f26..7d3ea2b 100644 --- a/api/services/telephony/providers/twilio/transport.py +++ b/api/services/telephony/providers/twilio/transport.py @@ -1,15 +1,15 @@ """Twilio transport factory.""" from fastapi import WebSocket - -from api.services.pipecat.audio_config import AudioConfig -from api.services.pipecat.audio_mixer import build_audio_out_mixer -from api.services.telephony.factory import load_credentials_for_transport from pipecat.transports.websocket.fastapi import ( FastAPIWebsocketParams, FastAPIWebsocketTransport, ) +from api.services.pipecat.audio_config import AudioConfig +from api.services.pipecat.audio_mixer import build_audio_out_mixer +from api.services.telephony.factory import load_credentials_for_transport + from .serializers import TwilioFrameSerializer from .strategies import TwilioConferenceStrategy, TwilioHangupStrategy diff --git a/api/services/telephony/providers/vobiz/routes.py b/api/services/telephony/providers/vobiz/routes.py index d39946c..4fffe5b 100644 --- a/api/services/telephony/providers/vobiz/routes.py +++ b/api/services/telephony/providers/vobiz/routes.py @@ -10,6 +10,7 @@ from typing import Optional from fastapi import APIRouter, Header, Request from loguru import logger +from pipecat.utils.run_context import set_current_run_id from starlette.responses import HTMLResponse from api.db import db_client @@ -24,7 +25,6 @@ from api.utils.common import get_backend_endpoints from api.utils.telephony_helper import ( parse_webhook_request, ) -from pipecat.utils.run_context import set_current_run_id router = APIRouter() diff --git a/api/services/telephony/providers/vobiz/transport.py b/api/services/telephony/providers/vobiz/transport.py index 2a2a042..44c3ccb 100644 --- a/api/services/telephony/providers/vobiz/transport.py +++ b/api/services/telephony/providers/vobiz/transport.py @@ -7,15 +7,15 @@ Vobiz uses Plivo-compatible WebSocket protocol: from fastapi import WebSocket from loguru import logger - -from api.services.pipecat.audio_config import AudioConfig -from api.services.pipecat.audio_mixer import build_audio_out_mixer -from api.services.telephony.factory import load_credentials_for_transport from pipecat.transports.websocket.fastapi import ( FastAPIWebsocketParams, FastAPIWebsocketTransport, ) +from api.services.pipecat.audio_config import AudioConfig +from api.services.pipecat.audio_mixer import build_audio_out_mixer +from api.services.telephony.factory import load_credentials_for_transport + from .serializers import VobizFrameSerializer diff --git a/api/services/telephony/providers/vonage/routes.py b/api/services/telephony/providers/vonage/routes.py index dff1bba..a4cca35 100644 --- a/api/services/telephony/providers/vonage/routes.py +++ b/api/services/telephony/providers/vonage/routes.py @@ -9,6 +9,7 @@ from typing import Optional from fastapi import APIRouter, Request from loguru import logger +from pipecat.utils.run_context import set_current_run_id from api.db import db_client from api.services.telephony.factory import get_telephony_provider_for_run @@ -16,7 +17,6 @@ from api.services.telephony.status_processor import ( StatusCallbackRequest, _process_status_update, ) -from pipecat.utils.run_context import set_current_run_id router = APIRouter() diff --git a/api/services/telephony/providers/vonage/transport.py b/api/services/telephony/providers/vonage/transport.py index 8e895c3..dc3397a 100644 --- a/api/services/telephony/providers/vonage/transport.py +++ b/api/services/telephony/providers/vonage/transport.py @@ -1,13 +1,14 @@ """Vonage transport factory.""" -from api.services.pipecat.audio_config import AudioConfig -from api.services.pipecat.audio_mixer import build_audio_out_mixer -from api.services.telephony.factory import load_credentials_for_transport from pipecat.transports.websocket.fastapi import ( FastAPIWebsocketParams, FastAPIWebsocketTransport, ) +from api.services.pipecat.audio_config import AudioConfig +from api.services.pipecat.audio_mixer import build_audio_out_mixer +from api.services.telephony.factory import load_credentials_for_transport + from .serializers import VonageFrameSerializer diff --git a/api/services/workflow/disposition_mapper.py b/api/services/workflow/disposition_mapper.py index f58abc4..f26c015 100644 --- a/api/services/workflow/disposition_mapper.py +++ b/api/services/workflow/disposition_mapper.py @@ -1,14 +1,12 @@ """Utility module for applying disposition code mapping.""" -from typing import Optional - from loguru import logger from api.db import db_client from api.enums import OrganizationConfigurationKey -async def apply_disposition_mapping(value: str, organization_id: Optional[int]) -> str: +async def apply_disposition_mapping(value: str, organization_id: int | None) -> str: """Apply disposition code mapping if configured. Args: @@ -46,32 +44,3 @@ async def apply_disposition_mapping(value: str, organization_id: Optional[int]) except Exception as e: logger.error(f"Error applying disposition mapping: {e}") return value - - -async def get_organization_id_from_workflow_run( - workflow_run_id: Optional[int], -) -> Optional[int]: - """Get organization_id from workflow_run_id through the model relationships. - - Args: - workflow_run_id: The workflow run ID - - Returns: - The organization ID if found, otherwise None - """ - if not workflow_run_id: - return None - - try: - workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id) - if not workflow_run or not workflow_run.workflow: - return None - - workflow = workflow_run.workflow - if not workflow.user: - return None - - return workflow.user.selected_organization_id - except Exception as e: - logger.error(f"Error getting organization_id from workflow_run: {e}") - return None diff --git a/api/services/workflow/pipecat_engine.py b/api/services/workflow/pipecat_engine.py index fa3ddc7..b4f00cb 100644 --- a/api/services/workflow/pipecat_engine.py +++ b/api/services/workflow/pipecat_engine.py @@ -1,11 +1,5 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Union -from api.services.pipecat.audio_playback import play_audio -from api.services.workflow.disposition_mapper import ( - apply_disposition_mapping, - get_organization_id_from_workflow_run, -) -from api.services.workflow.workflow import Node, WorkflowGraph from pipecat.adapters.schemas.tools_schema import ToolsSchema from pipecat.frames.frames import ( BotStartedSpeakingFrame, @@ -21,6 +15,11 @@ from pipecat.services.llm_service import FunctionCallParams from pipecat.services.settings import LLMSettings from pipecat.utils.enums import EndTaskReason +from api.db import db_client +from api.services.pipecat.audio_playback import play_audio +from api.services.workflow.disposition_mapper import apply_disposition_mapping +from api.services.workflow.workflow import Node, WorkflowGraph + if TYPE_CHECKING: from pipecat.frames.frames import Frame from pipecat.services.anthropic.llm import AnthropicLLMService @@ -114,6 +113,9 @@ class PipecatEngine: # Custom tool manager (initialized in initialize()) self._custom_tool_manager: Optional[CustomToolManager] = None + # Cached organization ID (resolved lazily from workflow run) + self._organization_id: Optional[int] = None + # Embeddings configuration (passed from run_pipeline.py) self._embeddings_api_key: Optional[str] = embeddings_api_key self._embeddings_model: Optional[str] = embeddings_model @@ -141,10 +143,13 @@ class PipecatEngine: async def _get_organization_id(self) -> Optional[int]: """Get and cache the organization ID from workflow run.""" - if self._custom_tool_manager: - return await self._custom_tool_manager.get_organization_id() - # Fallback for when manager is not yet initialized - return await get_organization_id_from_workflow_run(self._workflow_run_id) + if self._organization_id is None: + self._organization_id = ( + await db_client.get_organization_id_by_workflow_run_id( + self._workflow_run_id + ) + ) + return self._organization_id def _get_otel_context(self): """Extract the OTel Context from the task's TracingContext. @@ -324,11 +329,7 @@ class PipecatEngine: ) # Register function with LLM - self.llm.register_function( - name, - transition_func, - cancel_on_interruption=False, - ) + self.llm.register_function(name, transition_func) async def _register_knowledge_base_function( self, document_uuids: list[str] diff --git a/api/services/workflow/pipecat_engine_callbacks.py b/api/services/workflow/pipecat_engine_callbacks.py index 83990bf..87ff06e 100644 --- a/api/services/workflow/pipecat_engine_callbacks.py +++ b/api/services/workflow/pipecat_engine_callbacks.py @@ -14,7 +14,6 @@ import re from typing import TYPE_CHECKING from loguru import logger - from pipecat.frames.frames import ( LLMMessagesAppendFrame, ) diff --git a/api/services/workflow/pipecat_engine_context_summarizer.py b/api/services/workflow/pipecat_engine_context_summarizer.py index 1ea9f47..abfa7b2 100644 --- a/api/services/workflow/pipecat_engine_context_summarizer.py +++ b/api/services/workflow/pipecat_engine_context_summarizer.py @@ -6,8 +6,6 @@ from typing import TYPE_CHECKING, Optional from loguru import logger from opentelemetry import trace - -from api.services.pipecat.tracing_config import ensure_tracing from pipecat.frames.frames import LLMContextSummaryRequestFrame from pipecat.utils.context.llm_context_summarization import ( LLMContextSummarizationUtil, @@ -15,6 +13,8 @@ from pipecat.utils.context.llm_context_summarization import ( ) from pipecat.utils.tracing.service_attributes import add_llm_span_attributes +from api.services.pipecat.tracing_config import ensure_tracing + if TYPE_CHECKING: from api.services.workflow.pipecat_engine import PipecatEngine diff --git a/api/services/workflow/pipecat_engine_custom_tools.py b/api/services/workflow/pipecat_engine_custom_tools.py index e9c4a75..f4b5a2b 100644 --- a/api/services/workflow/pipecat_engine_custom_tools.py +++ b/api/services/workflow/pipecat_engine_custom_tools.py @@ -13,21 +13,6 @@ import uuid from typing import TYPE_CHECKING, Any, Dict, List, Optional from loguru import logger - -from api.db import db_client -from api.enums import ToolCategory, WorkflowRunMode -from api.services.pipecat.audio_playback import play_audio, play_audio_loop -from api.services.telephony.call_transfer_manager import get_call_transfer_manager -from api.services.telephony.factory import get_telephony_provider -from api.services.telephony.transfer_event_protocol import TransferContext -from api.services.workflow.disposition_mapper import ( - get_organization_id_from_workflow_run, -) -from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator -from api.services.workflow.tools.custom_tool import ( - execute_http_tool, - tool_to_function_schema, -) from pipecat.adapters.schemas.function_schema import FunctionSchema from pipecat.frames.frames import ( FunctionCallResultProperties, @@ -36,6 +21,18 @@ from pipecat.frames.frames import ( from pipecat.services.llm_service import FunctionCallParams from pipecat.utils.enums import EndTaskReason +from api.db import db_client +from api.enums import ToolCategory, WorkflowRunMode +from api.services.pipecat.audio_playback import play_audio, play_audio_loop +from api.services.telephony.call_transfer_manager import get_call_transfer_manager +from api.services.telephony.factory import get_telephony_provider +from api.services.telephony.transfer_event_protocol import TransferContext +from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator +from api.services.workflow.tools.custom_tool import ( + execute_http_tool, + tool_to_function_schema, +) + if TYPE_CHECKING: from api.services.workflow.pipecat_engine import PipecatEngine @@ -75,7 +72,6 @@ class CustomToolManager: def __init__(self, engine: "PipecatEngine") -> None: self._engine = engine - self._organization_id: Optional[int] = None async def _play_config_message( self, config: dict, *, append_to_context: bool = False @@ -122,12 +118,8 @@ class CustomToolManager: return False async def get_organization_id(self) -> Optional[int]: - """Get and cache the organization ID from workflow run.""" - if self._organization_id is None: - self._organization_id = await get_organization_id_from_workflow_run( - self._engine._workflow_run_id - ) - return self._organization_id + """Get the organization ID from the engine (shared cache).""" + return await self._engine._get_organization_id() async def get_tool_schemas(self, tool_uuids: list[str]) -> list[FunctionSchema]: """Fetch custom tools and convert them to function schemas. @@ -215,13 +207,10 @@ class CustomToolManager: function_name = schema["function"]["name"] # Create and register the handler - handler, timeout_secs, cancel_on_interruption = self._create_handler( - tool, function_name - ) + handler, timeout_secs = self._create_handler(tool, function_name) self._engine.llm.register_function( function_name, handler, - cancel_on_interruption=cancel_on_interruption, timeout_secs=timeout_secs, ) @@ -244,19 +233,16 @@ class CustomToolManager: Async handler function for the tool """ timeout_secs: Optional[float] = None - cancel_on_interruption = True if tool.category == ToolCategory.END_CALL.value: - cancel_on_interruption = False handler = self._create_end_call_handler(tool, function_name) elif tool.category == ToolCategory.TRANSFER_CALL.value: timeout_secs = 120.0 - cancel_on_interruption = False handler = self._create_transfer_call_handler(tool, function_name) else: handler = self._create_http_tool_handler(tool, function_name) - return handler, timeout_secs, cancel_on_interruption + return handler, timeout_secs def _register_calculator_handler(self) -> None: """Register the built-in calculator function with the LLM.""" @@ -335,7 +321,7 @@ class CustomToolManager: tool=tool, arguments=function_call_params.arguments, call_context_vars=self._engine._call_context_vars, - organization_id=self._organization_id, + organization_id=await self.get_organization_id(), ) await function_call_params.result_callback(result) diff --git a/api/services/workflow/pipecat_engine_variable_extractor.py b/api/services/workflow/pipecat_engine_variable_extractor.py index 53996cd..2853403 100644 --- a/api/services/workflow/pipecat_engine_variable_extractor.py +++ b/api/services/workflow/pipecat_engine_variable_extractor.py @@ -5,12 +5,12 @@ from typing import TYPE_CHECKING, Any, List from loguru import logger from opentelemetry import trace +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.utils.tracing.service_attributes import add_llm_span_attributes from api.services.gen_ai.json_parser import parse_llm_json from api.services.pipecat.tracing_config import ensure_tracing from api.services.workflow.dto import ExtractionVariableDTO -from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.utils.tracing.service_attributes import add_llm_span_attributes if TYPE_CHECKING: from api.services.workflow.pipecat_engine import PipecatEngine diff --git a/api/services/workflow/qa/analysis.py b/api/services/workflow/qa/analysis.py index b0a171e..0afb2e1 100644 --- a/api/services/workflow/qa/analysis.py +++ b/api/services/workflow/qa/analysis.py @@ -4,6 +4,7 @@ import json from typing import Any from loguru import logger +from pipecat.processors.aggregators.llm_context import LLMContext from api.db.models import WorkflowRunModel from api.services.gen_ai.json_parser import parse_llm_json @@ -26,7 +27,6 @@ from api.services.workflow.qa.tracing import ( setup_langfuse_parent_context, ) from api.utils.template_renderer import render_template -from pipecat.processors.aggregators.llm_context import LLMContext async def _run_llm_inference( diff --git a/api/services/workflow/qa/node_summary.py b/api/services/workflow/qa/node_summary.py index 5896f4c..aaeb7d3 100644 --- a/api/services/workflow/qa/node_summary.py +++ b/api/services/workflow/qa/node_summary.py @@ -3,6 +3,7 @@ from typing import Any from loguru import logger +from pipecat.processors.aggregators.llm_context import LLMContext from api.db import db_client from api.db.models import WorkflowRunModel @@ -10,7 +11,6 @@ from api.services.pipecat.service_factory import create_llm_service_from_provide from api.services.workflow.dto import NodeType, QANodeData from api.services.workflow.qa.llm_config import resolve_llm_config from api.services.workflow.qa.tracing import create_node_summary_trace -from pipecat.processors.aggregators.llm_context import LLMContext NODE_SUMMARY_SYSTEM_PROMPT = ( "You are analyzing a voice AI agent script. This is only a part of a larger script. " diff --git a/api/services/workflow/qa/tracing.py b/api/services/workflow/qa/tracing.py index d3a5ff1..58a0843 100644 --- a/api/services/workflow/qa/tracing.py +++ b/api/services/workflow/qa/tracing.py @@ -78,7 +78,6 @@ def add_qa_span_to_trace( return try: from opentelemetry import trace as otel_trace - from pipecat.utils.tracing.service_attributes import add_llm_span_attributes tracer = otel_trace.get_tracer("pipecat") @@ -122,9 +121,9 @@ def create_node_summary_trace( try: from opentelemetry import trace as otel_trace from opentelemetry.context import Context + from pipecat.utils.tracing.service_attributes import add_llm_span_attributes from api.services.pipecat.tracing_config import ensure_tracing - from pipecat.utils.tracing.service_attributes import add_llm_span_attributes if not ensure_tracing(): return None diff --git a/api/tasks/run_integrations.py b/api/tasks/run_integrations.py index da87413..d73a39b 100644 --- a/api/tasks/run_integrations.py +++ b/api/tasks/run_integrations.py @@ -5,6 +5,8 @@ from typing import Any, Dict, Optional import httpx from loguru import logger +from pipecat.utils.enums import EndTaskReason +from pipecat.utils.run_context import set_current_org_id, set_current_run_id from pydantic import ValidationError from api.constants import BACKEND_API_ENDPOINT @@ -21,8 +23,6 @@ from api.services.workflow.dto import ( from api.services.workflow.qa import run_per_node_qa_analysis from api.utils.credential_auth import build_auth_header from api.utils.template_renderer import render_template -from pipecat.utils.enums import EndTaskReason -from pipecat.utils.run_context import set_current_org_id, set_current_run_id def _should_skip_qa( diff --git a/api/tasks/s3_upload.py b/api/tasks/s3_upload.py index c18a9b3..b2086c0 100644 --- a/api/tasks/s3_upload.py +++ b/api/tasks/s3_upload.py @@ -2,12 +2,12 @@ import os from typing import Optional from loguru import logger +from pipecat.utils.run_context import set_current_run_id from api.db import db_client from api.services.pricing.workflow_run_cost import calculate_workflow_run_cost from api.services.storage import get_current_storage_backend, storage_fs from api.tasks.run_integrations import run_integrations_post_workflow_run -from pipecat.utils.run_context import set_current_run_id async def upload_voicemail_audio_to_s3( diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 0056ff2..6cf047d 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -9,7 +9,7 @@ the root api/conftest.py. This module provides lightweight, non-DB fixtures: from dataclasses import dataclass, field from typing import Any, Dict, Optional -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock, patch import pytest @@ -123,13 +123,28 @@ class MockToolModel: @pytest.fixture def mock_engine(): - """Create a mock PipecatEngine.""" + """Create a mock PipecatEngine. + + Binds the real `_get_organization_id` method so the fetch + cache logic + runs against a patched `db_client.get_organization_id_by_workflow_run_id` + (returns org_id=1) for the duration of the fixture. + """ + from api.services.workflow.pipecat_engine import PipecatEngine + engine = Mock() engine._workflow_run_id = 1 engine._call_context_vars = {"customer_name": "John Doe"} + engine._organization_id = None + engine._get_organization_id = PipecatEngine._get_organization_id.__get__(engine) engine.llm = Mock() engine.llm.register_function = Mock() - return engine + + with patch( + "api.db:db_client.get_organization_id_by_workflow_run_id", + new_callable=AsyncMock, + return_value=1, + ): + yield engine @pytest.fixture diff --git a/api/tests/integrations/__init__.py b/api/tests/integrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/integrations/_run_pipeline_helpers.py b/api/tests/integrations/_run_pipeline_helpers.py new file mode 100644 index 0000000..0591c09 --- /dev/null +++ b/api/tests/integrations/_run_pipeline_helpers.py @@ -0,0 +1,249 @@ +"""Shared scaffolding for ``_run_pipeline`` integration tests. + +Both ``test_run_pipeline.py`` and ``test_run_pipeline_text_greeting.py`` +drive the real ``_run_pipeline`` end-to-end with the same set of external +boundaries patched out (STT/LLM/TTS factories, S3 recording fetcher, +PostHog publisher, ARQ enqueuer, real-time feedback observer). This +module centralises that scaffolding so each test only declares the bits +that differ — its workflow definition and any preconfigured mocks. + +Provided here: + +- ``USER_CONFIGURATION``: a minimal user-configuration dict with valid + provider/model values; the keys themselves are dummy. +- ``PassthroughProcessor``: an STT stand-in that forwards frames as-is. +- ``NoopFeedbackObserver``: a ``RealtimeFeedbackObserver`` stand-in with + no WebSocket / clock-task side effects. +- ``patch_run_pipeline_externals``: ``contextmanager`` that applies the + full patch set and captures the constructed ``PipelineTask`` for the + caller. Optional ``llm`` / ``tts`` arguments inject preconfigured + mocks; otherwise blank ``MockLLMService`` / ``MockTTSService`` + instances are constructed per-call. +- ``create_workflow_run_rows``: helper that creates the org / user / + user-configuration / workflow / workflow-run rows for an integration + test. Each test wires this through its own thin fixture so the + workflow definition stays local to the test. +""" + +from contextlib import ExitStack, contextmanager +from typing import Any +from unittest.mock import AsyncMock, patch + +from pipecat.frames.frames import Frame +from pipecat.observers.base_observer import BaseObserver +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + +from api.db.models import OrganizationModel, UserModel +from api.enums import WorkflowRunMode +from pipecat.tests import MockLLMService, MockTTSService + +USER_CONFIGURATION: dict[str, Any] = { + "is_realtime": False, + "stt": { + "provider": "deepgram", + "model": "nova-3", + "api_key": "test-key", + }, + "tts": { + "provider": "cartesia", + "model": "sonic-2", + "api_key": "test-key", + "voice_id": "test-voice", + }, + "llm": { + "provider": "openai", + "model": "gpt-4.1", + "api_key": "test-key", + }, +} + + +class PassthroughProcessor(FrameProcessor): + """Stand-in for the STT processor: forwards every frame untouched.""" + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + await self.push_frame(frame, direction) + + +class NoopFeedbackObserver(BaseObserver): + """Stand-in for ``RealtimeFeedbackObserver``: no WS / no clock task.""" + + def __init__(self, *_args, **_kwargs): + super().__init__() + + async def cleanup(self): + pass + + +@contextmanager +def patch_run_pipeline_externals( + captured_task: list, + *, + llm: MockLLMService | None = None, + tts: MockTTSService | None = None, +): + """Patch the externally-talking pieces of ``_run_pipeline`` and capture + the constructed ``PipelineTask`` so tests can drive it from outside. + + Args: + captured_task: A list the constructed ``PipelineTask`` is appended + to. Tests read ``captured_task[0]`` to get a handle on the task + (to wait on its start event, queue frames, cancel it, etc.). + llm: Optional pre-built ``MockLLMService``. When given, every call + to ``create_llm_service`` returns this same instance (so the + test can inspect its ``mock_steps`` / ``current_step``). + When ``None``, a blank ``MockLLMService`` is constructed. + tts: Optional pre-built ``MockTTSService``. Same semantics as + ``llm``: pass an instance to share state with the test, or + ``None`` to use a fresh one. + """ + from api.services.pipecat import pipeline_builder as _pipeline_builder + + original_create_task = _pipeline_builder.create_pipeline_task + + def _capture_task(*args, **kwargs): + task = original_create_task(*args, **kwargs) + captured_task.append(task) + return task + + def _llm_factory(*_args, **_kwargs): + return llm if llm is not None else MockLLMService(api_key="test") + + def _tts_factory(*_args, **_kwargs): + return tts if tts is not None else MockTTSService() + + with ExitStack() as stack: + # Replace service factories with in-process test doubles. + stack.enter_context( + patch( + "api.services.pipecat.run_pipeline.create_llm_service", + _llm_factory, + ) + ) + stack.enter_context( + patch( + "api.services.pipecat.run_pipeline.create_stt_service", + lambda *_args, **_kwargs: PassthroughProcessor(), + ) + ) + stack.enter_context( + patch( + "api.services.pipecat.run_pipeline.create_tts_service", + _tts_factory, + ) + ) + # S3 — the recording fetcher would otherwise resolve org-scoped recordings. + stack.enter_context( + patch( + "api.services.pipecat.run_pipeline.create_recording_audio_fetcher", + lambda *_args, **_kwargs: AsyncMock(return_value=None), + ) + ) + # External fire-and-forget integrations. + stack.enter_context( + patch( + "api.services.pipecat.event_handlers._capture_call_event", + new=AsyncMock(), + ) + ) + stack.enter_context( + patch( + "api.services.pipecat.event_handlers.enqueue_job", + new=AsyncMock(), + ) + ) + # Skip the real-time feedback observer (WebSocket / log-buffer streaming). + stack.enter_context( + patch( + "api.services.pipecat.run_pipeline.RealtimeFeedbackObserver", + NoopFeedbackObserver, + ) + ) + # Disposition mapper would otherwise call out to the LLM. + stack.enter_context( + patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="completed", + ) + ) + # Capture the PipelineTask so the test can drive it from outside. + stack.enter_context( + patch( + "api.services.pipecat.run_pipeline.create_pipeline_task", + side_effect=_capture_task, + ) + ) + yield + + +async def create_workflow_run_rows( + db_session, + async_session, + *, + workflow_definition: dict, + name_prefix: str, + provider_id_suffix: str, +): + """Create org / user / user-configuration / workflow / workflow-run rows + in the test database for a ``_run_pipeline`` integration test. + + Args: + db_session: The patched ``DBClient`` from the ``db_session`` fixture. + async_session: The raw ``AsyncSession`` from the ``async_session`` + fixture (used to add the org/user rows directly). + workflow_definition: The dict that becomes + ``WorkflowModel.workflow_definition`` and the V1 workflow_json. + name_prefix: Used to build human-readable workflow / run names. + provider_id_suffix: Used to generate unique ``provider_id`` values + for the org and user rows so concurrent or repeated test runs + don't collide. + + Returns: + Tuple of (workflow_run, user, workflow). + """ + from api.schemas.user_configuration import UserConfiguration + + org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}") + async_session.add(org) + await async_session.flush() + + user = UserModel( + provider_id=f"test-user-{provider_id_suffix}", + selected_organization_id=org.id, + ) + async_session.add(user) + await async_session.flush() + + await db_session.update_user_configuration( + user_id=user.id, + configuration=UserConfiguration.model_validate(USER_CONFIGURATION), + ) + + workflow = await db_session.create_workflow( + name=f"{name_prefix} Workflow", + workflow_definition=workflow_definition, + user_id=user.id, + organization_id=org.id, + ) + + workflow_run = await db_session.create_workflow_run( + name=f"{name_prefix} Run", + workflow_id=workflow.id, + mode=WorkflowRunMode.SMALLWEBRTC.value, + user_id=user.id, + ) + + return workflow_run, user, workflow + + +# Keep the module's public surface explicit so ``import *`` doesn't grab +# transitive imports. +__all__ = [ + "USER_CONFIGURATION", + "PassthroughProcessor", + "NoopFeedbackObserver", + "patch_run_pipeline_externals", + "create_workflow_run_rows", +] diff --git a/api/tests/integrations/test_run_pipeline.py b/api/tests/integrations/test_run_pipeline.py new file mode 100644 index 0000000..9a87aa1 --- /dev/null +++ b/api/tests/integrations/test_run_pipeline.py @@ -0,0 +1,134 @@ +"""Integration tests for ``api.services.pipecat.run_pipeline._run_pipeline``. + +Drives the actual ``_run_pipeline`` against the test database with real +DB rows (organization, user, user configuration, workflow, workflow run) +and pipecat's real ``MockTransport`` / ``Pipeline`` / ``PipelineTask``. +The only patches are for things that talk to genuinely external systems; +those are applied via ``patch_run_pipeline_externals`` from the shared +helpers module. + +Verifies that the wiring done by ``_run_pipeline`` (in particular +``register_event_handlers``) produces the right behaviour end-to-end: +``maybe_trigger_initial_response`` fires (``engine.set_node`` runs), and +on shutdown the workflow run is persisted with the expected state, +completion flag, and ``gathered_context`` entries. +""" + +import asyncio + +import pytest +from pipecat.tests.mock_transport import MockTransport +from pipecat.transports.base_transport import TransportParams + +from api.enums import WorkflowRunMode, WorkflowRunState +from api.services.pipecat.audio_config import create_audio_config +from api.services.pipecat.run_pipeline import _run_pipeline +from api.tests.integrations._run_pipeline_helpers import ( + create_workflow_run_rows, + patch_run_pipeline_externals, +) + +WORKFLOW_DEFINITION = { + "nodes": [ + { + "id": "start", + "type": "startCall", + "position": {"x": 0, "y": 0}, + "data": { + "name": "Start", + "prompt": "You are a helpful assistant. Greet the user briefly.", + "is_start": True, + "allow_interrupt": False, + "add_global_prompt": False, + }, + }, + { + "id": "end", + "type": "endCall", + "position": {"x": 0, "y": 200}, + "data": { + "name": "End", + "prompt": "End the call politely.", + "is_end": True, + "allow_interrupt": False, + "add_global_prompt": False, + }, + }, + ], + "edges": [ + { + "id": "start-end", + "source": "start", + "target": "end", + "data": {"label": "End", "condition": "When the user wants to end."}, + } + ], +} + + +@pytest.fixture +async def workflow_run_setup(db_session, async_session): + """Create org/user/user_configuration/workflow/workflow_run rows in the + test database. Returns (workflow_run, user, workflow).""" + return await create_workflow_run_rows( + db_session, + async_session, + workflow_definition=WORKFLOW_DEFINITION, + name_prefix="Event Handler Integration", + provider_id_suffix="event-handlers", + ) + + +@pytest.mark.asyncio +async def test_run_pipeline_fires_initial_response_and_completes_run( + workflow_run_setup, db_session +): + """End-to-end: _run_pipeline boots, register_event_handlers wires up, + on_pipeline_started + on_client_connected both fire, the initial + response is triggered (set_node), and on_pipeline_finished updates + the workflow_run row to COMPLETED.""" + workflow_run, user, workflow = workflow_run_setup + transport = MockTransport( + TransportParams(audio_in_enabled=True, audio_out_enabled=True) + ) + + captured_task: list = [] + audio_config = create_audio_config(WorkflowRunMode.SMALLWEBRTC.value) + with patch_run_pipeline_externals(captured_task): + run_coro = _run_pipeline( + transport=transport, + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + user_id=user.id, + audio_config=audio_config, + user_provider_id=user.provider_id, + ) + run_task = asyncio.create_task(run_coro) + + # Wait until create_pipeline_task is invoked. Surface any + # exception from _run_pipeline immediately rather than swallowing + # it during the wait loop. + for _ in range(60): + if captured_task or run_task.done(): + break + await asyncio.sleep(0.05) + if run_task.done() and not captured_task: + run_task.result() # re-raise the failure + assert captured_task, "create_pipeline_task was never invoked" + pipeline_task = captured_task[0] + await asyncio.wait_for(pipeline_task._pipeline_start_event.wait(), timeout=3.0) + # Let the initial response handler (set_node, queue LLMContextFrame) + # complete before tearing things down. + await asyncio.sleep(0.1) + await pipeline_task.cancel() + await asyncio.wait_for(run_task, timeout=5.0) + + # Verify the run was completed end-to-end via the real on_pipeline_finished + # handler — DB side effects, not mock assertions. + refreshed = await db_session.get_workflow_run_by_id(workflow_run.id) + assert refreshed.is_completed is True + assert refreshed.state == WorkflowRunState.COMPLETED.value + # set_node("start") populates "nodes_visited" via _gathered_context, and + # on_pipeline_finished merges call_tags into gathered_context. + assert "Start" in refreshed.gathered_context.get("nodes_visited", []) + assert "call_tags" in refreshed.gathered_context diff --git a/api/tests/integrations/test_run_pipeline_text_greeting.py b/api/tests/integrations/test_run_pipeline_text_greeting.py new file mode 100644 index 0000000..0da7bf8 --- /dev/null +++ b/api/tests/integrations/test_run_pipeline_text_greeting.py @@ -0,0 +1,289 @@ +"""Integration test for the text-greeting flow through ``_run_pipeline``. + +Drives the full pipeline produced by ``_run_pipeline`` against the test +database with a workflow whose start node has a text greeting configured. +The flow under test: + +1. ``maybe_trigger_initial_response`` (in ``event_handlers.py``) sees a + text greeting and queues ``TTSSpeakFrame(greeting)``. +2. ``MockTTSService`` synthesises audio for the greeting; the real + ``MediaSender`` machinery in ``MockOutputTransport`` emits + ``BotStartedSpeakingFrame`` and ``BotStoppedSpeakingFrame``. +3. The TTS service emits an ``LLMAssistantPushAggregationFrame`` after + ``TTSStoppedFrame``, so the greeting is appended to the assistant + context by ``LLMAssistantAggregator``. +4. We then push a ``TranscriptionFrame`` into the pipeline. After the + user-turn-stop timeout, ``LLMUserAggregator`` pushes a context frame + to the LLM, ``MockLLMService`` returns an ``end_call`` tool call, and + the engine's transition function moves to the end node and calls + ``end_call_with_reason``. +5. ``on_pipeline_finished`` records the run as COMPLETED. + +External boundaries are patched via ``patch_run_pipeline_externals`` +from the shared helpers module. Preconfigured ``MockLLMService`` / +``MockTTSService`` instances are passed in so the end_call response is +deterministic and the synthesised audio length is short. +""" + +import asyncio + +import pytest +from pipecat.frames.frames import TranscriptionFrame +from pipecat.tests.mock_transport import MockTransport +from pipecat.transports.base_transport import TransportParams +from pipecat.utils.time import time_now_iso8601 + +from api.enums import WorkflowRunMode, WorkflowRunState +from api.services.pipecat.audio_config import create_audio_config +from api.services.pipecat.run_pipeline import _run_pipeline +from api.tests.integrations._run_pipeline_helpers import ( + create_workflow_run_rows, + patch_run_pipeline_externals, +) +from pipecat.tests import MockLLMService, MockTTSService + +GREETING_TEXT = ( + "Thanks for calling Happy Feet, this is Sarah. How can I help you today?" +) + +WORKFLOW_DEFINITION = { + "nodes": [ + { + "id": "start", + "type": "startCall", + "position": {"x": 0, "y": 0}, + "data": { + "name": "Start", + "prompt": "You are Sarah. Help the caller and end the call when they ask.", + "is_start": True, + "allow_interrupt": False, + "add_global_prompt": False, + "greeting": GREETING_TEXT, + "greeting_type": "text", + }, + }, + { + "id": "end", + "type": "endCall", + "position": {"x": 0, "y": 200}, + "data": { + "name": "End", + "prompt": "End the call politely.", + "is_end": True, + "allow_interrupt": False, + "add_global_prompt": False, + }, + }, + ], + "edges": [ + { + "id": "start-end", + "source": "start", + "target": "end", + "data": {"label": "End Call", "condition": "When the user wants to end."}, + } + ], +} + +# Hard cap on the entire test. Without this, a hung pipeline would keep the +# pytest worker alive indefinitely (the harness has no pytest-timeout plugin). +TEST_HARD_TIMEOUT_SECONDS = 25.0 + + +@pytest.fixture +async def workflow_run_setup(db_session, async_session): + """Create org/user/user_configuration/workflow/workflow_run rows. The + workflow's start node is configured with a text greeting.""" + return await create_workflow_run_rows( + db_session, + async_session, + workflow_definition=WORKFLOW_DEFINITION, + name_prefix="Text Greeting Integration", + provider_id_suffix="text-greeting", + ) + + +def _greeting_in_assistant_context(context) -> bool: + """Return True if the greeting text has been appended to the assistant context.""" + for message in context.get_messages(): + if isinstance(message, dict) and message.get("role") == "assistant": + content = message.get("content") or "" + if GREETING_TEXT in content: + return True + return False + + +def _find_processor_by_class_name(pipeline_task, class_name: str): + """Walk every processor reachable from the task's pipeline (including nested + sub-pipelines) and return the first one whose class name matches.""" + visited: set[int] = set() + stack = [pipeline_task._pipeline] + while stack: + processor = stack.pop() + if id(processor) in visited: + continue + visited.add(id(processor)) + if processor.__class__.__name__ == class_name: + return processor + sub = getattr(processor, "_processors", None) + if sub: + stack.extend(sub) + return None + + +async def _wait_for(predicate, *, timeout: float, interval: float = 0.05) -> bool: + """Poll ``predicate`` (sync callable returning bool) until it returns True + or the timeout elapses. Returns the final predicate value.""" + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: + if predicate(): + return True + await asyncio.sleep(interval) + return predicate() + + +async def _run_test_body(workflow_run_setup, db_session) -> None: + workflow_run, user, workflow = workflow_run_setup + + # Prepare the LLM with one step: the end_call function call. + # Edge label "End Call" maps to function name "end_call". + end_call_chunks = MockLLMService.create_function_call_chunks( + function_name="end_call", + arguments={}, + tool_call_id="call_end_1", + ) + llm = MockLLMService(mock_steps=[end_call_chunks], chunk_delay=0.001) + + # Short audio greeting so the bot finishes speaking quickly in tests. + tts = MockTTSService(mock_audio_duration_ms=50, frame_delay=0) + + transport = MockTransport( + TransportParams(audio_in_enabled=True, audio_out_enabled=True) + ) + + captured_task: list = [] + audio_config = create_audio_config(WorkflowRunMode.SMALLWEBRTC.value) + pipeline_task = None + + try: + with patch_run_pipeline_externals(captured_task, llm=llm, tts=tts): + run_coro = _run_pipeline( + transport=transport, + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + user_id=user.id, + audio_config=audio_config, + user_provider_id=user.provider_id, + ) + run_task = asyncio.create_task(run_coro) + + for _ in range(60): + if captured_task or run_task.done(): + break + await asyncio.sleep(0.05) + if run_task.done() and not captured_task: + run_task.result() + assert captured_task, "create_pipeline_task was never invoked" + pipeline_task = captured_task[0] + + await asyncio.wait_for( + pipeline_task._pipeline_start_event.wait(), timeout=3.0 + ) + + # Locate the assistant aggregator's LLM context (downstream of TTS). + # The PipelineTask wraps the user's pipeline inside another Pipeline, + # so we walk the tree recursively. + assistant_aggregator = _find_processor_by_class_name( + pipeline_task, "LLMAssistantAggregator" + ) + assert assistant_aggregator is not None, ( + "LLMAssistantAggregator not found in pipeline" + ) + context = assistant_aggregator.context + + # Wait for the greeting to be appended to the assistant context. The + # TTSSpeakFrame -> audio frames -> BotStoppedSpeaking -> assistant + # aggregation push chain runs through the real pipeline. + appeared = await _wait_for( + lambda: _greeting_in_assistant_context(context), timeout=5.0 + ) + assert appeared, ( + "Greeting was not appended to the assistant context. " + f"Messages: {context.get_messages()}" + ) + + # The LLM must not have been invoked yet — the greeting bypasses + # the LLM entirely (goes straight to TTS via TTSSpeakFrame). + assert llm.get_current_step() == 0, ( + f"LLM should not have run yet; current_step={llm.get_current_step()}" + ) + + # Now simulate the user replying. SpeechTimeoutUserTurnStopStrategy + # (default 0.6s) ends the user turn, which triggers an LLM run; + # the LLM returns end_call; the transition function moves to the + # end node and ends the call. + await pipeline_task.queue_frame( + TranscriptionFrame( + text="I want to end the call now please.", + user_id="test-user", + timestamp=time_now_iso8601(), + ) + ) + + # Wait for the run to complete. + await asyncio.wait_for(run_task, timeout=10.0) + + # Outside the patch ctx so the assertions exercise real DB state. + # The first LLM run produces the end_call; the engine then transitions + # to the End node and triggers a second generation (which is empty — + # mock_steps[1] is unset). What matters is that at least one run + # happened, i.e. the user transcript actually drove the LLM. + assert llm.get_current_step() >= 1, ( + f"Expected at least one LLM generation; got step={llm.get_current_step()}" + ) + + refreshed = await db_session.get_workflow_run_by_id(workflow_run.id) + assert refreshed.is_completed is True + assert refreshed.state == WorkflowRunState.COMPLETED.value + nodes_visited = refreshed.gathered_context.get("nodes_visited", []) + assert "Start" in nodes_visited + assert "End" in nodes_visited + finally: + # Best-effort cleanup so a partially-run pipeline doesn't leak tasks + # past the test boundary. + if pipeline_task is not None and not pipeline_task.has_finished(): + try: + await asyncio.wait_for(pipeline_task.cancel(), timeout=3.0) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_text_greeting_speaks_then_user_transcript_triggers_end_call( + workflow_run_setup, db_session +): + """End-to-end: + + - ``maybe_trigger_initial_response`` queues ``TTSSpeakFrame`` for the + start-node text greeting. + - ``MockTTSService`` synthesises audio; ``MockOutputTransport`` emits + bot speaking events; the assistant aggregator appends the greeting + to the context after the TTS turn ends. + - We push a ``TranscriptionFrame`` into the pipeline. After the user + turn stop timeout, ``MockLLMService`` returns an ``end_call`` tool + call which transitions to the end node and ends the run. + + The whole body is bounded by ``TEST_HARD_TIMEOUT_SECONDS`` so a hung + pipeline fails the test rather than wedging the test runner. + """ + try: + await asyncio.wait_for( + _run_test_body(workflow_run_setup, db_session), + timeout=TEST_HARD_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError as e: + raise AssertionError( + f"Test exceeded hard timeout of {TEST_HARD_TIMEOUT_SECONDS}s — " + "pipeline likely hung. Check earlier debug logs for the last frame " + "to reach the pipeline." + ) from e diff --git a/api/tests/test_custom_tools.py b/api/tests/test_custom_tools.py index 9275b34..14e9db4 100644 --- a/api/tests/test_custom_tools.py +++ b/api/tests/test_custom_tools.py @@ -12,12 +12,6 @@ from typing import Any, Dict from unittest.mock import AsyncMock, Mock, patch import pytest - -from api.services.workflow.pipecat_engine_custom_tools import get_function_schema -from api.services.workflow.tools.custom_tool import ( - execute_http_tool, - tool_to_function_schema, -) from pipecat.adapters.schemas.tools_schema import ToolsSchema from pipecat.frames.frames import ( FunctionCallInProgressFrame, @@ -31,6 +25,12 @@ from pipecat.frames.frames import ( from pipecat.pipeline.pipeline import Pipeline from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.services.llm_service import FunctionCallParams + +from api.services.workflow.pipecat_engine_custom_tools import get_function_schema +from api.services.workflow.tools.custom_tool import ( + execute_http_tool, + tool_to_function_schema, +) from pipecat.tests import MockLLMService, run_test @@ -720,13 +720,19 @@ class TestCustomToolManagerUnit: @pytest.mark.asyncio async def test_get_tool_schemas_returns_correct_format(self): """Test that get_tool_schemas returns FunctionSchema objects.""" - from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager from pipecat.adapters.schemas.function_schema import FunctionSchema # Create a mock engine + from api.services.workflow.pipecat_engine import PipecatEngine + from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager + mock_engine = Mock() mock_engine._workflow_run_id = 1 mock_engine._call_context_vars = {} + mock_engine._organization_id = None + mock_engine._get_organization_id = PipecatEngine._get_organization_id.__get__( + mock_engine + ) manager = CustomToolManager(mock_engine) @@ -754,29 +760,31 @@ class TestCustomToolManagerUnit: }, ) - with patch( - "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" - ) as mock_get_org: - mock_get_org.return_value = 1 - - with patch( + with ( + patch( "api.services.workflow.pipecat_engine_custom_tools.db_client" - ) as mock_db: - mock_db.get_tools_by_uuids = AsyncMock(return_value=[mock_tool]) + ) as mock_db, + patch( + "api.db:db_client.get_organization_id_by_workflow_run_id", + new_callable=AsyncMock, + return_value=1, + ), + ): + mock_db.get_tools_by_uuids = AsyncMock(return_value=[mock_tool]) - schemas = await manager.get_tool_schemas(["uuid-1"]) + schemas = await manager.get_tool_schemas(["uuid-1"]) - assert len(schemas) == 1 - schema = schemas[0] + assert len(schemas) == 1 + schema = schemas[0] - # Schema should be a FunctionSchema object - assert isinstance(schema, FunctionSchema) + # Schema should be a FunctionSchema object + assert isinstance(schema, FunctionSchema) - # FunctionSchema should have correct attributes - assert schema.name == "test_tool" - assert "param1" in schema.properties - assert schema.properties["param1"]["type"] == "string" - assert "param1" in schema.required + # FunctionSchema should have correct attributes + assert schema.name == "test_tool" + assert "param1" in schema.properties + assert schema.properties["param1"]["type"] == "string" + assert "param1" in schema.required @pytest.mark.asyncio async def test_register_handlers_creates_working_handler(self): @@ -792,9 +800,15 @@ class TestCustomToolManagerUnit: mock_llm.register_function = capture_register + from api.services.workflow.pipecat_engine import PipecatEngine + mock_engine = Mock() mock_engine._workflow_run_id = 1 mock_engine._call_context_vars = {} + mock_engine._organization_id = None + mock_engine._get_organization_id = PipecatEngine._get_organization_id.__get__( + mock_engine + ) mock_engine.llm = mock_llm manager = CustomToolManager(mock_engine) @@ -815,20 +829,22 @@ class TestCustomToolManagerUnit: }, ) - with patch( - "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" - ) as mock_get_org: - mock_get_org.return_value = 1 - - with patch( + with ( + patch( "api.services.workflow.pipecat_engine_custom_tools.db_client" - ) as mock_db: - mock_db.get_tools_by_uuids = AsyncMock(return_value=[mock_tool]) + ) as mock_db, + patch( + "api.db:db_client.get_organization_id_by_workflow_run_id", + new_callable=AsyncMock, + return_value=1, + ), + ): + mock_db.get_tools_by_uuids = AsyncMock(return_value=[mock_tool]) - await manager.register_handlers(["uuid-1"]) + await manager.register_handlers(["uuid-1"]) - # Verify handler was registered - assert "api_call" in registered_handlers + # Verify handler was registered + assert "api_call" in registered_handlers # Now test that the handler works handler = registered_handlers["api_call"] diff --git a/api/tests/test_custom_tools_context_integration.py b/api/tests/test_custom_tools_context_integration.py index fcdfc2b..db060fb 100644 --- a/api/tests/test_custom_tools_context_integration.py +++ b/api/tests/test_custom_tools_context_integration.py @@ -9,15 +9,15 @@ This module tests the full flow of: from unittest.mock import AsyncMock, patch import pytest +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.processors.aggregators.llm_context import LLMContext from api.services.workflow.pipecat_engine_custom_tools import ( CustomToolManager, get_function_schema, ) from api.tests.conftest import MockToolModel -from pipecat.adapters.schemas.function_schema import FunctionSchema -from pipecat.adapters.schemas.tools_schema import ToolsSchema -from pipecat.processors.aggregators.llm_context import LLMContext def _update_llm_context(context, system_message, functions): @@ -45,70 +45,65 @@ class TestCustomToolManagerContextIntegration: manager = CustomToolManager(mock_engine) with patch( - "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" - ) as mock_get_org: - mock_get_org.return_value = 1 + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools) - with patch( - "api.services.workflow.pipecat_engine_custom_tools.db_client" - ) as mock_db: - mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools) + # Get tool schemas via CustomToolManager - now returns FunctionSchema objects + tool_uuids = ["weather-uuid-123", "booking-uuid-456", "lookup-uuid-789"] + schemas = await manager.get_tool_schemas(tool_uuids) - # Get tool schemas via CustomToolManager - now returns FunctionSchema objects - tool_uuids = ["weather-uuid-123", "booking-uuid-456", "lookup-uuid-789"] - schemas = await manager.get_tool_schemas(tool_uuids) + # Verify schemas were returned as FunctionSchema objects + assert len(schemas) == 3 + assert all(isinstance(s, FunctionSchema) for s in schemas) - # Verify schemas were returned as FunctionSchema objects - assert len(schemas) == 3 - assert all(isinstance(s, FunctionSchema) for s in schemas) + # Create context with conversation history + context = LLMContext() + context.set_messages( + [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "I need to check the weather and book an appointment.", + }, + { + "role": "assistant", + "content": "I can help with both. Where would you like to check the weather?", + }, + {"role": "user", "content": "San Francisco"}, + ] + ) - # Create context with conversation history - context = LLMContext() - context.set_messages( - [ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": "I need to check the weather and book an appointment.", - }, - { - "role": "assistant", - "content": "I can help with both. Where would you like to check the weather?", - }, - {"role": "user", "content": "San Francisco"}, - ] - ) + # Update context with new system message and tools + # Now we can pass schemas directly since they're FunctionSchema objects + new_system = { + "role": "system", + "content": "You are a scheduling assistant with access to weather and booking tools.", + } + _update_llm_context(context, new_system, schemas) - # Update context with new system message and tools - # Now we can pass schemas directly since they're FunctionSchema objects - new_system = { - "role": "system", - "content": "You are a scheduling assistant with access to weather and booking tools.", - } - _update_llm_context(context, new_system, schemas) + # Verify context was updated correctly + messages = context.messages + assert len(messages) == 4 + assert ( + messages[0]["content"] + == "You are a scheduling assistant with access to weather and booking tools." + ) + assert messages[1]["role"] == "user" + assert messages[3]["content"] == "San Francisco" - # Verify context was updated correctly - messages = context.messages - assert len(messages) == 4 - assert ( - messages[0]["content"] - == "You are a scheduling assistant with access to weather and booking tools." - ) - assert messages[1]["role"] == "user" - assert messages[3]["content"] == "San Francisco" + # Verify tools were set + tools = context.tools + assert tools is not None + assert len(tools.standard_tools) == 3 - # Verify tools were set - tools = context.tools - assert tools is not None - assert len(tools.standard_tools) == 3 - - # Verify tool names - tool_names = {t.name for t in tools.standard_tools} - assert tool_names == { - "get_weather", - "book_appointment", - "customer_lookup", - } + # Verify tool names + tool_names = {t.name for t in tools.standard_tools} + assert tool_names == { + "get_weather", + "book_appointment", + "customer_lookup", + } @pytest.mark.asyncio async def test_tool_schemas_have_correct_properties( @@ -118,39 +113,32 @@ class TestCustomToolManagerContextIntegration: manager = CustomToolManager(mock_engine) with patch( - "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" - ) as mock_get_org: - mock_get_org.return_value = 1 + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools) - with patch( - "api.services.workflow.pipecat_engine_custom_tools.db_client" - ) as mock_db: - mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools) + schemas = await manager.get_tool_schemas( + ["weather-uuid-123", "booking-uuid-456"] + ) - schemas = await manager.get_tool_schemas( - ["weather-uuid-123", "booking-uuid-456"] - ) + # Find the booking schema - now using FunctionSchema attributes + booking_schema = next(s for s in schemas if s.name == "book_appointment") - # Find the booking schema - now using FunctionSchema attributes - booking_schema = next( - s for s in schemas if s.name == "book_appointment" - ) + # Verify parameter properties + assert "customer_name" in booking_schema.properties + assert "date" in booking_schema.properties + assert "time" in booking_schema.properties + assert "notes" in booking_schema.properties - # Verify parameter properties - assert "customer_name" in booking_schema.properties - assert "date" in booking_schema.properties - assert "time" in booking_schema.properties - assert "notes" in booking_schema.properties + # Verify types + assert booking_schema.properties["customer_name"]["type"] == "string" + assert booking_schema.properties["date"]["type"] == "string" - # Verify types - assert booking_schema.properties["customer_name"]["type"] == "string" - assert booking_schema.properties["date"]["type"] == "string" - - # Verify required - assert "customer_name" in booking_schema.required - assert "date" in booking_schema.required - assert "time" in booking_schema.required - assert "notes" not in booking_schema.required + # Verify required + assert "customer_name" in booking_schema.required + assert "date" in booking_schema.required + assert "time" in booking_schema.required + assert "notes" not in booking_schema.required @pytest.mark.asyncio async def test_context_update_with_builtin_and_custom_tools( @@ -160,67 +148,62 @@ class TestCustomToolManagerContextIntegration: manager = CustomToolManager(mock_engine) with patch( - "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" - ) as mock_get_org: - mock_get_org.return_value = 1 + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock( + return_value=[sample_tools[0]] + ) # Just weather - with patch( - "api.services.workflow.pipecat_engine_custom_tools.db_client" - ) as mock_db: - mock_db.get_tools_by_uuids = AsyncMock( - return_value=[sample_tools[0]] - ) # Just weather + # Get custom tool schemas - returns FunctionSchema objects + custom_schemas = await manager.get_tool_schemas(["weather-uuid-123"]) - # Get custom tool schemas - returns FunctionSchema objects - custom_schemas = await manager.get_tool_schemas(["weather-uuid-123"]) + # Create built-in function schemas (like calculator, timezone) + builtin_functions = [ + get_function_schema( + "safe_calculator", + "Evaluate a mathematical expression safely", + properties={ + "expression": { + "type": "string", + "description": "Mathematical expression to evaluate", + } + }, + required=["expression"], + ), + get_function_schema( + "get_current_time", + "Get the current time in a timezone", + properties={ + "timezone": { + "type": "string", + "description": "Timezone name (e.g., America/New_York)", + } + }, + required=["timezone"], + ), + ] - # Create built-in function schemas (like calculator, timezone) - builtin_functions = [ - get_function_schema( - "safe_calculator", - "Evaluate a mathematical expression safely", - properties={ - "expression": { - "type": "string", - "description": "Mathematical expression to evaluate", - } - }, - required=["expression"], - ), - get_function_schema( - "get_current_time", - "Get the current time in a timezone", - properties={ - "timezone": { - "type": "string", - "description": "Timezone name (e.g., America/New_York)", - } - }, - required=["timezone"], - ), - ] + # Combine built-in and custom functions - both are FunctionSchema objects + all_functions = builtin_functions + custom_schemas - # Combine built-in and custom functions - both are FunctionSchema objects - all_functions = builtin_functions + custom_schemas + # Update context + context = LLMContext() + context.set_messages([{"role": "system", "content": "Old prompt"}]) - # Update context - context = LLMContext() - context.set_messages([{"role": "system", "content": "Old prompt"}]) + new_system = { + "role": "system", + "content": "Assistant with calculator and weather tools", + } + _update_llm_context(context, new_system, all_functions) - new_system = { - "role": "system", - "content": "Assistant with calculator and weather tools", - } - _update_llm_context(context, new_system, all_functions) + # Verify all tools are present + tools = context.tools + assert len(tools.standard_tools) == 3 - # Verify all tools are present - tools = context.tools - assert len(tools.standard_tools) == 3 - - tool_names = {t.name for t in tools.standard_tools} - assert "safe_calculator" in tool_names - assert "get_current_time" in tool_names - assert "get_weather" in tool_names + tool_names = {t.name for t in tools.standard_tools} + assert "safe_calculator" in tool_names + assert "get_current_time" in tool_names + assert "get_weather" in tool_names @pytest.mark.asyncio async def test_context_preserves_function_call_history( @@ -230,65 +213,60 @@ class TestCustomToolManagerContextIntegration: manager = CustomToolManager(mock_engine) with patch( - "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" - ) as mock_get_org: - mock_get_org.return_value = 1 + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]]) - with patch( - "api.services.workflow.pipecat_engine_custom_tools.db_client" - ) as mock_db: - mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]]) + # Get schemas - returns FunctionSchema objects + schemas = await manager.get_tool_schemas(["weather-uuid-123"]) - # Get schemas - returns FunctionSchema objects - schemas = await manager.get_tool_schemas(["weather-uuid-123"]) + # Create context with function call history + context = LLMContext() + context.set_messages( + [ + {"role": "system", "content": "Old system prompt"}, + {"role": "user", "content": "What's the weather in NYC?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "New York, NY"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_123", + "content": '{"temperature": 72, "condition": "sunny"}', + }, + { + "role": "assistant", + "content": "The weather in NYC is 72°F and sunny!", + }, + ] + ) - # Create context with function call history - context = LLMContext() - context.set_messages( - [ - {"role": "system", "content": "Old system prompt"}, - {"role": "user", "content": "What's the weather in NYC?"}, - { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": "call_123", - "type": "function", - "function": { - "name": "get_weather", - "arguments": '{"location": "New York, NY"}', - }, - } - ], - }, - { - "role": "tool", - "tool_call_id": "call_123", - "content": '{"temperature": 72, "condition": "sunny"}', - }, - { - "role": "assistant", - "content": "The weather in NYC is 72°F and sunny!", - }, - ] - ) + new_system = {"role": "system", "content": "Updated weather assistant"} + _update_llm_context(context, new_system, schemas) - new_system = {"role": "system", "content": "Updated weather assistant"} - _update_llm_context(context, new_system, schemas) + messages = context.messages + # System + user + assistant(tool_call) + tool + assistant = 5 + assert len(messages) == 5 - messages = context.messages - # System + user + assistant(tool_call) + tool + assistant = 5 - assert len(messages) == 5 + # Verify function call messages are preserved + tool_call_msg = messages[2] + assert tool_call_msg["role"] == "assistant" + assert "tool_calls" in tool_call_msg - # Verify function call messages are preserved - tool_call_msg = messages[2] - assert tool_call_msg["role"] == "assistant" - assert "tool_calls" in tool_call_msg - - tool_result_msg = messages[3] - assert tool_result_msg["role"] == "tool" - assert tool_result_msg["tool_call_id"] == "call_123" + tool_result_msg = messages[3] + assert tool_result_msg["role"] == "tool" + assert tool_result_msg["tool_call_id"] == "call_123" @pytest.mark.asyncio async def test_empty_tool_list_does_not_set_tools(self, mock_engine): @@ -296,26 +274,21 @@ class TestCustomToolManagerContextIntegration: manager = CustomToolManager(mock_engine) with patch( - "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" - ) as mock_get_org: - mock_get_org.return_value = 1 + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=[]) - with patch( - "api.services.workflow.pipecat_engine_custom_tools.db_client" - ) as mock_db: - mock_db.get_tools_by_uuids = AsyncMock(return_value=[]) + schemas = await manager.get_tool_schemas([]) + assert schemas == [] - schemas = await manager.get_tool_schemas([]) - assert schemas == [] + context = LLMContext() + context.set_messages([{"role": "system", "content": "Old"}]) - context = LLMContext() - context.set_messages([{"role": "system", "content": "Old"}]) + new_system = {"role": "system", "content": "No tools available"} + _update_llm_context(context, new_system, []) - new_system = {"role": "system", "content": "No tools available"} - _update_llm_context(context, new_system, []) - - # Context should have updated message but no tools set - assert context.messages[0]["content"] == "No tools available" + # Context should have updated message but no tools set + assert context.messages[0]["content"] == "No tools available" @pytest.mark.asyncio async def test_numeric_and_boolean_parameter_types(self, mock_engine): @@ -357,33 +330,28 @@ class TestCustomToolManagerContextIntegration: manager = CustomToolManager(mock_engine) with patch( - "api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run" - ) as mock_get_org: - mock_get_org.return_value = 1 + "api.services.workflow.pipecat_engine_custom_tools.db_client" + ) as mock_db: + mock_db.get_tools_by_uuids = AsyncMock(return_value=[tool_with_types]) - with patch( - "api.services.workflow.pipecat_engine_custom_tools.db_client" - ) as mock_db: - mock_db.get_tools_by_uuids = AsyncMock(return_value=[tool_with_types]) + # Get schemas - returns FunctionSchema objects + schemas = await manager.get_tool_schemas(["order-uuid"]) + schema = schemas[0] - # Get schemas - returns FunctionSchema objects - schemas = await manager.get_tool_schemas(["order-uuid"]) - schema = schemas[0] + # Verify types using FunctionSchema attributes + assert schema.properties["item_id"]["type"] == "string" + assert schema.properties["quantity"]["type"] == "number" + assert schema.properties["express_shipping"]["type"] == "boolean" - # Verify types using FunctionSchema attributes - assert schema.properties["item_id"]["type"] == "string" - assert schema.properties["quantity"]["type"] == "number" - assert schema.properties["express_shipping"]["type"] == "boolean" + # Update context - pass schema directly + context = LLMContext() + context.set_messages([{"role": "system", "content": "Old"}]) + _update_llm_context( + context, {"role": "system", "content": "Order assistant"}, schemas + ) - # Update context - pass schema directly - context = LLMContext() - context.set_messages([{"role": "system", "content": "Old"}]) - _update_llm_context( - context, {"role": "system", "content": "Order assistant"}, schemas - ) - - # Verify tool was set with correct types - tool = context.tools.standard_tools[0] - assert tool.name == "place_order" - assert tool.properties["quantity"]["type"] == "number" - assert tool.properties["express_shipping"]["type"] == "boolean" + # Verify tool was set with correct types + tool = context.tools.standard_tools[0] + assert tool.name == "place_order" + assert tool.properties["quantity"]["type"] == "number" + assert tool.properties["express_shipping"]["type"] == "boolean" diff --git a/api/tests/test_pipecat_engine_context_update.py b/api/tests/test_pipecat_engine_context_update.py index 8ef0e0e..e22a575 100644 --- a/api/tests/test_pipecat_engine_context_update.py +++ b/api/tests/test_pipecat_engine_context_update.py @@ -14,18 +14,10 @@ result in the context when generating the next response. """ import asyncio -from typing import Any, Dict, List +from typing import List from unittest.mock import AsyncMock, patch import pytest - -from api.services.workflow.pipecat_engine import PipecatEngine -from api.services.workflow.workflow import WorkflowGraph -from api.tests.conftest import ( - AGENT_SYSTEM_PROMPT, - END_CALL_SYSTEM_PROMPT, - START_CALL_SYSTEM_PROMPT, -) from pipecat.frames.frames import LLMContextFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner @@ -35,75 +27,21 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMAssistantAggregatorParams, LLMContextAggregatorPair, ) -from pipecat.tests import MockLLMService, MockTTSService from pipecat.tests.mock_transport import MockTransport from pipecat.transports.base_transport import TransportParams - -class ContextCapturingMockLLM(MockLLMService): - """A MockLLMService that captures the context state at each generation. - - This allows us to verify that tool call results are present in the context - when the next LLM generation is triggered. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.captured_contexts: List[Dict[str, Any]] = [] - - async def _stream_chat_completions_universal_context(self, context): - """Override to capture context state before streaming chunks.""" - # Deep copy the messages to avoid mutation issues - messages_snapshot = [] - for msg in context.messages: - msg_copy = dict(msg) - # Copy content to avoid reference issues - if "content" in msg_copy: - msg_copy["content"] = ( - str(msg_copy["content"]) if msg_copy["content"] else None - ) - messages_snapshot.append(msg_copy) - - self.captured_contexts.append( - { - "step": self._current_step, - "messages": messages_snapshot, - "system_prompt": self._settings.system_instruction, - } - ) - - # Call parent implementation to stream the mock chunks - return await super()._stream_chat_completions_universal_context(context) - - def get_context_at_step(self, step: int) -> Dict[str, Any]: - """Get the captured context at a specific step (0-indexed).""" - for ctx in self.captured_contexts: - if ctx["step"] == step: - return ctx - return None - - def has_tool_call_result_at_step(self, step: int, function_name: str) -> bool: - """Check if a tool call result for the given function exists in context at step.""" - ctx = self.get_context_at_step(step) - if not ctx: - return False - - for msg in ctx["messages"]: - # Check for tool/function role messages - if msg.get("role") == "tool" and msg.get("name") == function_name: - return True - # Also check for tool_call_id which indicates a tool response - if msg.get("tool_call_id") and function_name in str(msg.get("name", "")): - return True - - return False - - def get_system_prompt_at_step(self, step: int) -> str: - """Get the system prompt from settings at a specific step.""" - ctx = self.get_context_at_step(step) - if ctx: - return ctx.get("system_prompt") or "" - return "" +from api.services.workflow.pipecat_engine import PipecatEngine +from api.services.workflow.workflow import WorkflowGraph +from api.tests.conftest import ( + AGENT_SYSTEM_PROMPT, + END_CALL_SYSTEM_PROMPT, + START_CALL_SYSTEM_PROMPT, +) +from pipecat.tests import ( + ContextCapturingMockLLM, + MockLLMService, + MockTTSService, +) async def run_pipeline_and_capture_context( @@ -142,7 +80,7 @@ async def run_pipeline_and_capture_context( context = LLMContext() # Add assistant context aggregator - assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + assistant_params = LLMAssistantAggregatorParams() context_aggregator = LLMContextAggregatorPair( context, assistant_params=assistant_params ) @@ -184,7 +122,7 @@ async def run_pipeline_and_capture_context( # Patch DB calls with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): diff --git a/api/tests/test_pipecat_engine_end_call.py b/api/tests/test_pipecat_engine_end_call.py index 1384150..a0f8ac1 100644 --- a/api/tests/test_pipecat_engine_end_call.py +++ b/api/tests/test_pipecat_engine_end_call.py @@ -23,6 +23,23 @@ from typing import Any, Dict, List from unittest.mock import AsyncMock, patch import pytest +from pipecat.frames.frames import Frame, LLMContextFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response_universal import ( + LLMAssistantAggregatorParams, + LLMContextAggregatorPair, + LLMUserAggregatorParams, +) +from pipecat.tests.mock_transport import MockTransport +from pipecat.transports.base_transport import TransportParams +from pipecat.turns.user_mute import ( + CallbackUserMuteStrategy, + MuteUntilFirstBotCompleteUserMuteStrategy, +) +from pipecat.utils.enums import EndTaskReason from api.enums import ToolCategory from api.services.workflow.dto import ( @@ -42,24 +59,7 @@ from api.services.workflow.pipecat_engine_variable_extractor import ( ) from api.services.workflow.workflow import WorkflowGraph from api.tests.conftest import END_CALL_SYSTEM_PROMPT, START_CALL_SYSTEM_PROMPT -from pipecat.frames.frames import Frame, LLMContextFrame -from pipecat.pipeline.pipeline import Pipeline -from pipecat.pipeline.runner import PipelineRunner -from pipecat.pipeline.task import PipelineParams, PipelineTask -from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.llm_response_universal import ( - LLMAssistantAggregatorParams, - LLMContextAggregatorPair, - LLMUserAggregatorParams, -) from pipecat.tests import MockLLMService, MockTTSService -from pipecat.tests.mock_transport import MockTransport -from pipecat.transports.base_transport import TransportParams -from pipecat.turns.user_mute import ( - CallbackUserMuteStrategy, - MuteUntilFirstBotCompleteUserMuteStrategy, -) -from pipecat.utils.enums import EndTaskReason class EndCallTestHelper: @@ -182,7 +182,7 @@ async def create_engine_with_tracking( engine.end_call_with_reason = tracked_end_call # Create context aggregator with user mute strategies (after engine so we can use its callback) - assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + assistant_params = LLMAssistantAggregatorParams() # Wrap should_mute_user to track calls original_should_mute_user = engine.should_mute_user @@ -265,7 +265,7 @@ class TestEndCallViaNodeTransition: # Patch DB calls and extraction manager with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): @@ -369,7 +369,7 @@ class TestEndCallViaNodeTransition: # Patch DB calls and extraction manager with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): @@ -468,7 +468,7 @@ class TestEndCallViaCustomTool: # Patch DB calls and extraction manager with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): @@ -560,7 +560,7 @@ class TestEndCallViaCustomTool: # Patch DB calls and extraction manager with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): @@ -638,7 +638,7 @@ class TestEndCallViaClientDisconnect: # Patch DB calls and extraction manager with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): @@ -729,7 +729,7 @@ class TestEndCallRaceConditions: # Patch DB calls and extraction manager with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): @@ -841,7 +841,7 @@ class TestEndCallRaceConditions: # Patch DB calls and extraction manager with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): @@ -937,7 +937,7 @@ class TestEndCallExtractionBehavior: # Patch DB calls and extraction manager with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): @@ -1061,7 +1061,7 @@ class TestEndCallExtractionBehavior: # Patch DB calls and extraction manager with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): diff --git a/api/tests/test_pipecat_engine_node_switch_with_user_speech.py b/api/tests/test_pipecat_engine_node_switch_with_user_speech.py index 205baa1..a19843b 100644 --- a/api/tests/test_pipecat_engine_node_switch_with_user_speech.py +++ b/api/tests/test_pipecat_engine_node_switch_with_user_speech.py @@ -15,9 +15,6 @@ import asyncio from unittest.mock import AsyncMock, patch import pytest - -from api.services.workflow.pipecat_engine import PipecatEngine -from api.services.workflow.workflow import WorkflowGraph from pipecat.frames.frames import ( Frame, FunctionCallResultFrame, @@ -36,7 +33,6 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMUserAggregatorParams, ) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.tests import MockLLMService, MockTTSService from pipecat.tests.mock_transport import MockTransport from pipecat.transports.base_transport import TransportParams from pipecat.turns.user_mute import ( @@ -52,6 +48,10 @@ from pipecat.turns.user_stop import ( from pipecat.turns.user_turn_strategies import UserTurnStrategies from pipecat.utils.time import time_now_iso8601 +from api.services.workflow.pipecat_engine import PipecatEngine +from api.services.workflow.workflow import WorkflowGraph +from pipecat.tests import MockLLMService, MockTTSService + class UserSpeechInjector(FrameProcessor): """Processor that injects user speaking frames on FunctionCallResultFrame. @@ -183,7 +183,7 @@ async def create_test_pipeline( ) # Create context aggregator with user and assistant params - assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + assistant_params = LLMAssistantAggregatorParams() context_aggregator = LLMContextAggregatorPair( context, assistant_params=assistant_params, user_params=user_params @@ -277,7 +277,7 @@ class TestNodeSwitchWithUserSpeech: # Patch DB calls with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): diff --git a/api/tests/test_pipecat_engine_tool_calls.py b/api/tests/test_pipecat_engine_tool_calls.py index 001de29..aef2df6 100644 --- a/api/tests/test_pipecat_engine_tool_calls.py +++ b/api/tests/test_pipecat_engine_tool_calls.py @@ -9,10 +9,6 @@ from typing import Any, Dict, List from unittest.mock import AsyncMock, patch import pytest - -from api.services.workflow.pipecat_engine import PipecatEngine -from api.services.workflow.workflow import WorkflowGraph -from api.tests.conftest import END_CALL_SYSTEM_PROMPT from pipecat.frames.frames import LLMContextFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner @@ -22,10 +18,14 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMAssistantAggregatorParams, LLMContextAggregatorPair, ) -from pipecat.tests import MockLLMService, MockTTSService from pipecat.tests.mock_transport import MockTransport from pipecat.transports.base_transport import TransportParams +from api.services.workflow.pipecat_engine import PipecatEngine +from api.services.workflow.workflow import WorkflowGraph +from api.tests.conftest import END_CALL_SYSTEM_PROMPT +from pipecat.tests import MockLLMService, MockTTSService + async def run_pipeline_with_tool_calls( workflow: WorkflowGraph, @@ -81,7 +81,7 @@ async def run_pipeline_with_tool_calls( context = LLMContext() # Add assistant context aggregator - assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + assistant_params = LLMAssistantAggregatorParams() context_aggregator = LLMContextAggregatorPair( context, assistant_params=assistant_params ) @@ -113,7 +113,7 @@ async def run_pipeline_with_tool_calls( # Patch DB calls to avoid actual database access with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): diff --git a/api/tests/test_pipecat_engine_transition_mute.py b/api/tests/test_pipecat_engine_transition_mute.py new file mode 100644 index 0000000..3cc5220 --- /dev/null +++ b/api/tests/test_pipecat_engine_transition_mute.py @@ -0,0 +1,280 @@ +"""Tests verifying user is muted while a transition function is executing. + +When the LLM calls a transition function (registered via +``_register_transition_function_with_llm``), pipecat broadcasts a +``FunctionCallsStartedFrame`` that ``FunctionCallUserMuteStrategy`` uses to +mute the user until a ``FunctionCallResultFrame`` arrives. These tests assert +that mute behavior holds end-to-end through the engine's transition flow, +so that user audio doesn't race the node switch / extraction / context update +that runs inside the transition function. +""" + +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest +from pipecat.frames.frames import LLMContextFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response_universal import ( + LLMAssistantAggregatorParams, + LLMContextAggregatorPair, + LLMUserAggregatorParams, +) +from pipecat.tests.mock_transport import MockTransport +from pipecat.transports.base_transport import TransportParams +from pipecat.turns.user_mute import ( + CallbackUserMuteStrategy, + FunctionCallUserMuteStrategy, + MuteUntilFirstBotCompleteUserMuteStrategy, +) + +from api.services.workflow.pipecat_engine import PipecatEngine +from api.services.workflow.pipecat_engine_variable_extractor import ( + VariableExtractionManager, +) +from api.services.workflow.workflow import WorkflowGraph +from pipecat.tests import MockLLMService, MockTTSService + + +async def _build_engine_and_pipeline( + workflow: WorkflowGraph, + mock_llm: MockLLMService, +): + """Set up engine + pipeline mirroring the non-realtime production wiring. + + Returns (engine, task, function_call_mute_strategy, user_context_aggregator). + """ + tts = MockTTSService(mock_audio_duration_ms=40, frame_delay=0) + + transport = MockTransport( + params=TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + audio_in_sample_rate=16000, + audio_out_sample_rate=16000, + ), + ) + + context = LLMContext() + + engine = PipecatEngine( + llm=mock_llm, + context=context, + workflow=workflow, + call_context_vars={"customer_name": "Test User"}, + workflow_run_id=1, + ) + + # Hold a reference so the test can introspect the in-progress set. + function_call_mute_strategy = FunctionCallUserMuteStrategy() + + # Match run_pipeline.py's non-realtime mute-strategy stack so the test + # exercises the same wiring that would be active in a real call. + user_mute_strategies = [ + MuteUntilFirstBotCompleteUserMuteStrategy(), + function_call_mute_strategy, + CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user), + ] + + user_params = LLMUserAggregatorParams(user_mute_strategies=user_mute_strategies) + assistant_params = LLMAssistantAggregatorParams() + + context_aggregator = LLMContextAggregatorPair( + context, assistant_params=assistant_params, user_params=user_params + ) + user_context_aggregator = context_aggregator.user() + assistant_context_aggregator = context_aggregator.assistant() + + pipeline = Pipeline( + [ + transport.input(), + user_context_aggregator, + mock_llm, + tts, + transport.output(), + assistant_context_aggregator, + ] + ) + + task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False) + engine.set_task(task) + + return engine, task, function_call_mute_strategy, user_context_aggregator + + +class TestTransitionFunctionMutesUser: + """Verify the user is muted while transition functions execute.""" + + @pytest.mark.asyncio + async def test_user_is_muted_during_transition_function( + self, simple_workflow: WorkflowGraph + ): + """The user must be muted from the moment a transition function starts + until its result is delivered. + + Scenario: + 1. LLM calls the ``end_call`` transition function (start → end edge). + 2. Wrap the registered handler so we can read mute state from inside it. + 3. VERIFY: the function-call mute strategy has the call in flight. + 4. VERIFY: the user aggregator's ``_user_is_muted`` flag is True. + """ + step_0_chunks = MockLLMService.create_function_call_chunks( + function_name="end_call", + arguments={}, + tool_call_id="call_end_1", + ) + llm = MockLLMService(mock_steps=[step_0_chunks], chunk_delay=0.001) + + ( + engine, + task, + function_call_mute_strategy, + user_context_aggregator, + ) = await _build_engine_and_pipeline(simple_workflow, llm) + + captured_states: list[dict] = [] + + # Wrap register_function so we can introspect mute state from inside + # the transition handler. We must wrap *after* the engine is created + # but *before* set_node registers the transition functions. + original_register_function = llm.register_function + + def wrapping_register_function(name, func, *args, **kwargs): + async def wrapped(function_call_params): + # Yield once so the user aggregator has a chance to drain + # the broadcasted FunctionCallsStartedFrame and update its + # mute state before we sample it. + await asyncio.sleep(0.02) + captured_states.append( + { + "name": name, + "function_call_in_progress": bool( + function_call_mute_strategy._function_call_in_progress + ), + "user_is_muted": user_context_aggregator._user_is_muted, + "tool_call_ids": set( + function_call_mute_strategy._function_call_in_progress + ), + } + ) + return await func(function_call_params) + + return original_register_function(name, wrapped, *args, **kwargs) + + llm.register_function = wrapping_register_function + + with patch( + "api.db:db_client.get_organization_id_by_workflow_run_id", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="completed", + ): + with patch.object( + VariableExtractionManager, + "_perform_extraction", + new_callable=AsyncMock, + return_value={"user_intent": "end call"}, + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_engine(): + await asyncio.sleep(0.01) + await engine.initialize() + await engine.set_node(engine.workflow.start_node_id) + await engine.llm.queue_frame(LLMContextFrame(engine.context)) + + await asyncio.wait_for( + asyncio.gather(run_pipeline(), initialize_engine()), + timeout=10.0, + ) + + assert len(captured_states) == 1, ( + f"Expected the transition function to be invoked exactly once, " + f"got {len(captured_states)}: {captured_states}" + ) + state = captured_states[0] + assert state["name"] == "end_call" + assert state["function_call_in_progress"], ( + "FunctionCallUserMuteStrategy should have the transition call in " + f"progress while the handler runs (state={state})" + ) + assert "call_end_1" in state["tool_call_ids"], ( + f"Expected tool_call_id 'call_end_1' to be tracked, got {state['tool_call_ids']}" + ) + assert state["user_is_muted"], ( + "User aggregator's _user_is_muted should be True during the " + f"transition function (state={state})" + ) + + @pytest.mark.asyncio + async def test_user_is_unmuted_after_transition_function_returns( + self, simple_workflow: WorkflowGraph + ): + """After the transition function's result is delivered, the function-call + mute strategy should clear its in-progress set. Other strategies in the + stack (CallbackUserMuteStrategy via engine.should_mute_user) may still + keep the pipeline muted because end_call_with_reason fires when the + engine reaches the End node, but the function-call strategy itself + must release its hold. + """ + step_0_chunks = MockLLMService.create_function_call_chunks( + function_name="end_call", + arguments={}, + tool_call_id="call_end_1", + ) + llm = MockLLMService(mock_steps=[step_0_chunks], chunk_delay=0.001) + + ( + engine, + task, + function_call_mute_strategy, + _user_context_aggregator, + ) = await _build_engine_and_pipeline(simple_workflow, llm) + + with patch( + "api.db:db_client.get_organization_id_by_workflow_run_id", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="completed", + ): + with patch.object( + VariableExtractionManager, + "_perform_extraction", + new_callable=AsyncMock, + return_value={"user_intent": "end call"}, + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_engine(): + await asyncio.sleep(0.01) + await engine.initialize() + await engine.set_node(engine.workflow.start_node_id) + await engine.llm.queue_frame(LLMContextFrame(engine.context)) + + await asyncio.wait_for( + asyncio.gather(run_pipeline(), initialize_engine()), + timeout=10.0, + ) + + assert function_call_mute_strategy._function_call_in_progress == set(), ( + "FunctionCallUserMuteStrategy should have cleared its in-progress " + "set after the transition function's result was delivered, got " + f"{function_call_mute_strategy._function_call_in_progress}" + ) diff --git a/api/tests/test_pipecat_engine_variable_extraction.py b/api/tests/test_pipecat_engine_variable_extraction.py index 140c5e8..29581d7 100644 --- a/api/tests/test_pipecat_engine_variable_extraction.py +++ b/api/tests/test_pipecat_engine_variable_extraction.py @@ -16,12 +16,6 @@ from typing import Any, Dict, List from unittest.mock import AsyncMock, patch import pytest - -from api.services.workflow.pipecat_engine import PipecatEngine -from api.services.workflow.pipecat_engine_variable_extractor import ( - VariableExtractionManager, -) -from api.services.workflow.workflow import WorkflowGraph from pipecat.frames.frames import LLMContextFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner @@ -31,10 +25,16 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMAssistantAggregatorParams, LLMContextAggregatorPair, ) -from pipecat.tests import MockLLMService, MockTTSService from pipecat.tests.mock_transport import MockTransport from pipecat.transports.base_transport import TransportParams +from api.services.workflow.pipecat_engine import PipecatEngine +from api.services.workflow.pipecat_engine_variable_extractor import ( + VariableExtractionManager, +) +from api.services.workflow.workflow import WorkflowGraph +from pipecat.tests import MockLLMService, MockTTSService + class TestVariableExtractionDuringTransitions: """Test that variable extraction is triggered for the correct node during transitions.""" @@ -97,7 +97,7 @@ class TestVariableExtractionDuringTransitions: context = LLMContext() # Add assistant context aggregator - assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + assistant_params = LLMAssistantAggregatorParams() context_aggregator = LLMContextAggregatorPair( context, assistant_params=assistant_params ) @@ -152,7 +152,7 @@ class TestVariableExtractionDuringTransitions: # Patch DB calls and extraction manager with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): diff --git a/api/tests/test_pipeline_cancellation.py b/api/tests/test_pipeline_cancellation.py index 02ea5f4..053c59f 100644 --- a/api/tests/test_pipeline_cancellation.py +++ b/api/tests/test_pipeline_cancellation.py @@ -2,7 +2,6 @@ import asyncio import pytest from loguru import logger - from pipecat.frames.frames import ( EndTaskFrame, Frame, @@ -35,8 +34,10 @@ class BusyWaitProcessor(FrameProcessor): # Simulate a delay, which can happen sometimes due to slow LLM Inferencing or # other reasons try: - logger.debug(f"{self} sleeping with frame: {frame}") - await asyncio.sleep(5) + logger.debug( + f"{self} sleeping with frame: {frame} for {self._wait_time} seconds" + ) + await asyncio.sleep(self._wait_time) logger.debug(f"{self} woke up with frame: {frame}") except asyncio.CancelledError: logger.debug(f"{self} was cancelled") @@ -46,7 +47,7 @@ class BusyWaitProcessor(FrameProcessor): @pytest.mark.asyncio async def test_interruption_with_blocked_end_frame(): - busy_wait_processor = BusyWaitProcessor(wait_time=5) + busy_wait_processor = BusyWaitProcessor(wait_time=0.5) transport = MockTransport() pipeline = Pipeline([transport, busy_wait_processor]) @@ -78,11 +79,13 @@ async def test_interruption_with_blocked_end_frame(): # Wait with timeout done, pending = await asyncio.wait( [pipeline_task, queue_task], - timeout=1.0, + timeout=2.0, return_when=asyncio.ALL_COMPLETED, ) # If there are pending tasks, we timed out + # FIXME: Currently I have creaetd an issue on pipecat which talks about + # how this behaviour is not good. https://github.com/pipecat-ai/pipecat/issues/4412 if pending: # Cancel all pending tasks for t in pending: @@ -92,9 +95,9 @@ async def test_interruption_with_blocked_end_frame(): try: await asyncio.wait_for( asyncio.gather(*pending, return_exceptions=True), - timeout=1.0, + timeout=2.0, ) except asyncio.TimeoutError: pass # Cleanup took too long, continue anyway - pytest.fail("Test timed out after 1 second") + pytest.fail("Test timed out after 2 second") diff --git a/api/tests/test_recording_router_processor.py b/api/tests/test_recording_router_processor.py index 5ef2057..c86e6b4 100644 --- a/api/tests/test_recording_router_processor.py +++ b/api/tests/test_recording_router_processor.py @@ -12,6 +12,14 @@ and inspect what arrives downstream. from typing import Optional import pytest +from pipecat.frames.frames import ( + LLMFullResponseEndFrame, + LLMTextFrame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, + TTSTextFrame, +) from api.services.pipecat.recording_audio_cache import RecordingAudio from api.services.pipecat.recording_router_processor import ( @@ -21,14 +29,6 @@ from api.services.workflow.pipecat_engine_context_composer import ( RECORDING_MARKER, TTS_MARKER, ) -from pipecat.frames.frames import ( - LLMFullResponseEndFrame, - LLMTextFrame, - TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, - TTSTextFrame, -) from pipecat.tests import run_test # --------------------------------------------------------------------------- diff --git a/api/tests/test_text_and_audio_playback.py b/api/tests/test_text_and_audio_playback.py index 6330fa4..a950c9b 100644 --- a/api/tests/test_text_and_audio_playback.py +++ b/api/tests/test_text_and_audio_playback.py @@ -11,21 +11,6 @@ from typing import Any, Dict, List from unittest.mock import AsyncMock, Mock, patch import pytest - -from api.services.pipecat.recording_audio_cache import RecordingAudio -from api.services.workflow.dto import ( - EdgeDataDTO, - EndCallNodeData, - EndCallRFNode, - Position, - ReactFlowDTO, - RFEdgeDTO, - StartCallNodeData, - StartCallRFNode, -) -from api.services.workflow.pipecat_engine import PipecatEngine -from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager -from api.services.workflow.workflow import WorkflowGraph from pipecat.frames.frames import ( Frame, LLMContextFrame, @@ -42,10 +27,25 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMAssistantAggregatorParams, LLMContextAggregatorPair, ) -from pipecat.tests import MockLLMService, MockTTSService from pipecat.tests.mock_transport import MockTransport from pipecat.transports.base_transport import TransportParams +from api.services.pipecat.recording_audio_cache import RecordingAudio +from api.services.workflow.dto import ( + EdgeDataDTO, + EndCallNodeData, + EndCallRFNode, + Position, + ReactFlowDTO, + RFEdgeDTO, + StartCallNodeData, + StartCallRFNode, +) +from api.services.workflow.pipecat_engine import PipecatEngine +from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager +from api.services.workflow.workflow import WorkflowGraph +from pipecat.tests import MockLLMService, MockTTSService + # ─── Constants ────────────────────────────────────────────────── START_PROMPT = "Start Call System Prompt" @@ -189,7 +189,7 @@ async def run_pipeline_and_capture_frames( ) context = LLMContext() - assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + assistant_params = LLMAssistantAggregatorParams() context_aggregator = LLMContextAggregatorPair( context, assistant_params=assistant_params ) @@ -234,7 +234,7 @@ async def run_pipeline_and_capture_frames( with ( patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ), diff --git a/api/tests/test_tts_endframe_with_audio_write_failure.py b/api/tests/test_tts_endframe_with_audio_write_failure.py index d4eb92f..56f9ac6 100644 --- a/api/tests/test_tts_endframe_with_audio_write_failure.py +++ b/api/tests/test_tts_endframe_with_audio_write_failure.py @@ -32,12 +32,6 @@ import asyncio from unittest.mock import AsyncMock, patch import pytest - -from api.services.workflow.pipecat_engine import PipecatEngine -from api.services.workflow.pipecat_engine_variable_extractor import ( - VariableExtractionManager, -) -from api.services.workflow.workflow import WorkflowGraph from pipecat.frames.frames import LLMContextFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner @@ -48,7 +42,6 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMContextAggregatorPair, LLMUserAggregatorParams, ) -from pipecat.tests import MockLLMService, MockTTSService from pipecat.tests.mock_transport import MockTransport from pipecat.transports.base_transport import TransportParams from pipecat.turns.user_mute import ( @@ -57,6 +50,13 @@ from pipecat.turns.user_mute import ( ) from pipecat.utils.enums import EndTaskReason +from api.services.workflow.pipecat_engine import PipecatEngine +from api.services.workflow.pipecat_engine_variable_extractor import ( + VariableExtractionManager, +) +from api.services.workflow.workflow import WorkflowGraph +from pipecat.tests import MockLLMService, MockTTSService + async def create_test_pipeline_with_failing_transport( workflow: WorkflowGraph, @@ -131,7 +131,7 @@ async def create_test_pipeline_with_failing_transport( user_mute_strategies=user_mute_strategies, ) - assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + assistant_params = LLMAssistantAggregatorParams() context_aggregator = LLMContextAggregatorPair( context, assistant_params=assistant_params, user_params=user_params @@ -204,7 +204,7 @@ class TestTTSPauseWithAudioWriteFailure: # Patch DB calls with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): @@ -324,7 +324,7 @@ class TestTTSPauseWithAudioWriteFailure: ) with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): diff --git a/api/tests/test_unregistered_function_call.py b/api/tests/test_unregistered_function_call.py new file mode 100644 index 0000000..24ed9a1 --- /dev/null +++ b/api/tests/test_unregistered_function_call.py @@ -0,0 +1,81 @@ +"""Tests for LLM behavior when calling an unregistered function.""" + +import pytest +from pipecat.frames.frames import ( + FunctionCallInProgressFrame, + FunctionCallResultFrame, + FunctionCallsFromLLMInfoFrame, + FunctionCallsStartedFrame, + LLMContextFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, +) +from pipecat.pipeline.pipeline import Pipeline +from pipecat.processors.aggregators.llm_context import LLMContext + +from pipecat.tests import MockLLMService, run_test + + +class TestUnregisteredFunctionCall: + """Tests for LLM behavior when generating a tool call for an unregistered function.""" + + @pytest.mark.asyncio + async def test_unregistered_function_emits_error_result(self): + """LLM calling an unregistered function should still terminate with a + FunctionCallResultFrame whose result is an error string, instead of + crashing the pipeline.""" + chunks = MockLLMService.create_function_call_chunks( + function_name="nonexistent_tool", + arguments={"foo": "bar"}, + tool_call_id="call_missing_1", + ) + + llm = MockLLMService(mock_chunks=chunks, chunk_delay=0.001) + + # Intentionally do NOT register any handler for "nonexistent_tool". + + messages = [{"role": "user", "content": "Please use a tool I never registered"}] + context = LLMContext(messages) + + pipeline = Pipeline([llm]) + + received_down_frames, _ = await run_test( + pipeline, + frames_to_send=[LLMContextFrame(context)], + expected_down_frames=[ + LLMFullResponseStartFrame, + FunctionCallsFromLLMInfoFrame, + FunctionCallsStartedFrame, + LLMFullResponseEndFrame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, + ], + ) + + result_frames = [ + f for f in received_down_frames if isinstance(f, FunctionCallResultFrame) + ] + assert len(result_frames) == 1, ( + "Expected exactly one FunctionCallResultFrame for the unregistered call" + ) + + result_frame = result_frames[0] + assert result_frame.function_name == "nonexistent_tool" + assert result_frame.tool_call_id == "call_missing_1" + assert result_frame.arguments == {"foo": "bar"} + + # Pipecat's missing-function handler returns a string error. + assert isinstance(result_frame.result, str) + assert "not registered" in result_frame.result + assert "nonexistent_tool" in result_frame.result + + # In-progress frame should also be emitted before the result so mute + # strategies can release the tool_call_id. + in_progress_frames = [ + f + for f in received_down_frames + if isinstance(f, FunctionCallInProgressFrame) + ] + assert len(in_progress_frames) == 1 + assert in_progress_frames[0].function_name == "nonexistent_tool" + assert in_progress_frames[0].tool_call_id == "call_missing_1" diff --git a/api/tests/test_user_idle_handler.py b/api/tests/test_user_idle_handler.py index 7c80e7d..47d8eee 100644 --- a/api/tests/test_user_idle_handler.py +++ b/api/tests/test_user_idle_handler.py @@ -13,9 +13,6 @@ import asyncio from unittest.mock import AsyncMock, patch import pytest - -from api.services.workflow.pipecat_engine import PipecatEngine -from api.services.workflow.workflow import WorkflowGraph from pipecat.frames.frames import ( BotStoppedSpeakingFrame, Frame, @@ -35,7 +32,6 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMUserAggregatorParams, ) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.tests import MockLLMService, MockTTSService from pipecat.tests.mock_transport import MockTransport from pipecat.transports.base_transport import TransportParams from pipecat.turns.user_mute import ( @@ -47,6 +43,10 @@ from pipecat.turns.user_stop import ExternalUserTurnStopStrategy from pipecat.turns.user_turn_strategies import UserTurnStrategies from pipecat.utils.time import time_now_iso8601 +from api.services.workflow.pipecat_engine import PipecatEngine +from api.services.workflow.workflow import WorkflowGraph +from pipecat.tests import MockLLMService, MockTTSService + class UserSpeechInjector(FrameProcessor): """Processor that injects user speaking frames after the bot finishes speaking. @@ -161,7 +161,7 @@ async def create_pipeline_with_speech_injection( user_idle_timeout=user_idle_timeout, ) - assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + assistant_params = LLMAssistantAggregatorParams() context_aggregator = LLMContextAggregatorPair( context, assistant_params=assistant_params, user_params=user_params @@ -257,7 +257,7 @@ class TestUserIdleHandler: ) with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): diff --git a/api/tests/test_user_muting_during_bot_speech.py b/api/tests/test_user_muting_during_bot_speech.py index ee14234..b055385 100644 --- a/api/tests/test_user_muting_during_bot_speech.py +++ b/api/tests/test_user_muting_during_bot_speech.py @@ -15,12 +15,6 @@ from typing import List from unittest.mock import AsyncMock, patch import pytest - -from api.services.workflow.pipecat_engine import PipecatEngine -from api.services.workflow.pipecat_engine_variable_extractor import ( - VariableExtractionManager, -) -from api.services.workflow.workflow import WorkflowGraph from pipecat.frames.frames import ( BotStartedSpeakingFrame, BotStoppedSpeakingFrame, @@ -41,7 +35,6 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMUserAggregatorParams, ) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.tests import MockLLMService, MockTTSService from pipecat.tests.mock_transport import MockTransport from pipecat.transports.base_transport import TransportParams from pipecat.turns.user_mute import ( @@ -51,6 +44,13 @@ from pipecat.turns.user_mute import ( from pipecat.turns.user_turn_strategies import ExternalUserTurnStrategies from pipecat.utils.time import time_now_iso8601 +from api.services.workflow.pipecat_engine import PipecatEngine +from api.services.workflow.pipecat_engine_variable_extractor import ( + VariableExtractionManager, +) +from api.services.workflow.workflow import WorkflowGraph +from pipecat.tests import MockLLMService, MockTTSService + class BotSpeakingObserverProcessor(FrameProcessor): """Observer that records mute status when bot speaking events flow upstream. @@ -160,7 +160,7 @@ async def create_engine_for_mute_test( ) # Create context aggregator with user mute strategies - assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + assistant_params = LLMAssistantAggregatorParams() user_mute_strategies = [ MuteUntilFirstBotCompleteUserMuteStrategy(), @@ -243,7 +243,7 @@ class TestUserMutingDuringBotSpeech: ) = await create_engine_for_mute_test(simple_workflow, llm, tts_duration_ms=50) with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): @@ -334,7 +334,7 @@ class TestUserMutingDuringBotSpeech: ) = await create_engine_for_mute_test(simple_workflow, llm, tts_duration_ms=50) with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): @@ -430,7 +430,7 @@ class TestUserMutingDuringBotSpeech: ) = await create_engine_for_mute_test(simple_workflow, llm, tts_duration_ms=50) with patch( - "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + "api.db:db_client.get_organization_id_by_workflow_run_id", new_callable=AsyncMock, return_value=1, ): diff --git a/api/tests/test_voicemail_detector.py b/api/tests/test_voicemail_detector.py index 9ec34bc..0677c29 100644 --- a/api/tests/test_voicemail_detector.py +++ b/api/tests/test_voicemail_detector.py @@ -8,7 +8,6 @@ incoming speech as CONVERSATION or VOICEMAIL and how the main LLM responds. import asyncio import pytest - from pipecat.extensions.voicemail.voicemail_detector import VoicemailDetector from pipecat.frames.frames import ( EndTaskFrame, @@ -27,7 +26,6 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMUserAggregatorParams, ) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.tests import MockLLMService from pipecat.turns.user_start import ( TranscriptionUserTurnStartStrategy, VADUserTurnStartStrategy, @@ -38,6 +36,8 @@ from pipecat.turns.user_stop import ( from pipecat.turns.user_turn_strategies import UserTurnStrategies from pipecat.utils.time import time_now_iso8601 +from pipecat.tests import MockLLMService + class FrameInjector(FrameProcessor): """Simple processor that can inject frames into the pipeline.""" @@ -110,7 +110,7 @@ class TestVoicemailDetectorWithUserAggregator: user_turn_strategies=user_turn_strategies, ) - assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + assistant_params = LLMAssistantAggregatorParams() context_aggregator = LLMContextAggregatorPair( context, assistant_params=assistant_params, user_params=user_params diff --git a/pipecat b/pipecat index 5a3c997..a0e790b 160000 --- a/pipecat +++ b/pipecat @@ -1 +1 @@ -Subproject commit 5a3c997405e4f16c4c061bbb06384ad25cd1c5e1 +Subproject commit a0e790b4e3b836425d14834f10e76392ce6fc4cd diff --git a/scripts/setup_pipecat.ps1 b/scripts/setup_pipecat.ps1 index e682413..f831319 100644 --- a/scripts/setup_pipecat.ps1 +++ b/scripts/setup_pipecat.ps1 @@ -13,12 +13,13 @@ Write-Host "Setting up pipecat as a git submodule..." Write-Host "Initializing git submodules..." git submodule update --init --recursive +# Install dograh API requirements first so pipecat's extras win on any +# shared transitive dependencies (matches api/Dockerfile and CI workflow). +Write-Host "Installing dograh API requirements..." +pip install -r api/requirements.txt + # Install pipecat in editable mode with all extras Write-Host "Installing pipecat dependencies..." pip install -e './pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb]' -# Install other requirements -Write-Host "Installing dograh API requirements..." -pip install -r api/requirements.txt - Write-Host "Setup complete! Pipecat is now available as a git submodule." diff --git a/scripts/setup_pipecat.sh b/scripts/setup_pipecat.sh index 8ae4e16..cf4cc27 100755 --- a/scripts/setup_pipecat.sh +++ b/scripts/setup_pipecat.sh @@ -14,12 +14,13 @@ echo "Setting up pipecat as a git submodule..." echo "Initializing git submodules..." git submodule update --init --recursive +# Install dograh API requirements first so pipecat's extras win on any +# shared transitive dependencies (matches api/Dockerfile and CI workflow). +echo "Installing dograh API requirements..." +pip install -r api/requirements.txt + # Install pipecat in editable mode with all extras echo "Installing pipecat dependencies..." pip install -e ./pipecat[cartesia,deepgram,openai,elevenlabs,groq,google,azure,sarvam,soundfile,silero,webrtc,speechmatics,openrouter,camb] -# Install other requirements -echo "Installing dograh API requirements..." -pip install -r api/requirements.txt - echo "Setup complete! Pipecat is now available as a git submodule." \ No newline at end of file diff --git a/sdk/python/src/dograh_sdk/_generated_models.py b/sdk/python/src/dograh_sdk/_generated_models.py index ed3cfec..8d17b5c 100644 --- a/sdk/python/src/dograh_sdk/_generated_models.py +++ b/sdk/python/src/dograh_sdk/_generated_models.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: -# filename: dograh-openapi-XXXXXX.json.TsrryEqEnE -# timestamp: 2026-05-02T11:32:55+00:00 +# filename: dograh-openapi-XXXXXX.json.YApLaGcbbM +# timestamp: 2026-05-04T09:31:31+00:00 from __future__ import annotations diff --git a/sdk/python/src/dograh_sdk/typed/trigger.py b/sdk/python/src/dograh_sdk/typed/trigger.py index e112144..3a6be4b 100644 --- a/sdk/python/src/dograh_sdk/typed/trigger.py +++ b/sdk/python/src/dograh_sdk/typed/trigger.py @@ -24,7 +24,12 @@ class Trigger(TypedNode): `/api/v1/public/agent/test/` — runs the latest draft, useful for verifying changes before publishing. Falls back to the published agent when no draft exists. Both require an API key in the - `X-API-Key` header. + `X-API-Key` header. Request body fields: • `phone_number` (string, + required) — destination to dial. • `initial_context` (object, + optional) — merged into the run's initial context. • + `telephony_configuration_id` (int, optional) — pick a specific telephony + configuration for the call. Must belong to the same organization as the + trigger. When omitted, the org's default outbound configuration is used. """ type: ClassVar[str] = 'trigger' diff --git a/sdk/typescript/src/typed/trigger.ts b/sdk/typescript/src/typed/trigger.ts index 119c4f7..2dbe27d 100644 --- a/sdk/typescript/src/typed/trigger.ts +++ b/sdk/typescript/src/typed/trigger.ts @@ -12,6 +12,10 @@ * • Production: `/api/v1/public/agent/` — runs the published agent. Use this from production systems. * • Test: `/api/v1/public/agent/test/` — runs the latest draft, useful for verifying changes before publishing. Falls back to the published agent when no draft exists. * Both require an API key in the `X-API-Key` header. + * Request body fields: + * • `phone_number` (string, required) — destination to dial. + * • `initial_context` (object, optional) — merged into the run's initial context. + * • `telephony_configuration_id` (int, optional) — pick a specific telephony configuration for the call. Must belong to the same organization as the trigger. When omitted, the org's default outbound configuration is used. */ export interface Trigger { type: "trigger";