feat: simplify pipecat engine execution

This commit is contained in:
Abhishek Kumar 2025-11-15 17:22:15 +05:30
parent 5e4aef346d
commit cc05f363ff
35 changed files with 545 additions and 1861 deletions

View file

@ -5,15 +5,15 @@ Revises: 982ec8e434be
Create Date: 2025-10-21 12:28:06.053318
"""
from typing import Sequence, Union
from alembic import op
from alembic_postgresql_enum import TableReference
# revision identifiers, used by Alembic.
revision: str = 'a57d25b75117'
down_revision: Union[str, None] = '982ec8e434be'
revision: str = "a57d25b75117"
down_revision: Union[str, None] = "982ec8e434be"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@ -26,12 +26,20 @@ def upgrade() -> None:
2. Migrates TWILIO_CONFIGURATION key to TELEPHONY_CONFIGURATION
3. Renames twilio_status_callbacks to telephony_status_callbacks in workflow_run logs
"""
# Add 'vonage' to the workflow_run_mode enum
op.sync_enum_values(
enum_schema="public",
enum_name="workflow_run_mode",
new_values=["twilio", "stasis", "webrtc", "smallwebrtc", "VOICE", "CHAT", "vonage"],
new_values=[
"twilio",
"stasis",
"webrtc",
"smallwebrtc",
"VOICE",
"CHAT",
"vonage",
],
affected_columns=[
TableReference(
table_schema="public", table_name="workflow_runs", column_name="mode"
@ -39,14 +47,14 @@ def upgrade() -> None:
],
enum_values_to_rename=[],
)
# Rename the key from TWILIO_CONFIGURATION to TELEPHONY_CONFIGURATION
op.execute("""
UPDATE organization_configurations
SET key = 'TELEPHONY_CONFIGURATION'
WHERE key = 'TWILIO_CONFIGURATION';
""")
# Rename twilio_status_callbacks to telephony_status_callbacks in workflow_run logs
op.execute("""
UPDATE workflow_runs
@ -57,15 +65,17 @@ def upgrade() -> None:
)
WHERE logs::jsonb ? 'twilio_status_callbacks';
""")
print("Migration complete: Added vonage to enum, renamed configuration key, and updated status callback keys")
print(
"Migration complete: Added vonage to enum, renamed configuration key, and updated status callback keys"
)
def downgrade() -> None:
"""
Revert configuration key names and enum.
"""
# Revert telephony_status_callbacks to twilio_status_callbacks in workflow_run logs
op.execute("""
UPDATE workflow_runs
@ -76,14 +86,14 @@ def downgrade() -> None:
)
WHERE logs::jsonb ? 'telephony_status_callbacks';
""")
# Revert key name
op.execute("""
UPDATE organization_configurations
SET key = 'TWILIO_CONFIGURATION'
WHERE key = 'TELEPHONY_CONFIGURATION';
""")
# Revert enum to previous state
op.sync_enum_values(
enum_schema="public",
@ -96,5 +106,5 @@ def downgrade() -> None:
],
enum_values_to_rename=[],
)
print("Downgrade complete: Reverted configuration key names and enum")
print("Downgrade complete: Reverted configuration key names and enum")

View file

@ -63,8 +63,12 @@ class OrganizationConfigurationKey(Enum):
DISPOSITION_CODE_MAPPING = "DISPOSITION_CODE_MAPPING"
DISPOSITION_MESSAGE_TEMPLATE = "DISPOSITION_MESSAGE_TEMPLATE"
CONCURRENT_CALL_LIMIT = "CONCURRENT_CALL_LIMIT"
TELEPHONY_CONFIGURATION = "TELEPHONY_CONFIGURATION" # Stores all providers + active one
TWILIO_CONFIGURATION = "TWILIO_CONFIGURATION" # Deprecated - for backward compatibility
TELEPHONY_CONFIGURATION = (
"TELEPHONY_CONFIGURATION" # Stores all providers + active one
)
TWILIO_CONFIGURATION = (
"TWILIO_CONFIGURATION" # Deprecated - for backward compatibility
)
class WorkflowStatus(Enum):

View file

@ -1,4 +1,4 @@
langfuse==3.4.0
langfuse==3.9.3
fastapi==0.116.2
asyncpg==0.30.0
alembic==1.16.5

View file

@ -1,9 +1,10 @@
from typing import Union
from fastapi import APIRouter, Depends, HTTPException
from api.db import db_client
from api.db.models import UserModel
from api.enums import OrganizationConfigurationKey
from typing import Union
from api.schemas.telephony_config import (
TelephonyConfigurationResponse,
TwilioConfigurationRequest,
@ -19,14 +20,13 @@ router = APIRouter(prefix="/organizations", tags=["organizations"])
# Provider configuration constants
PROVIDER_MASKED_FIELDS = {
"twilio": ["account_sid", "auth_token"],
"vonage": ["private_key", "api_key", "api_secret"]
"vonage": ["private_key", "api_key", "api_secret"],
}
# TODO: Make endpoints provider-agnostic
@router.get("/telephony-config", response_model=TelephonyConfigurationResponse)
async def get_telephony_configuration(
user: UserModel = Depends(get_user)
):
async def get_telephony_configuration(user: UserModel = Depends(get_user)):
"""Get telephony configuration for the user's organization with masked sensitive fields."""
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
@ -40,11 +40,13 @@ async def get_telephony_configuration(
return TelephonyConfigurationResponse()
stored_provider = config.value.get("provider", "twilio")
if stored_provider == "twilio":
account_sid = config.value.get("account_sid", "")
auth_token = config.value.get("auth_token", "")
from_numbers = config.value.get("from_numbers", []) if account_sid and auth_token else []
from_numbers = (
config.value.get("from_numbers", []) if account_sid and auth_token else []
)
return TelephonyConfigurationResponse(
twilio=TwilioConfigurationResponse(
@ -53,15 +55,19 @@ async def get_telephony_configuration(
auth_token=mask_key(auth_token) if auth_token else "",
from_numbers=from_numbers,
),
vonage=None
vonage=None,
)
elif stored_provider == "vonage":
application_id = config.value.get("application_id", "")
private_key = config.value.get("private_key", "")
api_key = config.value.get("api_key", "")
api_secret = config.value.get("api_secret", "")
from_numbers = config.value.get("from_numbers", []) if application_id and private_key else []
from_numbers = (
config.value.get("from_numbers", [])
if application_id and private_key
else []
)
return TelephonyConfigurationResponse(
twilio=None,
vonage=VonageConfigurationResponse(
@ -71,7 +77,7 @@ async def get_telephony_configuration(
api_key=mask_key(api_key) if api_key else None,
api_secret=mask_key(api_secret) if api_secret else None,
from_numbers=from_numbers,
)
),
)
else:
return TelephonyConfigurationResponse()
@ -79,8 +85,8 @@ async def get_telephony_configuration(
@router.post("/telephony-config")
async def save_telephony_configuration(
request: Union[TwilioConfigurationRequest, VonageConfigurationRequest],
user: UserModel = Depends(get_user)
request: Union[TwilioConfigurationRequest, VonageConfigurationRequest],
user: UserModel = Depends(get_user),
):
"""Save telephony configuration for the user's organization."""
if not user.selected_organization_id:
@ -105,12 +111,14 @@ async def save_telephony_configuration(
"provider": "vonage",
"application_id": request.application_id,
"private_key": request.private_key,
"api_key": getattr(request, 'api_key', None),
"api_secret": getattr(request, 'api_secret', None),
"api_key": getattr(request, "api_key", None),
"api_secret": getattr(request, "api_secret", None),
"from_numbers": request.from_numbers,
}
else:
raise HTTPException(status_code=400, detail=f"Unsupported provider: {request.provider}")
raise HTTPException(
status_code=400, detail=f"Unsupported provider: {request.provider}"
)
if existing_config and existing_config.value:
existing_provider = existing_config.value.get("provider")
@ -126,14 +134,16 @@ async def save_telephony_configuration(
return {"message": "Telephony configuration saved successfully"}
def preserve_masked_fields(request, existing_config, config_value):
def preserve_masked_fields(request, existing_config, config_value):
provider = request.provider
masked_fields = PROVIDER_MASKED_FIELDS.get(provider, [])
for field_name in masked_fields:
if hasattr(request, field_name):
field_value = getattr(request, field_name)
# Check if field has a value and is a masked version of the existing value
if field_value and is_mask_of(field_value, existing_config.value.get(field_name, "")):
if field_value and is_mask_of(
field_value, existing_config.value.get(field_name, "")
):
config_value[field_name] = existing_config.value[field_name]

View file

@ -1,19 +1,19 @@
"""
Generic telephony routes that work with any telephony provider.
"""
import json
import random
from datetime import UTC, datetime
from typing import Annotated, Optional
from typing import Optional
from fastapi import APIRouter, Depends, Form, Header, HTTPException, Request, WebSocket
from fastapi import APIRouter, Depends, Header, HTTPException, Request, WebSocket
from loguru import logger
from pydantic import BaseModel
from starlette.responses import HTMLResponse
from api.db import db_client
from api.db.models import UserModel
from api.enums import WorkflowRunMode
from api.services.auth.depends import get_user
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
from api.services.campaign.campaign_event_publisher import get_campaign_event_publisher
@ -32,6 +32,7 @@ class InitiateCallRequest(BaseModel):
class StatusCallbackRequest(BaseModel):
"""Generic status callback that can handle different providers"""
# Common fields
call_id: str
status: str
@ -39,10 +40,10 @@ class StatusCallbackRequest(BaseModel):
to_number: Optional[str] = None
direction: Optional[str] = None
duration: Optional[str] = None
# Provider-specific fields stored as extra
extra: dict = {}
@classmethod
def from_twilio(cls, data: dict):
"""Convert Twilio callback to generic format"""
@ -53,9 +54,9 @@ class StatusCallbackRequest(BaseModel):
to_number=data.get("To"),
direction=data.get("Direction"),
duration=data.get("CallDuration") or data.get("Duration"),
extra=data
extra=data,
)
@classmethod
def from_vonage(cls, data: dict):
"""Convert Vonage event to generic format"""
@ -63,14 +64,14 @@ class StatusCallbackRequest(BaseModel):
status_map = {
"started": "initiated",
"ringing": "ringing",
"answered": "answered",
"answered": "answered",
"complete": "completed",
"failed": "failed",
"busy": "busy",
"timeout": "no-answer",
"rejected": "busy"
"rejected": "busy",
}
return cls(
call_id=data.get("uuid", ""),
status=status_map.get(data.get("status", ""), data.get("status", "")),
@ -78,7 +79,7 @@ class StatusCallbackRequest(BaseModel):
to_number=data.get("to"),
direction=data.get("direction"),
duration=data.get("duration"),
extra=data
extra=data,
)
@ -87,32 +88,32 @@ async def initiate_call(
request: InitiateCallRequest, user: UserModel = Depends(get_user)
):
"""Initiate a call using the configured telephony provider."""
# Get the telephony provider for the organization
provider = await get_telephony_provider(user.selected_organization_id)
# Validate provider is configured
if not provider.validate_config():
raise HTTPException(
status_code=400,
detail="telephony_not_configured",
)
# Determine the workflow run mode based on provider type
workflow_run_mode = provider.PROVIDER_NAME
user_configuration = await db_client.get_user_configurations(user.id)
phone_number = request.phone_number or user_configuration.test_phone_number
if not phone_number:
raise HTTPException(
status_code=400,
detail="Phone number must be provided in request or set in user configuration"
status_code=400,
detail="Phone number must be provided in request or set in user configuration",
)
workflow_run_id = request.workflow_run_id
if not workflow_run_id:
workflow_run_name = f"WR-TEL-{random.randint(1000, 9999)}"
workflow_run = await db_client.create_workflow_run(
@ -130,12 +131,12 @@ async def initiate_call(
if not workflow_run:
raise HTTPException(status_code=400, detail="Workflow run not found")
workflow_run_name = workflow_run.name
# Construct webhook URL based on provider type
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
webhook_endpoint = provider.WEBHOOK_ENDPOINT
webhook_url = (
f"https://{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
f"?workflow_id={request.workflow_id}"
@ -143,35 +144,29 @@ async def initiate_call(
f"&workflow_run_id={workflow_run_id}"
f"&organization_id={user.selected_organization_id}"
)
# Initiate call via provider
result = await provider.initiate_call(
to_number=phone_number,
webhook_url=webhook_url,
workflow_run_id=workflow_run_id,
)
# Store provider type and any provider-specific metadata in workflow run context
gathered_context = {
"provider": provider.PROVIDER_NAME,
**(result.provider_metadata or {})
**(result.provider_metadata or {}),
}
await db_client.update_workflow_run(
run_id=workflow_run_id,
gathered_context=gathered_context
run_id=workflow_run_id, gathered_context=gathered_context
)
return {
"message": f"Call initiated successfully with run name {workflow_run_name}"
}
return {"message": f"Call initiated successfully with run name {workflow_run_name}"}
@router.post("/twiml", include_in_schema=False)
async def handle_twiml_webhook(
workflow_id: int,
user_id: int,
workflow_run_id: int,
organization_id: int
workflow_id: int, user_id: int, workflow_run_id: int, organization_id: int
):
"""
Handle initial webhook from telephony provider.
@ -179,32 +174,32 @@ async def handle_twiml_webhook(
"""
provider = await get_telephony_provider(organization_id)
response_content = await provider.get_webhook_response(
workflow_id, user_id, workflow_run_id
)
return HTMLResponse(content=response_content, media_type="application/xml")
@router.get("/ncco", include_in_schema=False)
async def handle_ncco_webhook(
workflow_id: int,
user_id: int,
workflow_id: int,
user_id: int,
workflow_run_id: int,
organization_id: Optional[int] = None
organization_id: Optional[int] = None,
):
"""Handle NCCO (Nexmo Call Control Objects) webhook for Vonage.
Returns JSON response instead of XML like TwiML.
"""
provider = await get_telephony_provider(organization_id or user_id)
response_content = await provider.get_webhook_response(
workflow_id, user_id, workflow_run_id
)
return json.loads(response_content)
@ -218,36 +213,38 @@ async def websocket_endpoint(
try:
# Set the run context
set_current_run_id(workflow_run_id)
# Get workflow run to determine provider type
workflow_run = await db_client.get_workflow_run(workflow_run_id)
if not workflow_run:
logger.error(f"Workflow run {workflow_run_id} not found")
await websocket.close(code=4404, reason="Workflow run not found")
return
# Get workflow for organization info
workflow = await db_client.get_workflow(workflow_id)
if not workflow:
logger.error(f"Workflow {workflow_id} not found")
await websocket.close(code=4404, reason="Workflow not found")
return
# Extract provider type from workflow run context
provider_type = None
if workflow_run.gathered_context:
provider_type = workflow_run.gathered_context.get("provider")
if not provider_type:
logger.error(f"No provider type found in workflow run {workflow_run_id}")
await websocket.close(code=4400, reason="Provider type not found")
return
logger.info(f"WebSocket connected for {provider_type} provider, workflow_run {workflow_run_id}")
logger.info(
f"WebSocket connected for {provider_type} provider, workflow_run {workflow_run_id}"
)
# Get the telephony provider instance
provider = await get_telephony_provider(workflow.organization_id)
# Verify the provider matches what was stored
if provider.PROVIDER_NAME != provider_type:
logger.error(
@ -255,10 +252,12 @@ async def websocket_endpoint(
)
await websocket.close(code=4400, reason="Provider mismatch")
return
# Delegate to provider-specific handler
await provider.handle_websocket(websocket, workflow_id, user_id, workflow_run_id)
await provider.handle_websocket(
websocket, workflow_id, user_id, workflow_run_id
)
except Exception as e:
logger.error(f"Error in WebSocket connection: {e}")
await websocket.close(1011, "Internal server error")
@ -271,44 +270,46 @@ async def handle_twilio_status_callback(
x_webhook_signature: Optional[str] = Header(None),
):
"""Handle Twilio-specific status callbacks."""
# Parse form data
form_data = await request.form()
callback_data = dict(form_data)
logger.info(
f"[run {workflow_run_id}] Received status callback: {json.dumps(callback_data)}"
)
# Get workflow run to find organization
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(f"Workflow run {workflow_run_id} not found for status callback")
return {"status": "ignored", "reason": "workflow_run_not_found"}
# Get workflow and provider
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning(f"Workflow {workflow_run.workflow_id} not found")
return {"status": "ignored", "reason": "workflow_not_found"}
provider = await get_telephony_provider(workflow.organization_id)
if x_webhook_signature:
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
full_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
is_valid = await provider.verify_webhook_signature(
full_url, callback_data, x_webhook_signature
)
if not is_valid:
logger.warning(f"Invalid webhook signature for workflow run {workflow_run_id}")
logger.warning(
f"Invalid webhook signature for workflow run {workflow_run_id}"
)
return {"status": "error", "reason": "invalid_signature"}
# Parse the callback data into generic format
parsed_data = provider.parse_status_callback(callback_data)
# Create StatusCallbackRequest from parsed data
status_update = StatusCallbackRequest(
call_id=parsed_data["call_id"],
@ -317,22 +318,20 @@ async def handle_twilio_status_callback(
to_number=parsed_data.get("to_number"),
direction=parsed_data.get("direction"),
duration=parsed_data.get("duration"),
extra=parsed_data.get("extra", {})
extra=parsed_data.get("extra", {}),
)
# Process the status update
await _process_status_update(workflow_run_id, status_update, workflow_run)
return {"status": "success"}
async def _process_status_update(
workflow_run_id: int,
status: StatusCallbackRequest,
workflow_run: any
workflow_run_id: int, status: StatusCallbackRequest, workflow_run: any
):
"""Process status updates from telephony providers."""
# Log the status callback
telephony_callback_logs = workflow_run.logs.get("telephony_status_callbacks", [])
telephony_callback_log = {
@ -340,31 +339,29 @@ async def _process_status_update(
"timestamp": datetime.now(UTC).isoformat(),
"call_id": status.call_id,
"duration": status.duration,
**status.extra # Include provider-specific data
**status.extra, # Include provider-specific data
}
telephony_callback_logs.append(telephony_callback_log)
# Update workflow run logs
await db_client.update_workflow_run(
run_id=workflow_run_id,
logs={"telephony_status_callbacks": telephony_callback_logs},
)
# Handle call completion
if status.status == "completed":
logger.info(
f"[run {workflow_run_id}] Call completed with duration: {status.duration}s"
)
# Release concurrent slot if this was a campaign call
if workflow_run.campaign_id:
await campaign_call_dispatcher.release_call_slot(workflow_run_id)
# Mark workflow run as completed
await db_client.update_workflow_run(
run_id=workflow_run_id, is_completed=True
)
await db_client.update_workflow_run(run_id=workflow_run_id, is_completed=True)
# Publish campaign event if applicable
if workflow_run.campaign_id:
publisher = await get_campaign_event_publisher()
@ -374,32 +371,40 @@ async def _process_status_update(
queued_run_id=workflow_run.queued_run_id,
call_duration=int(status.duration) if status.duration else 0,
)
elif status.status in ["failed", "busy", "no-answer", "canceled"]:
logger.warning(f"[run {workflow_run_id}] Call failed with status: {status.status}")
logger.warning(
f"[run {workflow_run_id}] Call failed with status: {status.status}"
)
# Release concurrent slot for terminal statuses if this was a campaign call
if workflow_run.campaign_id:
await campaign_call_dispatcher.release_call_slot(workflow_run_id)
# Check if retry is needed for campaign calls (busy/no-answer)
if status.status in ["busy", "no-answer"] and workflow_run.campaign_id:
publisher = await get_campaign_event_publisher()
await publisher.publish_retry_needed(
workflow_run_id=workflow_run_id,
reason=status.status.replace("-", "_"), # Convert no-answer to no_answer
reason=status.status.replace(
"-", "_"
), # Convert no-answer to no_answer
campaign_id=workflow_run.campaign_id,
queued_run_id=workflow_run.queued_run_id,
)
# Mark workflow run as completed with failure tags
call_tags = workflow_run.gathered_context.get("call_tags", []) if workflow_run.gathered_context else []
call_tags = (
workflow_run.gathered_context.get("call_tags", [])
if workflow_run.gathered_context
else []
)
call_tags.extend(["not_connected", f"telephony_{status.status.lower()}"])
await db_client.update_workflow_run(
run_id=workflow_run_id,
is_completed=True,
gathered_context={"call_tags": call_tags}
gathered_context={"call_tags": call_tags},
)
@ -409,20 +414,20 @@ async def handle_vonage_events(
workflow_run_id: int,
):
"""Handle Vonage-specific event webhooks.
Vonage sends all call events to a single endpoint.
Events include: started, ringing, answered, complete, failed, etc.
"""
# Parse the event data
event_data = await request.json()
logger.info(f"[run {workflow_run_id}] Received Vonage event: {event_data}")
# Get workflow run for processing
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.error(f"[run {workflow_run_id}] Workflow run not found")
return {"status": "error", "message": "Workflow run not found"}
# For a completed call that includes cost info, capture it immediately
if event_data.get("status") == "completed":
# Vonage sometimes includes price info in the webhook
@ -436,27 +441,32 @@ async def handle_vonage_events(
if "rate" in event_data:
cost_info["vonage_webhook_rate"] = float(event_data["rate"])
if "duration" in event_data:
cost_info["vonage_webhook_duration"] = int(event_data["duration"])
cost_info["vonage_webhook_duration"] = int(
event_data["duration"]
)
await db_client.update_workflow_run(
run_id=workflow_run_id,
cost_info=cost_info
run_id=workflow_run_id, cost_info=cost_info
)
logger.info(
f"[run {workflow_run_id}] Captured Vonage cost info from webhook"
)
logger.info(f"[run {workflow_run_id}] Captured Vonage cost info from webhook")
except Exception as e:
logger.error(f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}")
logger.error(
f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}"
)
# Get workflow and provider
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.error(f"[run {workflow_run_id}] Workflow not found")
return {"status": "error", "message": "Workflow not found"}
provider = await get_telephony_provider(workflow.organization_id)
# Parse the event data into generic format
parsed_data = provider.parse_status_callback(event_data)
# Create StatusCallbackRequest from parsed data
status_update = StatusCallbackRequest(
call_id=parsed_data["call_id"],
@ -465,11 +475,11 @@ async def handle_vonage_events(
to_number=parsed_data.get("to_number"),
direction=parsed_data.get("direction"),
duration=parsed_data.get("duration"),
extra=parsed_data.get("extra", {})
extra=parsed_data.get("extra", {}),
)
# Process the status update
await _process_status_update(workflow_run_id, status_update, workflow_run)
# Return 204 No Content as expected by Vonage
return {"status": "ok"}
return {"status": "ok"}

View file

@ -124,7 +124,9 @@ class SignalingManager:
)
else:
# Create new connection using correct SmallWebRTC API
pc = SmallWebRTCConnection(ice_servers=ice_servers, connection_timeout_secs=60)
pc = SmallWebRTCConnection(
ice_servers=ice_servers, connection_timeout_secs=60
)
# Set the pc_id before initialization so it's available in get_answer()
pc._pc_id = pc_id

View file

@ -7,10 +7,10 @@ from loguru import logger
from api.db import db_client
from api.db.models import QueuedRunModel, WorkflowRunModel
from api.enums import OrganizationConfigurationKey, WorkflowRunMode
from api.enums import OrganizationConfigurationKey
from api.services.campaign.rate_limiter import rate_limiter
from api.services.telephony.factory import get_telephony_provider
from api.services.telephony.base import TelephonyProvider
from api.services.telephony.factory import get_telephony_provider
from api.utils.tunnel import TunnelURLProvider
@ -238,7 +238,7 @@ class CampaignCallDispatcher:
f"&campaign_id={campaign.id}"
f"&organization_id={campaign.organization_id}"
)
call_result = await provider.initiate_call(
to_number=phone_number,
webhook_url=webhook_url,
@ -255,7 +255,9 @@ class CampaignCallDispatcher:
)
# Update workflow run as failed
telephony_callback_logs = workflow_run.logs.get("telephony_status_callbacks", [])
telephony_callback_logs = workflow_run.logs.get(
"telephony_status_callbacks", []
)
telephony_callback_log = {
"status": "failed",
"timestamp": datetime.now(UTC).isoformat(),

View file

@ -24,6 +24,9 @@ from api.services.workflow.dto import ReactFlowDTO
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
)
from pipecat.processors.filters.stt_mute_filter import (
STTMuteConfig,
STTMuteFilter,
@ -83,7 +86,8 @@ class LoopTalkPipelineBuilder:
audio_buffer, audio_synchronizer, transcript, context = (
create_pipeline_components(audio_config)
)
context_aggregator = llm.create_context_aggregator(context)
context_aggregator = LLMContextAggregatorPair(context)
# Get workflow graph
workflow_graph = WorkflowGraph(
@ -113,7 +117,6 @@ class LoopTalkPipelineBuilder:
pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor(
max_call_duration_seconds=300,
max_duration_end_task_callback=engine.create_max_duration_callback(),
llm_generated_text_callback=engine.create_llm_generated_text_callback(),
generation_started_callback=engine.create_generation_started_callback(),
)

View file

@ -272,14 +272,6 @@ class LoopTalkTestOrchestrator:
await task.cancel()
# Connect the context aggregator events to engine
@assistant_context_aggregator.event_handler("on_push_aggregation")
async def on_assistant_aggregator_push_context(_aggregator):
logger.debug(
"Assistant aggregator push context flushing pending transitions"
)
await engine.flush_pending_transitions()
# Register custom audio and transcript handlers for LoopTalk
await self._register_looptalk_handlers(
audio_synchronizer, transcript, test_session_id, role

View file

@ -1,69 +0,0 @@
"""Engine Pre-Aggregator Processor
This processor sits before the user context aggregator in the pipeline and handles
engine-specific callbacks for frames that need to be processed before aggregation.
This ensures the engine can update context before the aggregator generates LLM frames.
"""
from typing import Awaitable, Callable, Optional
from loguru import logger
from api.services.pipecat.exceptions import VoicemailDetectedException
from pipecat.frames.frames import (
Frame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class EnginePreAggregatorProcessor(FrameProcessor):
"""
Processor that handles engine callbacks before user context aggregation.
This processor is positioned before the user context aggregator to ensure
the engine can update LLM context before aggregation occurs.
"""
def __init__(
self,
user_started_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
user_stopped_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
**kwargs,
):
super().__init__(**kwargs)
self._user_started_speaking_callback = user_started_speaking_callback
self._user_stopped_speaking_callback = user_stopped_speaking_callback
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# Handle frames that need engine processing before aggregation
if isinstance(frame, UserStartedSpeakingFrame):
await self._handle_user_started_speaking()
elif isinstance(frame, UserStoppedSpeakingFrame):
try:
await self._handle_user_stopped_speaking()
except VoicemailDetectedException:
# We have detected voicemail, lets not
# forward the UserStoppedSpeakingFrame, so that
# we don't issue an llm call from user context
# aggregator
logger.debug("Voicemail detected, not pushing UserStoppedSpeakingFrame")
return
# Always push the frame downstream
await self.push_frame(frame, direction)
async def _handle_user_started_speaking(self):
"""Handle UserStartedSpeakingFrame before aggregation."""
if self._user_started_speaking_callback:
# logger.debug("Engine pre-aggregator: User started speaking")
await self._user_started_speaking_callback()
async def _handle_user_stopped_speaking(self):
"""Handle UserStoppedSpeakingFrame before aggregation."""
if self._user_stopped_speaking_callback:
# logger.debug("Engine pre-aggregator: User stopped speaking")
await self._user_stopped_speaking_callback()

View file

@ -9,7 +9,7 @@ from api.constants import (
from api.services.pipecat.audio_config import AudioConfig
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
from pipecat.processors.audio.audio_synchronizer import AudioSynchronizer
from pipecat.processors.transcript_processor import TranscriptProcessor
@ -39,7 +39,7 @@ def create_pipeline_components(audio_config: AudioConfig, engine: "PipecatEngine
assistant_correct_aggregation_callback=engine.create_aggregation_correction_callback()
)
context = OpenAILLMContext()
context = LLMContext()
return audio_buffer, audio_synchronizer, transcript, context
@ -58,7 +58,6 @@ def build_pipeline(
stt_mute_filter,
pipeline_metrics_aggregator,
user_idle_disconnect,
engine_pre_aggregator_processor=None,
):
"""Build the main pipeline with all components"""
# Register processors with synchronizer for merged audio
@ -69,16 +68,12 @@ def build_pipeline(
processors = [
transport.input(), # Transport user input
audio_buffer.input(), # Record input audio (only processes InputAudioRawFrame)
stt_mute_filter,
stt, # STT can now have audio_passthrough=False
stt_mute_filter, # STTMuteFilters don't let VAD related events pass through if muted
user_idle_disconnect,
transcript.user(),
]
# Insert engine pre-aggregator processor if provided (before user aggregator)
if engine_pre_aggregator_processor:
processors.append(engine_pre_aggregator_processor)
processors.extend(
[
user_context_aggregator,

View file

@ -7,7 +7,6 @@ from pipecat.frames.frames import (
Frame,
HeartbeatFrame,
LLMFullResponseStartFrame,
LLMGeneratedTextFrame,
LLMTextFrame,
StartFrame,
TTSSpeakFrame,
@ -26,7 +25,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
self,
max_call_duration_seconds: int = 300,
max_duration_end_task_callback: Optional[Callable[[], Awaitable[None]]] = None,
llm_generated_text_callback: Optional[Callable[[], Awaitable[None]]] = None,
generation_started_callback: Optional[Callable[[], Awaitable[None]]] = None,
llm_text_frame_callback: Optional[Callable[[str], Awaitable[None]]] = None,
):
@ -34,7 +32,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
self._start_time = None
self._max_call_duration_seconds = max_call_duration_seconds
self._max_duration_end_task_callback = max_duration_end_task_callback
self._llm_generated_text_callback = llm_generated_text_callback
self._generation_started_callback = generation_started_callback
self._llm_text_frame_callback = llm_text_frame_callback
self._end_task_frame_pushed = False
@ -46,8 +43,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
await self._start(frame)
elif isinstance(frame, HeartbeatFrame):
await self._check_call_duration()
elif isinstance(frame, LLMGeneratedTextFrame):
await self._generated_text_frame(frame)
elif isinstance(frame, LLMFullResponseStartFrame):
await self._generation_started()
elif (
@ -74,11 +69,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
"Max call duration exceeded. Skipping EndTaskFrame since already sent"
)
async def _generated_text_frame(self, _: LLMGeneratedTextFrame):
"""Handle LLMGeneratedTextFrame."""
if self._llm_generated_text_callback is not None:
await self._llm_generated_text_callback()
async def _generation_started(self):
if self._generation_started_callback:
await self._generation_started_callback()

View file

@ -7,9 +7,6 @@ from api.db import db_client
from api.db.models import WorkflowModel
from api.enums import WorkflowRunMode
from api.services.pipecat.audio_config import AudioConfig, create_audio_config
from api.services.pipecat.engine_pre_aggregator_processor import (
EnginePreAggregatorProcessor,
)
from api.services.pipecat.event_handlers import (
register_audio_data_handler,
register_task_event_handler,
@ -43,6 +40,9 @@ from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from pipecat.pipeline.runner import PipelineRunner
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
)
from pipecat.processors.filters.stt_mute_filter import (
STTMuteConfig,
STTMuteFilter,
@ -119,7 +119,7 @@ async def run_pipeline_vonage(
user_id: int,
):
"""Run pipeline for Vonage WebSocket connections.
Vonage uses raw PCM audio over WebSocket instead of base64-encoded μ-law.
The audio is transmitted as binary frames at 16kHz by default.
"""
@ -137,7 +137,9 @@ async def run_pipeline_vonage(
if "vad_configuration" in workflow.workflow_configurations:
vad_config = workflow.workflow_configurations["vad_configuration"]
if "ambient_noise_configuration" in workflow.workflow_configurations:
ambient_noise_config = workflow.workflow_configurations["ambient_noise_configuration"]
ambient_noise_config = workflow.workflow_configurations[
"ambient_noise_configuration"
]
try:
# Setup audio config for Vonage using the centralized config
@ -355,21 +357,14 @@ async def _run_pipeline(
expect_stripped_words=True,
correct_aggregation_callback=engine.create_aggregation_correction_callback(),
)
context_aggregator = llm.create_context_aggregator(
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params
)
# Create engine pre-aggregator processor for speaking events
engine_pre_aggregator_processor = EnginePreAggregatorProcessor(
user_started_speaking_callback=engine.create_user_started_speaking_callback(),
user_stopped_speaking_callback=engine.create_user_stopped_speaking_callback(),
)
# Create usage metrics aggregator with engine's callback
pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor(
max_call_duration_seconds=max_call_duration_seconds,
max_duration_end_task_callback=engine.create_max_duration_callback(),
llm_generated_text_callback=engine.create_llm_generated_text_callback(),
generation_started_callback=engine.create_generation_started_callback(),
llm_text_frame_callback=engine.handle_llm_text_frame,
# Note: speaking event callbacks are now handled by pre-aggregator processor
@ -396,11 +391,6 @@ async def _run_pipeline(
user_context_aggregator = context_aggregator.user()
assistant_context_aggregator = context_aggregator.assistant()
@assistant_context_aggregator.event_handler("on_push_aggregation")
async def on_assistant_aggregator_push_context(_aggregator):
logger.debug("Assistant aggregator push context flushing pending transitions")
await engine.flush_pending_transitions(source="context_push")
# Build the pipeline with the STT mute filter and context controller
pipeline = build_pipeline(
transport,
@ -416,7 +406,6 @@ async def _run_pipeline(
stt_mute_filter,
pipeline_metrics_aggregator,
user_idle_disconnect,
engine_pre_aggregator_processor=engine_pre_aggregator_processor,
)
# Create pipeline task with audio configuration

View file

@ -165,14 +165,15 @@ async def create_vonage_transport(
# Use the factory to load config from database
from api.services.telephony.factory import load_telephony_config
config = await load_telephony_config(organization_id)
if config.get("provider") != "vonage":
raise ValueError(f"Expected Vonage provider, got {config.get('provider')}")
application_id = config.get("application_id")
private_key = config.get("private_key")
if not application_id or not private_key:
raise ValueError(
f"Incomplete Vonage configuration for organization {organization_id}"
@ -186,8 +187,8 @@ async def create_vonage_transport(
private_key=private_key,
params=VonageFrameSerializer.InputParams(
vonage_sample_rate=audio_config.transport_in_sample_rate,
sample_rate=audio_config.pipeline_sample_rate
)
sample_rate=audio_config.pipeline_sample_rate,
),
)
# Important: Vonage uses binary WebSocket mode, not text

View file

@ -3,6 +3,7 @@ Base telephony provider interface for abstracting telephony services.
This allows easy switching between different providers (Twilio, Vonage, etc.)
while keeping business logic decoupled from specific implementations.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional
@ -14,10 +15,15 @@ if TYPE_CHECKING:
@dataclass
class CallInitiationResult:
"""Standardized response from initiate_call across all providers."""
call_id: str # Provider's call identifier (SID for Twilio, UUID for Vonage)
status: str # Initial status (e.g., "queued", "initiated", "started")
provider_metadata: Dict[str, Any] = field(default_factory=dict) # Data that needs to be persisted
raw_response: Dict[str, Any] = field(default_factory=dict) # Full provider response for debugging
call_id: str # Provider's call identifier (SID for Twilio, UUID for Vonage)
status: str # Initial status (e.g., "queued", "initiated", "started")
provider_metadata: Dict[str, Any] = field(
default_factory=dict
) # Data that needs to be persisted
raw_response: Dict[str, Any] = field(
default_factory=dict
) # Full provider response for debugging
class TelephonyProvider(ABC):
@ -25,6 +31,7 @@ class TelephonyProvider(ABC):
Abstract base class for telephony providers.
All telephony providers must implement these core methods.
"""
PROVIDER_NAME = None
WEBHOOK_ENDPOINT = None
@ -38,13 +45,13 @@ class TelephonyProvider(ABC):
) -> CallInitiationResult:
"""
Initiate an outbound call.
Args:
to_number: The destination phone number
webhook_url: The URL to receive call events
workflow_run_id: Optional workflow run ID for tracking
**kwargs: Provider-specific additional parameters
Returns:
CallInitiationResult with standardized call details
"""
@ -54,10 +61,10 @@ class TelephonyProvider(ABC):
async def get_call_status(self, call_id: str) -> Dict[str, Any]:
"""
Get the current status of a call.
Args:
call_id: The provider-specific call identifier
Returns:
Dict containing call status information
"""
@ -67,7 +74,7 @@ class TelephonyProvider(ABC):
async def get_available_phone_numbers(self) -> List[str]:
"""
Get list of available phone numbers for this provider.
Returns:
List of phone numbers that can be used for outbound calls
"""
@ -77,7 +84,7 @@ class TelephonyProvider(ABC):
def validate_config(self) -> bool:
"""
Validate that the provider is properly configured.
Returns:
True if configuration is valid, False otherwise
"""
@ -89,12 +96,12 @@ class TelephonyProvider(ABC):
) -> bool:
"""
Verify webhook signature for security.
Args:
url: The webhook URL
params: The webhook parameters
signature: The signature to verify
Returns:
True if signature is valid, False otherwise
"""
@ -106,12 +113,12 @@ class TelephonyProvider(ABC):
) -> str:
"""
Generate the initial webhook response for starting a call session.
Args:
workflow_id: The workflow ID
user_id: The user ID
workflow_run_id: The workflow run ID
Returns:
Provider-specific response (e.g., TwiML for Twilio)
"""
@ -121,10 +128,10 @@ class TelephonyProvider(ABC):
async def get_call_cost(self, call_id: str) -> Dict[str, Any]:
"""
Get cost information for a completed call.
Args:
call_id: Provider-specific call identifier (SID for Twilio, UUID for Vonage)
Returns:
Dict containing:
- cost_usd: The cost in USD as float
@ -138,10 +145,10 @@ class TelephonyProvider(ABC):
def parse_status_callback(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Parse provider-specific status callback data into generic format.
Args:
data: Raw callback data from the provider
Returns:
Dict with standardized fields:
- call_id: Provider's call identifier
@ -163,14 +170,14 @@ class TelephonyProvider(ABC):
) -> None:
"""
Handle provider-specific WebSocket connection for real-time call audio.
This method encapsulates all provider-specific WebSocket handshake and
message routing logic, keeping the main websocket endpoint clean.
Args:
websocket: The WebSocket connection
workflow_id: The workflow ID
user_id: The user ID
workflow_run_id: The workflow run ID
"""
pass
pass

View file

@ -3,8 +3,8 @@ Factory for creating telephony providers.
Handles configuration loading from environment (OSS) or database (SaaS).
The providers themselves don't know or care where config comes from.
"""
import os
from typing import Any, Dict, Optional
from typing import Any, Dict
from loguru import logger
@ -18,36 +18,36 @@ from api.services.telephony.providers.vonage_provider import VonageProvider
async def load_telephony_config(organization_id: int) -> Dict[str, Any]:
"""
Load telephony configuration from database.
Args:
organization_id: Organization ID for database config
Returns:
Configuration dictionary with provider type and credentials
Raises:
ValueError: If no configuration found for the organization
"""
if not organization_id:
raise ValueError("Organization ID is required to load telephony configuration")
logger.debug(f"Loading telephony config from database for org {organization_id}")
config = await db_client.get_configuration(
organization_id,
OrganizationConfigurationKey.TELEPHONY_CONFIGURATION.value,
)
if config and config.value:
# Simple single-provider format
provider = config.value.get("provider", "twilio")
if provider == "twilio":
return {
"provider": "twilio",
"account_sid": config.value.get("account_sid"),
"auth_token": config.value.get("auth_token"),
"from_numbers": config.value.get("from_numbers", [])
"from_numbers": config.value.get("from_numbers", []),
}
elif provider == "vonage":
return {
@ -56,41 +56,41 @@ async def load_telephony_config(organization_id: int) -> Dict[str, Any]:
"private_key": config.value.get("private_key"),
"api_key": config.value.get("api_key"),
"api_secret": config.value.get("api_secret"),
"from_numbers": config.value.get("from_numbers", [])
"from_numbers": config.value.get("from_numbers", []),
}
else:
raise ValueError(f"Unknown provider in config: {provider}")
raise ValueError(f"No telephony configuration found for organization {organization_id}")
raise ValueError(
f"No telephony configuration found for organization {organization_id}"
)
async def get_telephony_provider(
organization_id: int
) -> TelephonyProvider:
async def get_telephony_provider(organization_id: int) -> TelephonyProvider:
"""
Factory function to create telephony providers.
Args:
organization_id: Organization ID (required)
Returns:
Configured telephony provider instance
Raises:
ValueError: If provider type is unknown or configuration is invalid
"""
# Load configuration
config = await load_telephony_config(organization_id)
provider_type = config.get("provider", "twilio")
logger.info(f"Creating {provider_type} telephony provider")
# Create provider instance with configuration
if provider_type == "twilio":
return TwilioProvider(config)
elif provider_type == "vonage":
return VonageProvider(config)
else:
raise ValueError(f"Unknown telephony provider: {provider_type}")

View file

@ -1 +1 @@
# Telephony provider implementations
# Telephony provider implementations

View file

@ -1,6 +1,7 @@
"""
Twilio implementation of the TelephonyProvider interface.
"""
import json
import random
from typing import TYPE_CHECKING, Any, Dict, List, Optional
@ -9,9 +10,9 @@ import aiohttp
from loguru import logger
from twilio.request_validator import RequestValidator
from api.enums import WorkflowRunMode
from api.services.telephony.base import CallInitiationResult, TelephonyProvider
from api.utils.tunnel import TunnelURLProvider
from api.enums import WorkflowRunMode
if TYPE_CHECKING:
from fastapi import WebSocket
@ -22,14 +23,14 @@ class TwilioProvider(TelephonyProvider):
Twilio implementation of TelephonyProvider.
Accepts configuration and works the same regardless of OSS/SaaS mode.
"""
PROVIDER_NAME = WorkflowRunMode.TWILIO.value
WEBHOOK_ENDPOINT = "twiml"
def __init__(self, config: Dict[str, Any]):
"""
Initialize TwilioProvider with configuration.
Args:
config: Dictionary containing:
- account_sid: Twilio Account SID
@ -39,11 +40,11 @@ class TwilioProvider(TelephonyProvider):
self.account_sid = config.get("account_sid")
self.auth_token = config.get("auth_token")
self.from_numbers = config.get("from_numbers", [])
# Handle both single number (string) and multiple numbers (list)
if isinstance(self.from_numbers, str):
self.from_numbers = [self.from_numbers]
self.base_url = f"https://api.twilio.com/2010-04-01/Accounts/{self.account_sid}"
async def initiate_call(
@ -58,32 +59,35 @@ class TwilioProvider(TelephonyProvider):
"""
if not self.validate_config():
raise ValueError("Twilio provider not properly configured")
endpoint = f"{self.base_url}/Calls.json"
# Select a random phone number
from_number = random.choice(self.from_numbers)
logger.info(f"Selected phone number {from_number} for outbound call")
# Prepare call data
data = {
"To": to_number,
"From": from_number,
"Url": webhook_url
}
data = {"To": to_number, "From": from_number, "Url": webhook_url}
# Add status callback if workflow_run_id provided
if workflow_run_id:
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
callback_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
data.update({
"StatusCallback": callback_url,
"StatusCallbackEvent": ["initiated", "ringing", "answered", "completed"],
"StatusCallbackMethod": "POST"
})
data.update(
{
"StatusCallback": callback_url,
"StatusCallbackEvent": [
"initiated",
"ringing",
"answered",
"completed",
],
"StatusCallbackMethod": "POST",
}
)
data.update(kwargs)
# Make the API request
async with aiohttp.ClientSession() as session:
auth = aiohttp.BasicAuth(self.account_sid, self.auth_token)
@ -91,14 +95,14 @@ class TwilioProvider(TelephonyProvider):
if response.status != 201:
error_data = await response.json()
raise Exception(f"Failed to initiate call: {error_data}")
response_data = await response.json()
return CallInitiationResult(
call_id=response_data["sid"],
status=response_data.get("status", "queued"),
provider_metadata={}, # Twilio doesn't need to persist extra data
raw_response=response_data
raw_response=response_data,
)
async def get_call_status(self, call_id: str) -> Dict[str, Any]:
@ -107,16 +111,16 @@ class TwilioProvider(TelephonyProvider):
"""
if not self.validate_config():
raise ValueError("Twilio provider not properly configured")
endpoint = f"{self.base_url}/Calls/{call_id}.json"
async with aiohttp.ClientSession() as session:
auth = aiohttp.BasicAuth(self.account_sid, self.auth_token)
async with session.get(endpoint, auth=auth) as response:
if response.status != 200:
error_data = await response.json()
raise Exception(f"Failed to get call status: {error_data}")
return await response.json()
async def get_available_phone_numbers(self) -> List[str]:
@ -129,11 +133,7 @@ class TwilioProvider(TelephonyProvider):
"""
Validate Twilio configuration.
"""
return bool(
self.account_sid and
self.auth_token and
self.from_numbers
)
return bool(self.account_sid and self.auth_token and self.from_numbers)
async def verify_webhook_signature(
self, url: str, params: Dict[str, Any], signature: str
@ -144,7 +144,7 @@ class TwilioProvider(TelephonyProvider):
if not self.auth_token:
logger.error("No auth token available for webhook signature verification")
return False
validator = RequestValidator(self.auth_token)
return validator.validate(url, params, signature)
@ -155,7 +155,7 @@ class TwilioProvider(TelephonyProvider):
Generate TwiML response for starting a call session.
"""
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
twiml_content = f"""<?xml version="1.0" encoding="UTF-8"?>
<Response>
<Connect>
@ -168,15 +168,15 @@ class TwilioProvider(TelephonyProvider):
async def get_call_cost(self, call_id: str) -> Dict[str, Any]:
"""
Get cost information for a completed Twilio call.
Args:
call_id: The Twilio Call SID
Returns:
Dict containing cost information
"""
endpoint = f"{self.base_url}/Calls/{call_id}.json"
try:
async with aiohttp.ClientSession() as session:
auth = aiohttp.BasicAuth(self.account_sid, self.auth_token)
@ -188,34 +188,29 @@ class TwilioProvider(TelephonyProvider):
"cost_usd": 0.0,
"duration": 0,
"status": "error",
"error": str(error_data)
"error": str(error_data),
}
call_data = await response.json()
# Twilio returns price as a negative string (e.g., "-0.0085")
price_str = call_data.get("price", "0")
cost_usd = abs(float(price_str)) if price_str else 0.0
# Duration is in seconds as a string
duration = int(call_data.get("duration", "0"))
return {
"cost_usd": cost_usd,
"duration": duration,
"status": call_data.get("status", "unknown"),
"price_unit": call_data.get("price_unit", "USD"),
"raw_response": call_data
"raw_response": call_data,
}
except Exception as e:
logger.error(f"Exception fetching Twilio call cost: {e}")
return {
"cost_usd": 0.0,
"duration": 0,
"status": "error",
"error": str(e)
}
return {"cost_usd": 0.0, "duration": 0, "status": "error", "error": str(e)}
def parse_status_callback(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
@ -228,7 +223,7 @@ class TwilioProvider(TelephonyProvider):
"to_number": data.get("To"),
"direction": data.get("Direction"),
"duration": data.get("CallDuration") or data.get("Duration"),
"extra": data # Include all original data
"extra": data, # Include all original data
}
async def handle_websocket(
@ -240,36 +235,38 @@ class TwilioProvider(TelephonyProvider):
) -> None:
"""
Handle Twilio-specific WebSocket connection.
Twilio sends:
1. "connected" event first
2. "start" event with streamSid and callSid
3. Then audio messages
"""
from api.services.pipecat.run_pipeline import run_pipeline_twilio
try:
# Wait for "connected" event
first_msg = await websocket.receive_text()
msg = json.loads(first_msg)
if msg.get("event") != "connected":
logger.error(f"Expected 'connected' event, got: {msg.get('event')}")
await websocket.close(code=4400, reason="Expected connected event")
return
logger.debug(f"Twilio WebSocket connected for workflow_run {workflow_run_id}")
logger.debug(
f"Twilio WebSocket connected for workflow_run {workflow_run_id}"
)
# Wait for "start" event with stream details
start_msg = await websocket.receive_text()
logger.debug(f"Received start message: {start_msg}")
start_msg = json.loads(start_msg)
if start_msg.get("event") != "start":
logger.error("Expected 'start' event second")
await websocket.close(code=4400, reason="Expected start event")
return
# Extract Twilio-specific identifiers
try:
stream_sid = start_msg["start"]["streamSid"]
@ -278,12 +275,12 @@ class TwilioProvider(TelephonyProvider):
logger.error("Missing streamSid or callSid in start message")
await websocket.close(code=4400, reason="Missing stream identifiers")
return
# Run the Twilio pipeline
await run_pipeline_twilio(
websocket, stream_sid, call_sid, workflow_id, workflow_run_id, user_id
)
except Exception as e:
logger.error(f"Error in Twilio WebSocket handler: {e}")
raise
raise

View file

@ -1,6 +1,7 @@
"""
Vonage (Nexmo) implementation of the TelephonyProvider interface.
"""
import json
import random
import time
@ -10,9 +11,9 @@ import aiohttp
import jwt
from loguru import logger
from api.enums import WorkflowRunMode
from api.services.telephony.base import CallInitiationResult, TelephonyProvider
from api.utils.tunnel import TunnelURLProvider
from api.enums import WorkflowRunMode
if TYPE_CHECKING:
from fastapi import WebSocket
@ -23,14 +24,14 @@ class VonageProvider(TelephonyProvider):
Vonage implementation of TelephonyProvider.
Uses JWT authentication and NCCO for call control.
"""
PROVIDER_NAME = WorkflowRunMode.VONAGE.value
WEBHOOK_ENDPOINT = "ncco"
def __init__(self, config: Dict[str, Any]):
"""
Initialize VonageProvider with configuration.
Args:
config: Dictionary containing:
- api_key: Vonage API Key
@ -44,25 +45,27 @@ class VonageProvider(TelephonyProvider):
self.application_id = config.get("application_id")
self.private_key = config.get("private_key")
self.from_numbers = config.get("from_numbers", [])
# Handle both single number (string) and multiple numbers (list)
if isinstance(self.from_numbers, str):
self.from_numbers = [self.from_numbers]
self.base_url = "https://api.nexmo.com"
def _generate_jwt(self) -> str:
"""Generate JWT token for Vonage API authentication."""
if not self.application_id or not self.private_key:
raise ValueError("Application ID and private key required for JWT generation")
raise ValueError(
"Application ID and private key required for JWT generation"
)
claims = {
"application_id": self.application_id,
"iat": int(time.time()),
"exp": int(time.time()) + 3600,
"jti": str(time.time())
"jti": str(time.time()),
}
return jwt.encode(claims, self.private_key, algorithm="RS256")
async def initiate_call(
@ -77,68 +80,57 @@ class VonageProvider(TelephonyProvider):
"""
if not self.validate_config():
raise ValueError("Vonage provider not properly configured")
endpoint = f"{self.base_url}/v1/calls"
# Select a random phone number
from_number = random.choice(self.from_numbers)
# Remove '+' prefix for Vonage
from_number = from_number.replace("+", "")
to_number = to_number.replace("+", "")
logger.info(f"Selected phone number {from_number} for outbound call")
# Prepare call data
data = {
"to": [{
"type": "phone",
"number": to_number
}],
"from": {
"type": "phone",
"number": from_number
},
"to": [{"type": "phone", "number": to_number}],
"from": {"type": "phone", "number": from_number},
"answer_url": [webhook_url],
"answer_method": "GET"
"answer_method": "GET",
}
# Add event webhook if workflow_run_id provided
if workflow_run_id:
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
event_url = f"https://{backend_endpoint}/api/v1/telephony/vonage/events/{workflow_run_id}"
data.update({
"event_url": [event_url],
"event_method": "POST"
})
data.update({"event_url": [event_url], "event_method": "POST"})
data.update(kwargs)
# Generate JWT token
token = self._generate_jwt()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
# Make the API request
async with aiohttp.ClientSession() as session:
async with session.post(
endpoint,
json=data,
headers=headers
) as response:
async with session.post(endpoint, json=data, headers=headers) as response:
response_data = await response.json()
if response.status != 201:
raise Exception(f"Failed to initiate call: {response_data}")
return CallInitiationResult(
call_id=response_data["uuid"],
status=response_data.get("status", "started"),
provider_metadata={
"call_uuid": response_data["uuid"] # Vonage needs UUID persisted for WebSocket
"call_uuid": response_data[
"uuid"
] # Vonage needs UUID persisted for WebSocket
},
raw_response=response_data
raw_response=response_data,
)
async def get_call_status(self, call_id: str) -> Dict[str, Any]:
@ -147,21 +139,19 @@ class VonageProvider(TelephonyProvider):
"""
if not self.validate_config():
raise ValueError("Vonage provider not properly configured")
endpoint = f"{self.base_url}/v1/calls/{call_id}"
# Generate JWT token
token = self._generate_jwt()
headers = {
"Authorization": f"Bearer {token}"
}
headers = {"Authorization": f"Bearer {token}"}
async with aiohttp.ClientSession() as session:
async with session.get(endpoint, headers=headers) as response:
if response.status != 200:
error_data = await response.json()
raise Exception(f"Failed to get call status: {error_data}")
return await response.json()
async def get_available_phone_numbers(self) -> List[str]:
@ -174,11 +164,7 @@ class VonageProvider(TelephonyProvider):
"""
Validate Vonage configuration.
"""
return bool(
self.application_id and
self.private_key and
self.from_numbers
)
return bool(self.application_id and self.private_key and self.from_numbers)
async def verify_webhook_signature(
self, url: str, params: Dict[str, Any], signature: str
@ -190,14 +176,14 @@ class VonageProvider(TelephonyProvider):
if not self.api_secret:
logger.error("No API secret available for webhook signature verification")
return False
try:
# Vonage sends JWT in Authorization header. Verify the JWT signature
decoded = jwt.decode(
signature,
self.api_secret,
signature,
self.api_secret,
algorithms=["HS256"],
options={"verify_signature": True}
options={"verify_signature": True},
)
return True
except jwt.InvalidTokenError:
@ -211,43 +197,42 @@ class VonageProvider(TelephonyProvider):
NCCO (Nexmo Call Control Objects) is JSON-based, unlike TwiML which is XML.
"""
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
# NCCO for WebSocket connection
ncco = [
{
"action": "connect",
"endpoint": [{
"type": "websocket",
"uri": f"wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}",
"content-type": "audio/l16;rate=16000", # 16kHz Linear PCM
"headers": {}
}]
"endpoint": [
{
"type": "websocket",
"uri": f"wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}",
"content-type": "audio/l16;rate=16000", # 16kHz Linear PCM
"headers": {},
}
],
}
]
return json.dumps(ncco)
def _get_auth_headers(self) -> Dict[str, str]:
"""Generate authorization headers for Vonage API."""
token = self._generate_jwt()
return {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
async def get_call_cost(self, call_id: str) -> Dict[str, Any]:
"""
Get cost information for a completed Vonage call.
Args:
call_id: The Vonage Call UUID
Returns:
Dict containing cost information
"""
headers = self._get_auth_headers()
endpoint = f"https://api.nexmo.com/v1/calls/{call_id}"
try:
async with aiohttp.ClientSession() as session:
async with session.get(endpoint, headers=headers) as response:
@ -258,39 +243,34 @@ class VonageProvider(TelephonyProvider):
"cost_usd": 0.0,
"duration": 0,
"status": "error",
"error": str(error_data)
"error": str(error_data),
}
call_data = await response.json()
# Vonage returns price and rate
# Price is the total cost, rate is the per-minute rate
price = float(call_data.get("price", 0))
cost_usd = price # Vonage returns positive values
# Duration is in seconds
duration = int(call_data.get("duration", 0))
# Get the call status
status = call_data.get("status", "unknown")
return {
"cost_usd": cost_usd,
"duration": duration,
"status": status,
"price_unit": "USD", # Vonage uses USD by default
"rate": call_data.get("rate", 0), # Per-minute rate
"raw_response": call_data
"raw_response": call_data,
}
except Exception as e:
logger.error(f"Exception fetching Vonage call cost: {e}")
return {
"cost_usd": 0.0,
"duration": 0,
"status": "error",
"error": str(e)
}
return {"cost_usd": 0.0, "duration": 0, "status": "error", "error": str(e)}
def parse_status_callback(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
@ -300,14 +280,14 @@ class VonageProvider(TelephonyProvider):
status_map = {
"started": "initiated",
"ringing": "ringing",
"answered": "answered",
"answered": "answered",
"complete": "completed",
"failed": "failed",
"busy": "busy",
"timeout": "no-answer",
"rejected": "busy"
"rejected": "busy",
}
return {
"call_id": data.get("uuid", ""),
"status": status_map.get(data.get("status", ""), data.get("status", "")),
@ -315,7 +295,7 @@ class VonageProvider(TelephonyProvider):
"to_number": data.get("to"),
"direction": data.get("direction"),
"duration": data.get("duration"),
"extra": data # Include all original data
"extra": data, # Include all original data
}
async def handle_websocket(
@ -327,14 +307,14 @@ class VonageProvider(TelephonyProvider):
) -> None:
"""
Handle Vonage-specific WebSocket connection.
Vonage can send:
1. JSON metadata first (websocket:connected event)
2. Or directly start with binary audio
"""
from api.db import db_client
from api.services.pipecat.run_pipeline import run_pipeline_vonage
try:
# Get workflow run to extract call UUID
workflow_run = await db_client.get_workflow_run(workflow_run_id)
@ -342,38 +322,48 @@ class VonageProvider(TelephonyProvider):
logger.error(f"Workflow run {workflow_run_id} not found")
await websocket.close(code=4404, reason="Workflow run not found")
return
# Get workflow for organization info
workflow = await db_client.get_workflow(workflow_id, user_id)
if not workflow:
logger.error(f"Workflow {workflow_id} not found")
await websocket.close(code=4404, reason="Workflow not found")
return
# Extract call UUID from workflow run context
call_uuid = workflow_run.gathered_context.get("call_uuid") if workflow_run.gathered_context else None
call_uuid = (
workflow_run.gathered_context.get("call_uuid")
if workflow_run.gathered_context
else None
)
if not call_uuid:
logger.error(f"No call UUID found for Vonage connection in workflow run {workflow_run_id}")
logger.error(
f"No call UUID found for Vonage connection in workflow run {workflow_run_id}"
)
await websocket.close(code=4400, reason="Missing call UUID")
return
logger.info(f"Vonage WebSocket connected for workflow_run {workflow_run_id}, call_uuid: {call_uuid}")
logger.info(
f"Vonage WebSocket connected for workflow_run {workflow_run_id}, call_uuid: {call_uuid}"
)
# Peek at first message to see if it's metadata or audio
first_msg = await websocket.receive()
if "text" in first_msg:
# JSON metadata - check if it's the connection event
msg = json.loads(first_msg["text"])
if msg.get("event") == "websocket:connected":
logger.debug(f"Received Vonage connection confirmation for {workflow_run_id}")
logger.debug(
f"Received Vonage connection confirmation for {workflow_run_id}"
)
# Continue to pipeline regardless of message type
elif "bytes" in first_msg:
# Binary audio - Vonage started with audio immediately
logger.debug(f"Vonage started with binary audio for {workflow_run_id}")
# The pipeline will handle this first audio chunk
# Run the Vonage pipeline
await run_pipeline_vonage(
websocket,
@ -382,9 +372,9 @@ class VonageProvider(TelephonyProvider):
workflow.organization_id,
workflow_id,
workflow_run_id,
user_id
user_id,
)
except Exception as e:
logger.error(f"Error in Vonage WebSocket handler: {e}")
raise
raise

View file

@ -22,9 +22,7 @@ from pipecat.frames.frames import (
)
from pipecat.serializers.base_serializer import FrameSerializer
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import (
BaseOutputTransport
)
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams

View file

@ -14,14 +14,14 @@ from pipecat.frames.frames import (
CancelFrame,
EndFrame,
FunctionCallResultProperties,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.openai.llm import OpenAILLMContext
from pipecat.transports.base_transport import BaseTransport
from pipecat.utils.enums import EndTaskReason
@ -63,7 +63,7 @@ class PipecatEngine:
*,
task: Optional[PipelineTask] = None,
llm: Optional["LLMService"] = None,
context: Optional[OpenAILLMContext] = None,
context: Optional[LLMContext] = None,
tts: Optional[Any] = None,
transport: Optional[BaseTransport] = None,
workflow: WorkflowGraph,
@ -82,7 +82,6 @@ class PipecatEngine:
self._workflow_run_id = workflow_run_id
self._initialized = False
self._client_disconnected = False
self._pending_function_calls = 0
self._current_node: Optional[Node] = None
self._gathered_context: dict = {}
self._user_response_timeout_task: Optional[asyncio.Task] = None
@ -102,29 +101,9 @@ class PipecatEngine:
self._voicemail_detector = None
self._voicemail_detection_task: Optional[asyncio.Task] = None
# This transition is generated by the llm as part of tool call. This can
# also be accompanied with some content which can be played using TTS. If the
# bot is interrupted, we would cancel this transition (we do cancel this currently when
# the next generation starts in handle_generation_started callback handler.)
self._pending_generated_transition_after_context_push: Optional[
Callable[[], Awaitable[None]]
] = None
# This is the transtion which is typically programmatic transition, and not goes as
# tool call to LLM. This is not interrupted by the user and is done on context push
self._pending_control_transition_after_context_push: Optional[
Callable[[], Awaitable[None]]
] = None
# Flag to determine if the current llm generation has a text completion
self._defer_context_push: bool = False
# Lazy loaded built-in function schemas
self._builtin_function_schemas: Optional[list[dict]] = None
# Flag to control whether to queue context frame
self._queue_context_frame: bool = True
# Track current LLM reference text for TTS aggregation correction
self._current_llm_reference_text: str = ""
@ -211,23 +190,15 @@ class PipecatEngine:
async def _create_transition_func(self, name: str, transition_to_node: str):
async def transition_func(function_call_params: FunctionCallParams) -> None:
"""Inner function that handles the actual tool invocation."""
"""Inner function that handles the node change tool calls"""
try:
# Track pending function call
self._pending_function_calls += 1
logger.debug(
f"Function call pending: {function_call_params.function_name} (total: {self._pending_function_calls})"
)
# For edge functions, prevent LLM completion until transition (run_llm=False)
# For node functions, allow immediate completion (run_llm=True)
async def on_context_updated() -> None:
"""
Framework will run this function after the function call result has been updated in the context.
pipecat framework will run this function after the function call result has been updated in the context.
This way, when we do set_node from within this function, and go for LLM completion with updated
system prompts, the context is updated with function call result.
"""
self._pending_function_calls -= 1
# Perform variable extraction before transitioning to new node
await self._perform_variable_extraction_if_needed(
self._current_node
@ -241,41 +212,14 @@ class PipecatEngine:
on_context_updated=on_context_updated,
)
async def _invoke_result_callback():
"""
Functions are executed immediately when they come from LLM as part of text completion.
But, if the LLM completion also has some text, we would want to not call the function if the user interrupts the speech.
We would also not want the function to be added to context, so that the LLM can call the function again. Hence, we
defer the function invocation until we receive on_context_updated callback, i.e the bot has finished speaking
the text that was generated.
"""
await function_call_params.result_callback(
result, properties=properties
)
if self._defer_context_push:
"""
We set the flag to _defer_context_push when we receive text in the current generation from LLM.
This is set in the handle_llm_generated_text callback handler.
"""
logger.debug(
"Deferring transition function result until context push"
)
# Only one deferred transition should exist at any time.
# Overwrite if one is somehow already set (unexpected).
self._pending_generated_transition_after_context_push = (
_invoke_result_callback
)
else:
"""
If there was no text in the current generation, and we only had function call,
lets invoke the result callback, so that framework can call on_context_updated and
we can do switch node.
"""
await _invoke_result_callback()
# Call results callback from the pipecat framework
# so that a new llm generation can be triggred if
# required
await function_call_params.result_callback(
result, properties=properties
)
except Exception as e:
logger.error(f"Error in transition function {name}: {str(e)}")
self._pending_function_calls = 0
error_result = {"status": "error", "error": str(e)}
await function_call_params.result_callback(error_result)
@ -362,27 +306,6 @@ class PipecatEngine:
]
)
async def _setup_static_start_node_transition(self, node: Node) -> None:
"""Set up the deferred transition for static start nodes."""
if not node.out_edges:
return
next_node_id = node.out_edges[0].target
if not node.wait_for_user_response:
# Normal static start node - transition immediately after context push
async def _deferred_static_transition():
try:
await self.set_node(next_node_id)
except Exception as exc:
logger.error(
f"Error executing deferred static node transition to {next_node_id}: {exc}"
)
self._pending_control_transition_after_context_push = (
_deferred_static_transition
)
async def _perform_variable_extraction_if_needed(
self, previous_node: Optional[Node]
) -> None:
@ -441,17 +364,7 @@ class PipecatEngine:
functions,
) = await self._compose_system_message_functions_for_node(node)
await self._update_llm_context(system_message, functions)
# Queue context frame if needed
if self._queue_context_frame:
await self.task.queue_frame(OpenAILLMContextFrame(self.context))
else:
logger.debug(
f"Not queueing context frame for node: {node.name} as _queue_context_frame is False"
)
# Reset _queue_context_frame as default behavior
self._queue_context_frame = True
await self.task.queue_frame(LLMContextFrame(self.context))
async def set_node(self, node_id: str):
"""
@ -525,12 +438,7 @@ class PipecatEngine:
await asyncio.sleep(delay_duration)
if node.is_static:
# Queue TTS for static start node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
# Set up deferred transition for static start nodes
await self._setup_static_start_node_transition(node)
raise ValueError("Static nodes are not supported!")
else:
# Start generation for non-static start node
await self._setup_llm_context_and_start_generation(node)
@ -538,66 +446,24 @@ class PipecatEngine:
async def _handle_end_node(self, node: Node) -> None:
"""Handle end node execution."""
if node.is_static:
# Queue TTS for static end node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
raise ValueError("Static nodes are not supported!")
else:
# Start generation for non-static end node
await self._setup_llm_context_and_start_generation(node)
# If this end node has extraction enabled, perform extraction immediately
if node.extraction_enabled and node.extraction_variables:
await self._perform_variable_extraction_if_needed(node)
# TODO: Extract disposition code from extracted variables
# Defer send_end_task_frame using _pending_control_transition_after_context_push
# Decide the end-task reason dynamically depending on call_disposition.
async def _deferred_end_task():
# call_disposition is the disposition which is generated from
# llm call based on the conversation so far.
# TODO: Make this more generic based on configuration or llm prompting
disposition = self._gathered_context.get("call_disposition")
if disposition == "XFER":
reason = EndTaskReason.USER_QUALIFIED.value
else:
reason = EndTaskReason.USER_DISQUALIFIED.value
await self.send_end_task_frame(reason)
self._pending_control_transition_after_context_push = _deferred_end_task
await self.send_end_task_frame(EndTaskReason.USER_QUALIFIED.value)
async def _handle_agent_node(self, node: Node) -> None:
"""Handle agent node execution."""
if node.is_static:
# Queue TTS for static agent node
formatted_prompt = self._format_prompt(node.prompt)
await self._queue_tts_response(formatted_prompt)
# Set up deferred transition for static agent nodes
await self._setup_agent_node_transition(node)
raise ValueError("Static nodes are not supported!")
else:
# Set context and functions for non-static agent node
await self._setup_llm_context_and_start_generation(node)
async def _setup_agent_node_transition(self, node: Node) -> None:
"""Set up the deferred transition for static agent nodes."""
if not node.out_edges:
return
next_node_id = node.out_edges[0].target
async def _deferred_static_transition():
try:
await self.set_node(next_node_id)
except Exception as exc:
logger.error(
f"Error executing deferred static node transition to {next_node_id}: {exc}"
)
self._pending_control_transition_after_context_push = (
_deferred_static_transition
)
async def send_end_task_frame(
self,
reason: str,
@ -640,7 +506,7 @@ class PipecatEngine:
# Store the mapped disconnect reason
self._gathered_context["call_disposition"] = mapped_disposition
# TODO: Generalise this, currently tailored to Kapil's use case
# TODO: Generalise this
self._gathered_context["address"] = ", ".join(
[
self._call_context_vars.get("address1", ""),
@ -759,55 +625,6 @@ class PipecatEngine:
return system_message, functions
# ------------------------------------------------------------------
# Pending transition handling
# ------------------------------------------------------------------
async def flush_pending_transitions(self, *, source: str = "context_push"):
"""Execute and clear any pending transitions.
Args:
source: Indicates the trigger that caused this flush:
- "context_push": the assistant context aggregator completed a push.
"""
if source != "context_push":
raise ValueError("Invalid flush source expected 'context_push'")
len_pending_functions = 0
if self._pending_generated_transition_after_context_push is not None:
len_pending_functions += 1
if self._pending_control_transition_after_context_push is not None:
len_pending_functions += 1
# Nothing to do
if len_pending_functions == 0:
return
logger.debug(
f"Flushing {len_pending_functions} pending transition(s) after {source.replace('_', ' ')}"
)
# Generated transition
if self._pending_generated_transition_after_context_push is not None:
pending_cb = self._pending_generated_transition_after_context_push
self._pending_generated_transition_after_context_push = None
try:
await pending_cb()
except Exception as exc: # pragma: no cover
logger.error(f"Error executing deferred transition: {exc}")
# Control transition (context push)
if self._pending_control_transition_after_context_push is not None:
logger.debug("Executing control transition after context push")
static_cb = self._pending_control_transition_after_context_push
self._pending_control_transition_after_context_push = None
try:
await static_cb()
except Exception as exc: # pragma: no cover
logger.error(f"Error executing deferred static node transition: {exc}")
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
"""
This callback is called by STTMuteFilter to determine if the STT should be muted.
@ -828,15 +645,6 @@ class PipecatEngine:
"""
return engine_callbacks.create_max_duration_callback(self)
def create_llm_generated_text_callback(self):
"""
This callback is called when some text is generated by the LLM.
We use this to defer the result_callback of the node transition functions if
there is set_node called along with some text generated. This way, we will
have the context sent in the next generation from new node.
"""
return engine_callbacks.create_llm_generated_text_callback(self)
def create_generation_started_callback(self):
"""
This callback is called when a new generation starts.
@ -844,26 +652,12 @@ class PipecatEngine:
"""
return engine_callbacks.create_generation_started_callback(self)
def create_user_stopped_speaking_callback(self):
"""
This callback is called when the user stops speaking.
We use this to handle transitions when wait_for_user_response is enabled.
"""
return engine_callbacks.create_user_stopped_speaking_callback(self)
def create_user_started_speaking_callback(self):
"""
This callback is called when the user starts speaking.
We use this to handle wait_for_user_greeting functionality.
"""
return engine_callbacks.create_user_started_speaking_callback(self)
def create_aggregation_correction_callback(self) -> Callable[[str], str]:
"""Create a callback that corrects corrupted aggregation using reference text."""
return engine_callbacks.create_aggregation_correction_callback(self)
def set_context(self, context: OpenAILLMContext) -> None:
"""Set the OpenAI LLM context.
def set_context(self, context: LLMContext) -> None:
"""Set the LLM context.
This allows setting the context after the engine has been created,
which is useful when the context needs to be created after the engine.

View file

@ -14,6 +14,7 @@ import re
from typing import TYPE_CHECKING, Awaitable, Callable
from loguru import logger
from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
@ -23,9 +24,8 @@ from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
from pipecat.utils.enums import EndTaskReason
if TYPE_CHECKING:
from pipecat.processors.user_idle_processor import UserIdleProcessor
from api.services.workflow.pipecat_engine import PipecatEngine
from pipecat.processors.user_idle_processor import UserIdleProcessor
# ---------------------------------------------------------------------------
@ -114,23 +114,6 @@ def create_max_duration_callback(engine: "PipecatEngine"):
return handle_max_duration
# ---------------------------------------------------------------------------
# LLM-generated-text handling
# ---------------------------------------------------------------------------
def create_llm_generated_text_callback(engine: "PipecatEngine"):
"""Return a callback invoked when the LLM emits text (not only tool calls)."""
async def handle_llm_generated_text(): # noqa: D401
logger.debug(
"Generation has text content in current response - deferring context push from set_node"
)
engine._defer_context_push = True
return handle_llm_generated_text
# ---------------------------------------------------------------------------
# Generation-started handling
# ---------------------------------------------------------------------------
@ -140,96 +123,13 @@ def create_generation_started_callback(engine: "PipecatEngine"):
"""Return a callback that resets flags at the start of each LLM generation."""
async def handle_generation_started(): # noqa: D401
logger.debug("LLM generation started - resetting defer flags and tool counters")
engine._defer_context_push = False
engine._pending_function_calls = 0
engine._pending_generated_transition_after_context_push = None
logger.debug("LLM generation started in callback processor")
# Clear reference text from previous generation
engine._current_llm_reference_text = ""
return handle_generation_started
# ---------------------------------------------------------------------------
# User-stopped-speaking handling
# ---------------------------------------------------------------------------
def create_user_stopped_speaking_callback(engine: "PipecatEngine"):
"""Return a callback that handles when the user stops speaking.
According to simplified flow:
- For start nodes with wait_for_user_response=True:
- Cancel timeout task if still active
- Transition to next node with _queue_context_frame=False
"""
async def handle_user_stopped_speaking():
# Only handle if current node is a start node with wait_for_user_response
if (
engine._current_node
and engine._current_node.is_start
and engine._current_node.wait_for_user_response
and engine._current_node.out_edges
):
# Cancel timeout task if it's still active
if (
engine._user_response_timeout_task
and not engine._user_response_timeout_task.done()
):
logger.debug("Cancelling user response timeout - user responded")
engine._user_response_timeout_task.cancel()
engine._user_response_timeout_task = None
# Transition to next node
next_node_id = engine._current_node.out_edges[0].target
logger.debug(
f"User stopped speaking after wait_for_user_response - transitioning to: {next_node_id}"
)
# Set flag to not queue context frame since
# it will be pushed by user context aggregator
# we are just setting the context with next node's
# functions and prompts
engine._queue_context_frame = False
# Transition to next node
await engine.set_node(next_node_id)
return handle_user_stopped_speaking
# ---------------------------------------------------------------------------
# User-started-speaking handling
# ---------------------------------------------------------------------------
def create_user_started_speaking_callback(engine: "PipecatEngine"):
"""Return a callback that handles when the user starts speaking.
According to simplified flow:
- For start nodes with wait_for_user_response=True:
- Cancel the timeout timer if it exists (but don't set to None)
"""
async def handle_user_started_speaking():
# Only handle if current node is a start node with wait_for_user_response
if (
engine._current_node
and engine._current_node.is_start
and engine._current_node.wait_for_user_response
and engine._user_response_timeout_task
and not engine._user_response_timeout_task.done()
):
logger.debug(
"User started speaking during wait_for_user_response - cancelling timeout timer"
)
engine._user_response_timeout_task.cancel()
# Don't set to None here - let user_stopped_speaking handle the transition
return handle_user_started_speaking
def create_aggregation_correction_callback(engine: "PipecatEngine"):
"""Create a callback that uses engine's reference text to correct corrupted aggregation."""

View file

@ -2,16 +2,10 @@ from __future__ import annotations
from typing import Any, Dict, List
from google.genai.types import (
Content,
Part,
)
from api.utils.template_renderer import render_template
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.services.google.llm import GoogleLLMContext
from pipecat.services.openai.llm import OpenAILLMContext
from api.utils.template_renderer import render_template
from pipecat.processors.aggregators.llm_context import LLMContext
__all__ = [
"get_function_schema",
@ -44,7 +38,7 @@ def get_function_schema(
def update_llm_context(
context: OpenAILLMContext,
context: LLMContext,
system_message: Dict[str, Any],
functions: List[FunctionSchema],
) -> None:
@ -59,21 +53,6 @@ def update_llm_context(
# associated with the current LLM service can convert them to the correct
# provider-specific representation when required.
tools_schema = ToolsSchema(standard_tools=functions)
if isinstance(context, GoogleLLMContext):
context.system_message = system_message["content"]
if functions:
# Lets only call set_tools if we have functions, else Gemini will
# throw an exception
context.set_tools(tools_schema)
if context.messages[-1].role != "user":
# Google expects the last message should end with user message
context.add_message(Content(role="user", parts=[Part(text="...")]))
return
# In case of OpenAILLMContext, replace the system message with incoming system message
previous_interactions = context.messages
# Filter out old system messages but keep user/assistant/function content.

View file

@ -7,11 +7,11 @@ from typing import TYPE_CHECKING, Any, List
from loguru import logger
from openai import AsyncOpenAI
from opentelemetry import trace
from pipecat.services.openai.llm import OpenAILLMContext
from pipecat.utils.tracing.service_attributes import add_llm_span_attributes
from api.services.pipecat.tracing_config import is_tracing_enabled
from api.services.workflow.dto import ExtractionVariableDTO
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.utils.tracing.service_attributes import add_llm_span_attributes
if TYPE_CHECKING:
from api.services.workflow.pipecat_engine import PipecatEngine
@ -139,7 +139,7 @@ class VariableExtractionManager:
f"{conversation_history}"
)
extraction_context = OpenAILLMContext()
extraction_context = LLMContext()
extraction_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
@ -171,7 +171,7 @@ class VariableExtractionManager:
service_name="OpenAILLMService",
model=self._model,
operation_name="variable_extraction",
messages=json.dumps(extraction_messages),
messages=extraction_messages,
output=llm_response,
stream=False,
parameters={"temperature": 0.0, "response_format": "json_object"},

View file

@ -44,8 +44,6 @@ class Node:
self.extraction_prompt = data.extraction_prompt
self.extraction_variables = data.extraction_variables
self.add_global_prompt = data.add_global_prompt
self.wait_for_user_response = data.wait_for_user_response
self.wait_for_user_response_timeout = data.wait_for_user_response_timeout
self.detect_voicemail = data.detect_voicemail
self.delayed_start = data.delayed_start
self.delayed_start_duration = data.delayed_start_duration

View file

@ -3,12 +3,12 @@ import os
import aiohttp
import httpx
from loguru import logger
from pipecat.utils.context import set_current_run_id
from api.db import db_client
from api.db.models import IntegrationModel
from api.enums import OrganizationConfigurationKey, WorkflowRunMode
from api.utils.template_renderer import render_template
from pipecat.utils.context import set_current_run_id
async def run_integrations_post_workflow_run(ctx, workflow_run_id: int):
@ -162,7 +162,7 @@ async def _process_slack_integration(
"""
logger.info(f"Processing Slack integration {integration.id}")
# TODO: Generalise this, currently tailored to Kapil's use case
# TODO: Generalise this
if gathered_context.get("mapped_call_disposition") != "XFER":
logger.debug(
f"Not sending message on slack since not XFER: {gathered_context.get('mapped_call_disposition')}"

View file

@ -28,20 +28,24 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
# Fetch telephony call cost for both Twilio and Vonage
telephony_cost_usd = 0.0
if workflow_run.mode in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value] and workflow_run.cost_info:
if (
workflow_run.mode
in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
and workflow_run.cost_info
):
# Get the call ID (provider-agnostic approach with backward compatibility)
call_id = workflow_run.cost_info.get("call_id")
# Fallback to legacy provider-specific fields if needed
if not call_id:
if workflow_run.mode == WorkflowRunMode.TWILIO.value:
call_id = workflow_run.cost_info.get("twilio_call_sid")
elif workflow_run.mode == WorkflowRunMode.VONAGE.value:
call_id = workflow_run.cost_info.get("vonage_call_uuid")
# Provider name is derived from workflow run mode
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
if call_id:
try:
# Get workflow to access organization_id
@ -55,12 +59,14 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
# Use telephony provider abstraction
provider = await get_telephony_provider(workflow.organization_id)
call_cost_info = await provider.get_call_cost(call_id)
if call_cost_info.get("status") != "error":
telephony_cost_usd = call_cost_info.get("cost_usd", 0.0)
cost_breakdown["telephony_call"] = telephony_cost_usd
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd # Keep backward compatibility
cost_breakdown[f"{provider_name}_call"] = (
telephony_cost_usd # Keep backward compatibility
)
# Add telephony cost to the total
cost_breakdown["total"] = (
float(cost_breakdown["total"]) + telephony_cost_usd
@ -69,8 +75,10 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
f"{provider_name.title()} call cost: ${telephony_cost_usd:.6f} USD for call {call_id}"
)
else:
logger.error(f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}")
logger.error(
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
)
except Exception as e:
logger.error(f"Failed to fetch telephony call cost: {e}")
# Don't fail the whole cost calculation if telephony API fails
@ -119,7 +127,9 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
elif "twilio_call_sid" in workflow_run.cost_info:
cost_info["twilio_call_sid"] = workflow_run.cost_info["twilio_call_sid"]
elif "vonage_call_uuid" in workflow_run.cost_info:
cost_info["vonage_call_uuid"] = workflow_run.cost_info["vonage_call_uuid"]
cost_info["vonage_call_uuid"] = workflow_run.cost_info[
"vonage_call_uuid"
]
# Update workflow run with cost information
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)

View file

@ -1,179 +0,0 @@
### - This test has some weird loop which keeps on increasing the context size
# import asyncio
# import json
# import unittest
# from types import SimpleNamespace
# from unittest import mock
# from loguru import logger
# from pipecat.frames.frames import (
# FunctionCallInProgressFrame,
# FunctionCallResultFrame,
# FunctionCallsStartedFrame,
# LLMFullResponseEndFrame,
# LLMFullResponseStartFrame,
# LLMGeneratedTextFrame,
# LLMTextFrame,
# )
# from pipecat.pipeline.pipeline import Pipeline
# from pipecat.processors.aggregators.openai_llm_context import (
# OpenAILLMContext,
# OpenAILLMContextFrame,
# )
# from pipecat.services.llm_service import (
# FunctionCallParams,
# FunctionCallResultProperties,
# )
# from pipecat.services.openai.llm import OpenAILLMService
# from pipecat.tests.utils import run_test
# class _MockAsyncStream:
# """A minimal async-stream wrapper that mimics ``openai.AsyncStream``."""
# def __init__(self, chunks):
# self._chunks = chunks
# def __aiter__(self):
# self._idx = 0
# return self
# async def __anext__(self):
# if self._idx >= len(self._chunks):
# raise StopAsyncIteration
# item = self._chunks[self._idx]
# self._idx += 1
# await asyncio.sleep(0) # Yield control
# return item
# # ------------------------------------------------------------------
# # Factories for mock chunks
# # ------------------------------------------------------------------
# def _make_tool_call(tool_name: str, args_json: str, *, idx: int = 0):
# function = SimpleNamespace(name=tool_name, arguments=args_json)
# return SimpleNamespace(index=idx, id=f"call-{idx}", function=function)
# def _make_chunk(*, content: str | None = None, tool_calls=None, usage=None):
# delta = SimpleNamespace()
# # When we are asked to simulate multiple tool calls in parallel, OpenAI
# # sends *separate* chunks for every tool-call index. To mimic that behaviour
# # in tests we split a list of tool calls (>1) into individual chunks one
# # for each tool call while keeping the original single-chunk behaviour
# # when zero or one tool calls are supplied. This enables us to write
# # concise tests such as ``_make_chunk(tool_calls=[call_1, call_2])`` that
# # accurately reflect the streaming protocol.
# # No special handling needed if there is textual content or 0/1 tool calls.
# if content is not None or tool_calls is None or len(tool_calls) <= 1:
# if content is not None:
# delta.content = content
# # Always set tool_calls so downstream code can safely access it
# delta.tool_calls = tool_calls if tool_calls is not None else None
# return SimpleNamespace(choices=[SimpleNamespace(delta=delta)], usage=usage)
# # --- Multiple tool calls (len(tool_calls) > 1) ---
# # Create a list of chunks, each containing a single tool call. This is the
# # format produced by the OpenAI client when several tools are invoked in a
# # single assistant response.
# chunks = []
# for tc in tool_calls:
# delta_tc = SimpleNamespace(tool_calls=[tc])
# chunks.append(SimpleNamespace(choices=[SimpleNamespace(delta=delta_tc)], usage=usage))
# return chunks
# class TestBaseOpenAILLMService(unittest.IsolatedAsyncioTestCase):
# async def test_process_context_with_patch(self):
# streamed_text = "Hello from OpenAI!"
# tool_name = "echo"
# tool_name_2 = "echo_2"
# tool_args = {"text": "hello"}
# tool_args_2 = {"text": "hello_2"}
# # Build mocked stream (tool call first, then text)
# chunks = [
# _make_chunk(content=streamed_text),
# _make_chunk(tool_calls=[_make_tool_call(tool_name, json.dumps(tool_args))]),
# _make_chunk(tool_calls=[_make_tool_call(tool_name_2, json.dumps(tool_args_2), idx=1)]),
# ]
# # Instantiate real OpenAILLMService (no need for actual API key)
# llm = OpenAILLMService(model="gpt-4o-mini", api_key="test")
# # Patch get_chat_completions to return our mocked async stream
# async def fake_get_chat_completions(self, context, messages): # noqa: D401
# return _MockAsyncStream(chunks)
# with mock.patch.object(llm.__class__, "get_chat_completions", fake_get_chat_completions):
# # Register echo tool
# executed = False
# async def echo_handler(params: FunctionCallParams):
# nonlocal executed
# executed = True
# # sleep for 1 second
# logger.info("echo_handler: sleeping for 5 second")
# await asyncio.sleep(5)
# await params.result_callback(
# {"ok": True},
# properties=FunctionCallResultProperties(run_llm=True),
# )
# async def echo_2_handler(params: FunctionCallParams):
# nonlocal executed
# executed = True
# # sleep for 1 second
# logger.info("echo_2_handler: sleeping for 5 second")
# await asyncio.sleep(5)
# await params.result_callback(
# {"ok": True},
# properties=FunctionCallResultProperties(run_llm=True),
# )
# llm.register_function(tool_name, echo_handler)
# llm.register_function(tool_name_2, echo_2_handler)
# # Prepare context and send
# context = OpenAILLMContext()
# context.add_message({"role": "user", "content": "Hi"})
# frames_to_send = [OpenAILLMContextFrame(context)]
# expected_down_frames = [
# LLMFullResponseStartFrame,
# FunctionCallsStartedFrame,
# FunctionCallInProgressFrame,
# FunctionCallResultFrame,
# LLMGeneratedTextFrame,
# LLMTextFrame,
# LLMFullResponseEndFrame,
# ]
# context_aggregator = llm.create_context_aggregator(context)
# pipeline = Pipeline([llm, context_aggregator.assistant()])
# down_frames, _ = await run_test(
# pipeline,
# frames_to_send=frames_to_send,
# expected_down_frames=expected_down_frames,
# send_end_frame=False,
# )
# # Assertions
# self.assertTrue(executed)
# for fr in down_frames:
# if isinstance(fr, FunctionCallResultFrame):
# self.assertTrue(fr.run_llm)
# if isinstance(fr, LLMTextFrame):
# self.assertEqual(fr.text, streamed_text)
# if __name__ == "__main__":
# unittest.main()

View file

@ -1,143 +0,0 @@
#!/usr/bin/env python3
"""
Test script to verify that LLMGeneratedTextFrame signaling works correctly
with the new local variable approach.
"""
def test_local_variable_logic():
"""Test the core logic using the same pattern as the implementation"""
print("=== Testing Local Variable Logic ===")
# Simulate the logic from _process_context
text_generation_signaled = False
frames_sent = []
# Simulate chunks with text content
chunks_with_content = ["Hello", " world", "!"]
for content in chunks_with_content:
# This is the exact logic from our implementation
if content: # equivalent to chunk.choices[0].delta.content
if not text_generation_signaled:
frames_sent.append("LLMGeneratedTextFrame")
text_generation_signaled = True
frames_sent.append(f"LLMTextFrame({content})")
print(f"Frames sent: {frames_sent}")
# Verify behavior
generated_signals = [f for f in frames_sent if f == "LLMGeneratedTextFrame"]
text_frames = [f for f in frames_sent if f.startswith("LLMTextFrame")]
assert len(generated_signals) == 1, (
f"Expected 1 signal, got {len(generated_signals)}"
)
assert len(text_frames) == 3, f"Expected 3 text frames, got {len(text_frames)}"
assert frames_sent[0] == "LLMGeneratedTextFrame", "Signal should be first"
print("✅ Local variable logic works correctly")
return True
def test_no_text_logic():
"""Test that no signal is sent when there's no text"""
print("\n=== Testing No Text Logic ===")
text_generation_signaled = False
frames_sent = []
# Simulate chunks with no text content (function calls only)
chunks_with_content = [None, None, None] # No text content
for content in chunks_with_content:
if content: # This will be False for all chunks
if not text_generation_signaled:
frames_sent.append("LLMGeneratedTextFrame")
text_generation_signaled = True
frames_sent.append(f"LLMTextFrame({content})")
print(f"Frames sent: {frames_sent}")
assert len(frames_sent) == 0, f"Expected no frames, got {frames_sent}"
print("✅ No signal sent when no text content")
return True
def test_mixed_content_logic():
"""Test behavior with mixed function calls and text"""
print("\n=== Testing Mixed Content Logic ===")
text_generation_signaled = False
frames_sent = []
# Simulate chunks: function call, text, function call, text
chunks = [
{"type": "function", "content": None},
{"type": "text", "content": "Hello"},
{"type": "function", "content": None},
{"type": "text", "content": " world"},
]
for chunk in chunks:
if chunk["type"] == "function":
frames_sent.append("FunctionCallFrame")
elif chunk["content"]: # text content
if not text_generation_signaled:
frames_sent.append("LLMGeneratedTextFrame")
text_generation_signaled = True
frames_sent.append(f"LLMTextFrame({chunk['content']})")
print(f"Frames sent: {frames_sent}")
generated_signals = [f for f in frames_sent if f == "LLMGeneratedTextFrame"]
assert len(generated_signals) == 1, (
f"Expected 1 signal, got {len(generated_signals)}"
)
# Signal should come before first text frame but after any function frames
signal_index = frames_sent.index("LLMGeneratedTextFrame")
first_text_index = next(
i for i, f in enumerate(frames_sent) if f.startswith("LLMTextFrame")
)
assert signal_index == first_text_index - 1, (
"Signal should come right before first text"
)
print("✅ Mixed content logic works correctly")
return True
def main():
try:
test1_result = test_local_variable_logic()
test2_result = test_no_text_logic()
test3_result = test_mixed_content_logic()
print(f"\n=== Test Results ===")
print(f"Local variable test: {'✅ PASS' if test1_result else '❌ FAIL'}")
print(f"No text test: {'✅ PASS' if test2_result else '❌ FAIL'}")
print(f"Mixed content test: {'✅ PASS' if test3_result else '❌ FAIL'}")
if test1_result and test2_result and test3_result:
print("\n🎉 All LLMGeneratedTextFrame signaling logic tests passed!")
print(
"✅ Implementation correctly signals text generation once, as early as possible"
)
else:
print("\n❌ Some tests failed.")
except Exception as e:
print(f"❌ Test failed with error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View file

@ -1,536 +0,0 @@
import asyncio
from unittest.mock import AsyncMock, Mock, patch
import pytest
from pipecat.frames.frames import (
EndFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.services.openai.llm import OpenAILLMContext
from api.services.workflow.dto import EdgeDataDTO, NodeDataDTO
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import Edge, Node, WorkflowGraph
class TestPipecatEngineSetNode:
"""Test cases for PipecatEngine.set_node method refactoring."""
@pytest.fixture
def mock_workflow(self):
"""Create a mock workflow with various node types."""
workflow = Mock(spec=WorkflowGraph)
workflow.nodes = {}
workflow.start_node_id = "start_node"
workflow.global_node_id = None
return workflow
@pytest.fixture
def mock_dependencies(self, mock_workflow):
"""Create mock dependencies for PipecatEngine initialization."""
task = AsyncMock()
task.queue_frames = AsyncMock()
task.queue_frame = AsyncMock()
llm = AsyncMock()
llm.register_function = Mock()
llm.push_frame = AsyncMock()
context = Mock(spec=OpenAILLMContext)
context.set_node_name = Mock()
return {
"task": task,
"llm": llm,
"context": context,
"tts": Mock(),
"transport": Mock(),
"workflow": mock_workflow,
"call_context_vars": {"test_var": "test_value"},
}
@pytest.fixture
def engine(self, mock_dependencies):
"""Create a PipecatEngine instance."""
# Add audio_buffer and workflow_run_id to dependencies
mock_dependencies["audio_buffer"] = None
mock_dependencies["workflow_run_id"] = 123
engine = PipecatEngine(**mock_dependencies)
# Mock the builtin function registration
engine._register_builtin_functions = AsyncMock()
return engine
def create_node(self, node_id, **kwargs):
"""Helper to create a node with default values."""
defaults = {
"name": f"Node {node_id}",
"prompt": f"Prompt for {node_id}",
"is_static": False,
"is_start": False,
"is_end": False,
"allow_interrupt": True,
"extraction_enabled": False,
"extraction_prompt": "",
"extraction_variables": [],
"add_global_prompt": True,
"wait_for_user_response": False,
"detect_voicemail": False,
}
defaults.update(kwargs)
data = Mock(spec=NodeDataDTO)
for key, value in defaults.items():
setattr(data, key, value)
node = Mock(spec=Node)
node.id = node_id
node.data = data
node.out_edges = []
# Copy attributes from data to node
for key, value in defaults.items():
setattr(node, key, value)
return node
def create_edge(
self, source, target, label="Continue", condition="Always continue"
):
"""Helper to create an edge."""
data = Mock(spec=EdgeDataDTO)
data.label = label
data.condition = condition
edge = Mock(spec=Edge)
edge.source = source
edge.target = target
edge.data = data
edge.get_function_name = Mock(return_value=label.lower().replace(" ", "_"))
return edge
# ===== START NODE TESTS =====
@pytest.mark.asyncio
async def test_start_node_static_immediate_execution(self, engine, mock_workflow):
"""Test: Basic static start node executes immediately."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=True,
prompt="Welcome to our service!",
)
next_node = self.create_node("next_node", is_static=False)
edge = self.create_edge("start_node", "next_node")
start_node.out_edges = [edge]
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
# Execute
await engine.set_node("start_node")
# Verify
# Should queue TTS immediately
engine.task.queue_frames.assert_called_once()
frames = engine.task.queue_frames.call_args[0][0]
assert len(frames) == 3
assert isinstance(frames[0], LLMFullResponseStartFrame)
assert isinstance(frames[1], TTSSpeakFrame)
assert frames[1].text == "Welcome to our service!"
assert isinstance(frames[2], LLMFullResponseEndFrame)
# Static start nodes now set pending transition after context push
assert engine._pending_control_transition_after_context_push is not None
# Should not have set detect_voicemail for static start without it
assert not engine._detect_voicemail
@pytest.mark.asyncio
async def test_start_node_with_detect_voicemail_no_audio_buffer(
self, engine, mock_workflow
):
"""Test: Start node with voicemail detection but no audio buffer logs warning."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=True,
detect_voicemail=True,
prompt="Hello, this is a business call.",
)
mock_workflow.nodes = {"start_node": start_node}
# Engine has no audio buffer (None)
assert engine._audio_buffer is None
# Execute
await engine.set_node("start_node")
# Verify
# Should NOT set voicemail detection flag since no audio buffer
assert engine._detect_voicemail is False
assert engine._voicemail_detector is None
# Should queue TTS immediately
engine.task.queue_frames.assert_called_once()
frames = engine.task.queue_frames.call_args[0][0]
assert isinstance(frames[1], TTSSpeakFrame)
assert frames[1].text == "Hello, this is a business call."
@pytest.mark.asyncio
async def test_start_node_non_static_with_detect_voicemail(
self, engine, mock_workflow
):
"""Test: Non-static start node with voicemail detection without audio buffer."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=False, # Non-static
detect_voicemail=True,
prompt="You are an AI assistant. Start the conversation.",
)
mock_workflow.nodes = {"start_node": start_node}
# Mock the context update method
engine._update_llm_context = AsyncMock()
engine._compose_system_message_functions_for_node = AsyncMock(
return_value=({"role": "system", "content": "Test prompt"}, [])
)
# Execute
await engine.set_node("start_node")
# Verify
# Should NOT set voicemail detection flags (no audio buffer)
assert engine._detect_voicemail is False
assert engine._voicemail_detector is None
# Should update LLM context for non-static node
engine._update_llm_context.assert_called_once()
# Should queue context frame
engine.task.queue_frame.assert_called_once()
frame = engine.task.queue_frame.call_args[0][0]
assert isinstance(frame, OpenAILLMContextFrame)
@pytest.mark.asyncio
async def test_start_node_static_with_wait_for_user_response(
self, engine, mock_workflow
):
"""Test: Static start node with wait_for_user_response."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=True,
wait_for_user_response=True,
prompt="Please tell me your name.",
)
next_node = self.create_node("next_node")
edge = self.create_edge("start_node", "next_node")
start_node.out_edges = [edge]
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
# Execute
await engine.set_node("start_node")
# Verify
# Should queue TTS immediately
engine.task.queue_frames.assert_called_once()
# Should have a pending control transition that will start the timer
assert engine._pending_control_transition_after_context_push is not None
# Timer task should not exist yet
assert (
not hasattr(engine, "_user_response_timeout_task")
or engine._user_response_timeout_task is None
)
# Simulate context push to start the timer
await engine.flush_pending_transitions(source="context_push")
# Now the timeout task should be created
assert engine._user_response_timeout_task is not None
assert not engine._user_response_timeout_task.done()
# Clean up the task
engine._user_response_timeout_task.cancel()
@pytest.mark.asyncio
async def test_start_node_non_static(self, engine, mock_workflow):
"""Test: Non-static start node sends context to LLM."""
# Setup
start_node = self.create_node(
"start_node",
is_start=True,
is_static=False,
prompt="You are a helpful assistant. Greet the user.",
)
mock_workflow.nodes = {"start_node": start_node}
# Mock the context update method
engine._update_llm_context = AsyncMock()
engine._compose_system_message_functions_for_node = AsyncMock(
return_value=({"role": "system", "content": "Test prompt"}, [])
)
# Execute
await engine.set_node("start_node")
# Verify
# Should set context name
engine.context.set_node_name.assert_called_once_with("Node start_node")
# Should update LLM context
engine._update_llm_context.assert_called_once()
# Should queue context frame
engine.task.queue_frame.assert_called_once()
frame = engine.task.queue_frame.call_args[0][0]
assert isinstance(frame, OpenAILLMContextFrame)
# ===== AGENT NODE TESTS =====
@pytest.mark.asyncio
async def test_agent_node_static(self, engine, mock_workflow):
"""Test: Static agent node plays TTS and transitions."""
# Setup
agent_node = self.create_node(
"agent_node", is_static=True, prompt="Processing your request..."
)
next_node = self.create_node("next_node")
edge = self.create_edge("agent_node", "next_node")
agent_node.out_edges = [edge]
mock_workflow.nodes = {"agent_node": agent_node, "next_node": next_node}
# Execute
await engine.set_node("agent_node")
# Verify
# Should queue TTS
engine.task.queue_frames.assert_called_once()
frames = engine.task.queue_frames.call_args[0][0]
assert isinstance(frames[1], TTSSpeakFrame)
assert frames[1].text == "Processing your request..."
# Should have pending transition
assert engine._pending_control_transition_after_context_push is not None
@pytest.mark.asyncio
async def test_agent_node_non_static(self, engine, mock_workflow):
"""Test: Non-static agent node sends context to LLM."""
# Setup
agent_node = self.create_node(
"agent_node",
is_static=False,
prompt="Analyze the user's request and respond appropriately.",
)
decision_node = self.create_node("decision_node")
edge = self.create_edge("agent_node", "decision_node", "analyze_complete")
agent_node.out_edges = [edge]
mock_workflow.nodes = {"agent_node": agent_node, "decision_node": decision_node}
# Mock methods
engine._update_llm_context = AsyncMock()
engine._compose_system_message_functions_for_node = AsyncMock(
return_value=(
{"role": "system", "content": "Test"},
[{"name": "test_func"}],
)
)
# Execute
await engine.set_node("agent_node")
# Verify
# Should register transition function
engine.llm.register_function.assert_called_once()
call_args = engine.llm.register_function.call_args
assert call_args[0][0] == "analyze_complete"
assert callable(call_args[0][1]) # Check it's a function
assert call_args[1]["cancel_on_interruption"] is True
# Should update context and send frame
engine._update_llm_context.assert_called_once()
engine.task.queue_frame.assert_called_once()
@pytest.mark.asyncio
async def test_agent_node_with_interruption_control(self, engine, mock_workflow):
"""Test: Agent node respects allow_interrupt flag."""
# Setup
no_interrupt_node = self.create_node(
"no_interrupt",
is_static=True,
allow_interrupt=False,
prompt="Please wait while I process...",
)
mock_workflow.nodes = {"no_interrupt": no_interrupt_node}
# Execute
await engine.set_node("no_interrupt")
# Verify current node is set (for STT mute callback)
assert engine._current_node == no_interrupt_node
assert engine._current_node.allow_interrupt is False
# ===== END NODE TESTS =====
@pytest.mark.asyncio
async def test_end_node_static(self, engine, mock_workflow):
"""Test: Static end node plays final message and schedules end task."""
# Setup
end_node = self.create_node(
"end_node",
is_static=True,
is_end=True,
prompt="Thank you for calling. Goodbye!",
)
mock_workflow.nodes = {"end_node": end_node}
# Execute
await engine.set_node("end_node")
# Verify
# Should queue TTS
engine.task.queue_frames.assert_called_once()
frames = engine.task.queue_frames.call_args[0][0]
assert frames[1].text == "Thank you for calling. Goodbye!"
# Should have pending end task
assert engine._pending_control_transition_after_context_push is not None
# Execute the pending transition
await engine._pending_control_transition_after_context_push()
# Should have sent EndFrame via task.queue_frame
# The second call should be the EndFrame (first was TTS frames)
assert engine.task.queue_frame.call_count >= 1
end_frame = engine.task.queue_frame.call_args[0][0]
assert isinstance(end_frame, EndFrame)
@pytest.mark.asyncio
async def test_end_node_with_extraction(self, engine, mock_workflow):
"""Test: End node with variable extraction."""
# Setup
end_node = self.create_node(
"end_node",
is_end=True,
is_static=False,
extraction_enabled=True,
extraction_variables=["user_name", "satisfaction_level"],
extraction_prompt="Extract user name and satisfaction",
)
mock_workflow.nodes = {"end_node": end_node}
# Mock the extraction manager
engine._variable_extraction_manager = Mock()
engine._perform_variable_extraction_if_needed = AsyncMock()
# Mock context update and composition methods
engine._update_llm_context = AsyncMock()
engine._compose_system_message_functions_for_node = AsyncMock(
return_value=({"role": "system", "content": "Test"}, [])
)
# Execute
await engine.set_node("end_node")
# Verify
# Should trigger extraction
engine._perform_variable_extraction_if_needed.assert_called_once_with(end_node)
# Should have pending end task
assert engine._pending_control_transition_after_context_push is not None
# ===== CALLBACK INTEGRATION TESTS =====
@pytest.mark.asyncio
async def test_user_stopped_speaking_during_response_wait(
self, engine, mock_workflow
):
"""Test: User stops speaking triggers transition during wait_for_response."""
# Setup
start_node = self.create_node(
"start_node", is_start=True, is_static=True, wait_for_user_response=True
)
next_node = self.create_node("next_node")
edge = self.create_edge("start_node", "next_node")
start_node.out_edges = [edge]
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
# Set current node to start node
engine._current_node = start_node
engine._user_response_timeout_task = asyncio.create_task(asyncio.sleep(3))
# Create callback and execute
callback = engine.create_user_stopped_speaking_callback()
# Mock set_node to avoid recursion
with patch.object(engine, "set_node", new=AsyncMock()) as mock_set_node:
await callback()
# Verify
mock_set_node.assert_called_once_with("next_node")
assert engine._queue_context_frame is False # Should be set to False
@pytest.mark.asyncio
async def test_context_push_callback_executes_pending_transitions(self, engine):
"""Test: flush_pending_transitions executes deferred transitions."""
# Setup pending transitions
mock_generated_transition = AsyncMock()
mock_control_transition = AsyncMock()
engine._pending_generated_transition_after_context_push = (
mock_generated_transition
)
engine._pending_control_transition_after_context_push = mock_control_transition
# Execute
await engine.flush_pending_transitions(source="context_push")
# Verify both transitions were executed
mock_generated_transition.assert_called_once()
mock_control_transition.assert_called_once()
# Verify they were cleared
assert engine._pending_generated_transition_after_context_push is None
assert engine._pending_control_transition_after_context_push is None
# ===== COMPLEX SCENARIO TESTS =====
# Add helper for testing with real async behavior
def ANY(cls=None):
"""Helper for matching any argument in mock calls."""
class AnyMatcher:
def __init__(self, cls):
self.cls = cls
def __eq__(self, other):
if self.cls:
return isinstance(other, self.cls)
return True
return AnyMatcher(cls)

View file

@ -5,46 +5,41 @@ handles provider switches correctly without losing billing data.
"""
import asyncio
import json
from datetime import datetime, timezone
from typing import Dict, Any
# Test scenarios to validate
async def test_scenario_1_mid_call_provider_switch():
"""
Test: What happens if provider is switched while a call is active?
Expected behavior:
- Active call continues with original provider
- Call is billed to original provider
- New calls use new provider
"""
print("Test 1: Mid-call provider switching")
# Simulate workflow run with Twilio
twilio_run = {
"id": 1,
"mode": "twilio",
"cost_info": {
"twilio_call_sid": "CA123456789",
"provider": "twilio"
},
"is_completed": False
"cost_info": {"twilio_call_sid": "CA123456789", "provider": "twilio"},
"is_completed": False,
}
# Provider switch happens here (in real scenario, user changes config)
# But the call continues...
# When cost calculation runs, it should:
# 1. Use the provider stored in cost_info
# 2. Fetch cost from Twilio using twilio_call_sid
# 3. Store cost with provider attribution
result = {
"test": "mid_call_switch",
"status": "PASS",
"reason": "Call continues with original provider, billing intact"
"reason": "Call continues with original provider, billing intact",
}
print(f"{result['reason']}")
return result
@ -53,41 +48,41 @@ async def test_scenario_1_mid_call_provider_switch():
async def test_scenario_2_pending_cost_calculation():
"""
Test: Calls that ended but cost not yet calculated when provider switches.
Expected behavior:
- Background job should use the provider info stored in cost_info
- Cost should be fetched from correct provider
"""
print("\nTest 2: Pending cost calculation during switch")
# Workflow runs that ended but cost job hasn't run yet
pending_runs = [
{
"id": 2,
"mode": "twilio",
"mode": "twilio",
"cost_info": {"twilio_call_sid": "CA987654321", "provider": "twilio"},
"is_completed": True
"is_completed": True,
},
{
"id": 3,
"mode": "vonage",
"cost_info": {"vonage_call_uuid": "uuid-123", "provider": "vonage"},
"is_completed": True
}
"is_completed": True,
},
]
# Provider switch happens here
# Cost calculation jobs run after switch
# Each job should:
# 1. Check the provider field in cost_info
# 2. Use appropriate provider API to fetch cost
# 3. Handle gracefully if credentials changed
result = {
"test": "pending_cost_calculation",
"status": "PASS",
"reason": "Cost jobs use stored provider info correctly"
"reason": "Cost jobs use stored provider info correctly",
}
print(f"{result['reason']}")
return result
@ -96,33 +91,37 @@ async def test_scenario_2_pending_cost_calculation():
async def test_scenario_3_mixed_provider_history():
"""
Test: Organization has calls from both Twilio and Vonage.
Expected behavior:
- Historical costs remain intact
- Reports show correct attribution
- Total costs aggregate correctly
"""
print("\nTest 3: Mixed provider history")
historical_runs = [
{"provider": "twilio", "cost_usd": 0.15, "date": "2024-01-01"},
{"provider": "vonage", "cost_usd": 0.12, "date": "2024-01-02"},
{"provider": "twilio", "cost_usd": 0.18, "date": "2024-01-03"},
{"provider": "vonage", "cost_usd": 0.14, "date": "2024-01-04"},
]
# Calculate totals
total_cost = sum(run["cost_usd"] for run in historical_runs)
twilio_cost = sum(run["cost_usd"] for run in historical_runs if run["provider"] == "twilio")
vonage_cost = sum(run["cost_usd"] for run in historical_runs if run["provider"] == "vonage")
twilio_cost = sum(
run["cost_usd"] for run in historical_runs if run["provider"] == "twilio"
)
vonage_cost = sum(
run["cost_usd"] for run in historical_runs if run["provider"] == "vonage"
)
result = {
"test": "mixed_provider_history",
"status": "PASS",
"status": "PASS",
"total_cost": total_cost,
"twilio_cost": twilio_cost,
"vonage_cost": vonage_cost,
"reason": f"Costs correctly aggregated: Total ${total_cost:.2f} (Twilio: ${twilio_cost:.2f}, Vonage: ${vonage_cost:.2f})"
"reason": f"Costs correctly aggregated: Total ${total_cost:.2f} (Twilio: ${twilio_cost:.2f}, Vonage: ${vonage_cost:.2f})",
}
print(f"{result['reason']}")
return result
@ -131,41 +130,41 @@ async def test_scenario_3_mixed_provider_history():
async def test_scenario_4_cost_api_failure():
"""
Test: Provider API fails when fetching cost.
Expected behavior:
- Error logged but system continues
- Call record preserved
- Cost marked as 0 or unknown
"""
print("\nTest 4: Cost API failure handling")
# Simulate API failure scenarios
failure_scenarios = [
{
"provider": "twilio",
"error": "401 Unauthorized - credentials changed",
"expected": "Cost set to 0, error logged"
"expected": "Cost set to 0, error logged",
},
{
"provider": "vonage",
"error": "404 Not Found - call record deleted",
"expected": "Cost set to 0, error logged"
"expected": "Cost set to 0, error logged",
},
{
"provider": "twilio",
"error": "500 Internal Server Error",
"expected": "Cost set to 0, retry possible"
}
"expected": "Cost set to 0, retry possible",
},
]
for scenario in failure_scenarios:
print(f" - {scenario['provider']}: {scenario['error']}")
print(f" Expected: {scenario['expected']}")
result = {
"test": "cost_api_failure",
"status": "PASS",
"reason": "All failure scenarios handled gracefully"
"reason": "All failure scenarios handled gracefully",
}
print(f"{result['reason']}")
return result
@ -174,22 +173,22 @@ async def test_scenario_4_cost_api_failure():
async def test_scenario_5_configuration_migration():
"""
Test: Database migration from single to multi-provider format.
Expected behavior:
- Old TWILIO_CONFIGURATION migrated to TELEPHONY_CONFIGURATION
- Single provider config wrapped in multi-provider structure
- Existing cost_info gets provider field added
"""
print("\nTest 5: Configuration migration")
# Old format
old_config = {
"account_sid": "AC123",
"auth_token": "token123",
"auth_token": "token123",
"from_numbers": ["+1234567890"],
"provider": "twilio"
"provider": "twilio",
}
# New format after migration
new_config = {
"active_provider": "twilio",
@ -197,20 +196,20 @@ async def test_scenario_5_configuration_migration():
"twilio": {
"account_sid": "AC123",
"auth_token": "token123",
"from_numbers": ["+1234567890"]
"from_numbers": ["+1234567890"],
}
}
},
}
# Validate migration
assert new_config["active_provider"] == "twilio"
assert "providers" in new_config
assert new_config["providers"]["twilio"]["account_sid"] == old_config["account_sid"]
result = {
"test": "configuration_migration",
"status": "PASS",
"reason": "Configuration migrated to multi-provider format correctly"
"reason": "Configuration migrated to multi-provider format correctly",
}
print(f"{result['reason']}")
return result
@ -219,39 +218,34 @@ async def test_scenario_5_configuration_migration():
async def test_scenario_6_provider_cost_discrepancy():
"""
Test: Webhook cost vs API cost discrepancy.
Expected behavior:
- Webhook cost stored immediately if available
- API cost fetched later for verification
- Both costs stored for auditing
"""
print("\nTest 6: Provider cost discrepancy handling")
# Vonage webhook provides immediate cost
webhook_cost = {
"vonage_webhook_price": 0.15,
"vonage_webhook_duration": 120
}
webhook_cost = {"vonage_webhook_price": 0.15, "vonage_webhook_duration": 120}
# API call provides authoritative cost
api_cost = {
"cost_usd": 0.14, # Slight difference
"duration": 120
"duration": 120,
}
# Both should be stored
final_cost_info = {
**webhook_cost,
"cost_breakdown": {
"telephony_call": api_cost["cost_usd"]
},
"provider": "vonage"
"cost_breakdown": {"telephony_call": api_cost["cost_usd"]},
"provider": "vonage",
}
result = {
"test": "cost_discrepancy",
"status": "PASS",
"reason": "Both webhook and API costs stored for auditing"
"reason": "Both webhook and API costs stored for auditing",
}
print(f"{result['reason']}")
return result
@ -262,40 +256,40 @@ async def run_all_tests():
print("=" * 60)
print("PROVIDER SWITCHING TEST SUITE")
print("=" * 60)
tests = [
test_scenario_1_mid_call_provider_switch,
test_scenario_2_pending_cost_calculation,
test_scenario_3_mixed_provider_history,
test_scenario_4_cost_api_failure,
test_scenario_5_configuration_migration,
test_scenario_6_provider_cost_discrepancy
test_scenario_6_provider_cost_discrepancy,
]
results = []
for test in tests:
result = await test()
results.append(result)
print("\n" + "=" * 60)
print("TEST SUMMARY")
print("=" * 60)
passed = sum(1 for r in results if r["status"] == "PASS")
failed = sum(1 for r in results if r["status"] == "FAIL")
print(f"Total Tests: {len(results)}")
print(f"Passed: {passed}")
print(f"Failed: {failed}")
if failed == 0:
print("\n✅ ALL TESTS PASSED - Provider switching is working correctly!")
else:
print("\n❌ Some tests failed - Review the implementation")
return results
if __name__ == "__main__":
# Run the test suite
asyncio.run(run_all_tests())
asyncio.run(run_all_tests())

@ -1 +1 @@
Subproject commit fa68d2ce261544398013307d2c6a69e0556b4449
Subproject commit 53653657d851e8052f9cc5b73b6f675a44c86fe7

View file

@ -18,8 +18,6 @@ interface EndCallEditFormProps {
nodeData: FlowNodeData;
prompt: string;
setPrompt: (value: string) => void;
isStatic: boolean;
setIsStatic: (value: boolean) => void;
name: string;
setName: (value: string) => void;
extractionEnabled: boolean;
@ -45,7 +43,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
// Form state
const [prompt, setPrompt] = useState(data.prompt);
const [isStatic, setIsStatic] = useState(data.is_static ?? true);
const [name, setName] = useState(data.name);
// Variable Extraction state
@ -58,7 +55,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
handleSaveNodeData({
...data,
prompt,
is_static: isStatic,
name,
allow_interrupt: false, // Always set to false for end nodes
extraction_enabled: extractionEnabled,
@ -77,7 +73,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
const handleOpenChange = (newOpen: boolean) => {
if (newOpen) {
setPrompt(data.prompt);
setIsStatic(data.is_static ?? true);
setName(data.name);
setExtractionEnabled(data.extraction_enabled ?? false);
setExtractionPrompt(data.extraction_prompt ?? "");
@ -91,7 +86,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
useEffect(() => {
if (open) {
setPrompt(data.prompt);
setIsStatic(data.is_static ?? true);
setName(data.name);
setExtractionEnabled(data.extraction_enabled ?? false);
setExtractionPrompt(data.extraction_prompt ?? "");
@ -137,8 +131,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
nodeData={data}
prompt={prompt}
setPrompt={setPrompt}
isStatic={isStatic}
setIsStatic={setIsStatic}
name={name}
setName={setName}
extractionEnabled={extractionEnabled}
@ -159,8 +151,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
const EndCallEditForm = ({
prompt,
setPrompt,
isStatic,
setIsStatic,
name,
setName,
extractionEnabled,
@ -206,14 +196,10 @@ const EndCallEditForm = ({
</Label>
<Input value={name} onChange={(e) => setName(e.target.value)} />
<Label>{isStatic ? "Text" : "Prompt"}</Label>
<Label>Prompt</Label>
<Label className="text-xs text-gray-500">
What would you like the agent to say when the call ends? Its a good idea to have a static goodbye message.
Enter the prompt for the agent. This will be used to generate the agent&apos;s response. Prompt engineering&apos;s best practices apply.
</Label>
<div className="flex items-center space-x-2">
<Switch id="static-text" checked={isStatic} onCheckedChange={setIsStatic} />
<Label htmlFor="static-text">Static Text</Label>
</div>
<Textarea
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
@ -221,7 +207,7 @@ const EndCallEditForm = ({
style={{
overflowY: 'auto'
}}
placeholder={isStatic ? "Thank you for calling Dograh. Have a great day!" : "Enter a dynamic prompt"}
placeholder="Enter a dynamic prompt"
/>
<div className="flex items-center space-x-2">
<Switch id="add-global-prompt" checked={addGlobalPrompt} onCheckedChange={setAddGlobalPrompt} />

View file

@ -19,16 +19,12 @@ interface StartCallEditFormProps {
nodeData: FlowNodeData;
prompt: string;
setPrompt: (value: string) => void;
isStatic: boolean;
setIsStatic: (value: boolean) => void;
name: string;
setName: (value: string) => void;
allowInterrupt: boolean;
setAllowInterrupt: (value: boolean) => void;
addGlobalPrompt: boolean;
setAddGlobalPrompt: (value: boolean) => void;
waitForUserResponse: boolean;
setWaitForUserResponse: (value: boolean) => void;
detectVoicemail: boolean;
setDetectVoicemail: (value: boolean) => void;
delayedStart: boolean;
@ -50,11 +46,9 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
// Form state
const [prompt, setPrompt] = useState(data.prompt ?? "");
const [isStatic, setIsStatic] = useState(data.is_static ?? true);
const [name, setName] = useState(data.name);
const [allowInterrupt, setAllowInterrupt] = useState(data.allow_interrupt ?? true);
const [addGlobalPrompt, setAddGlobalPrompt] = useState(data.add_global_prompt ?? true);
const [waitForUserResponse, setWaitForUserResponse] = useState(data.wait_for_user_response ?? false);
const [detectVoicemail, setDetectVoicemail] = useState(data.detect_voicemail ?? true);
const [delayedStart, setDelayedStart] = useState(data.delayed_start ?? false);
const [delayedStartDuration, setDelayedStartDuration] = useState(data.delayed_start_duration ?? 2);
@ -63,11 +57,9 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
handleSaveNodeData({
...data,
prompt,
is_static: isStatic,
name,
allow_interrupt: allowInterrupt,
add_global_prompt: addGlobalPrompt,
wait_for_user_response: waitForUserResponse,
detect_voicemail: detectVoicemail,
delayed_start: delayedStart,
delayed_start_duration: delayedStart ? delayedStartDuration : undefined
@ -83,11 +75,9 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
const handleOpenChange = (newOpen: boolean) => {
if (newOpen) {
setPrompt(data.prompt ?? "");
setIsStatic(data.is_static ?? true);
setName(data.name);
setAllowInterrupt(data.allow_interrupt ?? true);
setAddGlobalPrompt(data.add_global_prompt ?? true);
setWaitForUserResponse(data.wait_for_user_response ?? false);
setDetectVoicemail(data.detect_voicemail ?? true);
setDelayedStart(data.delayed_start ?? false);
setDelayedStartDuration(data.delayed_start_duration ?? 3);
@ -99,11 +89,9 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
useEffect(() => {
if (open) {
setPrompt(data.prompt ?? "");
setIsStatic(data.is_static ?? true);
setName(data.name);
setAllowInterrupt(data.allow_interrupt ?? true);
setAddGlobalPrompt(data.add_global_prompt ?? true);
setWaitForUserResponse(data.wait_for_user_response ?? false);
setDetectVoicemail(data.detect_voicemail ?? true);
setDelayedStart(data.delayed_start ?? false);
setDelayedStartDuration(data.delayed_start_duration ?? 3);
@ -147,16 +135,12 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
nodeData={data}
prompt={prompt}
setPrompt={setPrompt}
isStatic={isStatic}
setIsStatic={setIsStatic}
name={name}
setName={setName}
allowInterrupt={allowInterrupt}
setAllowInterrupt={setAllowInterrupt}
addGlobalPrompt={addGlobalPrompt}
setAddGlobalPrompt={setAddGlobalPrompt}
waitForUserResponse={waitForUserResponse}
setWaitForUserResponse={setWaitForUserResponse}
detectVoicemail={detectVoicemail}
setDetectVoicemail={setDetectVoicemail}
delayedStart={delayedStart}
@ -173,16 +157,12 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
const StartCallEditForm = ({
prompt,
setPrompt,
isStatic,
setIsStatic,
name,
setName,
allowInterrupt,
setAllowInterrupt,
addGlobalPrompt,
setAddGlobalPrompt,
waitForUserResponse,
setWaitForUserResponse,
detectVoicemail,
setDetectVoicemail,
delayedStart,
@ -201,14 +181,10 @@ const StartCallEditForm = ({
onChange={(e) => setName(e.target.value)}
/>
<Label>{isStatic ? "Text" : "Prompt"}</Label>
<Label>Prompt</Label>
<Label className="text-xs text-gray-500">
What would you like the agent to say when the call starts? Its a good idea to have a static greeting that can be used to identify the call.
Enter the prompt for the agent. This will be used to generate the agent&apos;s response. Prompt engineering&apos;s best practices apply.
</Label>
<div className="flex items-center space-x-2">
<Switch id="static-text" checked={isStatic} onCheckedChange={setIsStatic} />
<Label htmlFor="static-text">Static Text</Label>
</div>
<Textarea
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
@ -216,7 +192,7 @@ const StartCallEditForm = ({
style={{
overflowY: 'auto'
}}
placeholder={isStatic ? "Hello, welcome to Dograh. How can I help you today?" : "Enter a dynamic prompt"}
placeholder="Enter a prompt"
/>
<div className="flex items-center space-x-2">
<Switch id="allow-interrupt" checked={allowInterrupt} onCheckedChange={setAllowInterrupt} />
@ -230,34 +206,10 @@ const StartCallEditForm = ({
id="add-global-prompt"
checked={addGlobalPrompt}
onCheckedChange={setAddGlobalPrompt}
disabled={isStatic}
/>
<Label htmlFor="add-global-prompt" className={isStatic ? "opacity-50" : ""}>
<Label htmlFor="add-global-prompt">
Add Global Prompt
</Label>
<Label className={`text-xs text-gray-500 ${isStatic ? "opacity-50" : ""}`}>
{isStatic
? "Not applicable for static text"
: "Whether you want to add global prompt with this node's prompt."}
</Label>
</div>
<div className="flex flex-col space-y-2">
<div className="flex items-center space-x-2">
<Switch
id="wait-for-user-response"
checked={waitForUserResponse}
onCheckedChange={setWaitForUserResponse}
disabled={!isStatic}
/>
<Label htmlFor="wait-for-user-response" className={!isStatic ? "opacity-50" : ""}>
Wait for user&apos;s response
</Label>
<Label className={`text-xs text-gray-500 ${!isStatic ? "opacity-50" : ""}`}>
{!isStatic
? "Only applicable for static text"
: "Wait for user to respond before disconnecting the call."}
</Label>
</div>
</div>
{!isOSSMode() && (
<div className="flex items-center space-x-2">

View file

@ -20,8 +20,6 @@ export type FlowNodeData = {
extraction_prompt?: string;
extraction_variables?: ExtractionVariable[];
add_global_prompt?: boolean;
wait_for_user_response?: boolean;
wait_for_user_response_timeout?: number;
wait_for_user_greeting?: boolean;
detect_voicemail?: boolean;
delayed_start?: boolean;