mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: simplify pipecat engine execution
This commit is contained in:
parent
5e4aef346d
commit
cc05f363ff
35 changed files with 545 additions and 1861 deletions
|
|
@ -5,15 +5,15 @@ Revises: 982ec8e434be
|
|||
Create Date: 2025-10-21 12:28:06.053318
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from alembic_postgresql_enum import TableReference
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a57d25b75117'
|
||||
down_revision: Union[str, None] = '982ec8e434be'
|
||||
revision: str = "a57d25b75117"
|
||||
down_revision: Union[str, None] = "982ec8e434be"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
|
@ -26,12 +26,20 @@ def upgrade() -> None:
|
|||
2. Migrates TWILIO_CONFIGURATION key to TELEPHONY_CONFIGURATION
|
||||
3. Renames twilio_status_callbacks to telephony_status_callbacks in workflow_run logs
|
||||
"""
|
||||
|
||||
|
||||
# Add 'vonage' to the workflow_run_mode enum
|
||||
op.sync_enum_values(
|
||||
enum_schema="public",
|
||||
enum_name="workflow_run_mode",
|
||||
new_values=["twilio", "stasis", "webrtc", "smallwebrtc", "VOICE", "CHAT", "vonage"],
|
||||
new_values=[
|
||||
"twilio",
|
||||
"stasis",
|
||||
"webrtc",
|
||||
"smallwebrtc",
|
||||
"VOICE",
|
||||
"CHAT",
|
||||
"vonage",
|
||||
],
|
||||
affected_columns=[
|
||||
TableReference(
|
||||
table_schema="public", table_name="workflow_runs", column_name="mode"
|
||||
|
|
@ -39,14 +47,14 @@ def upgrade() -> None:
|
|||
],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
|
||||
|
||||
# Rename the key from TWILIO_CONFIGURATION to TELEPHONY_CONFIGURATION
|
||||
op.execute("""
|
||||
UPDATE organization_configurations
|
||||
SET key = 'TELEPHONY_CONFIGURATION'
|
||||
WHERE key = 'TWILIO_CONFIGURATION';
|
||||
""")
|
||||
|
||||
|
||||
# Rename twilio_status_callbacks to telephony_status_callbacks in workflow_run logs
|
||||
op.execute("""
|
||||
UPDATE workflow_runs
|
||||
|
|
@ -57,15 +65,17 @@ def upgrade() -> None:
|
|||
)
|
||||
WHERE logs::jsonb ? 'twilio_status_callbacks';
|
||||
""")
|
||||
|
||||
print("Migration complete: Added vonage to enum, renamed configuration key, and updated status callback keys")
|
||||
|
||||
print(
|
||||
"Migration complete: Added vonage to enum, renamed configuration key, and updated status callback keys"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""
|
||||
Revert configuration key names and enum.
|
||||
"""
|
||||
|
||||
|
||||
# Revert telephony_status_callbacks to twilio_status_callbacks in workflow_run logs
|
||||
op.execute("""
|
||||
UPDATE workflow_runs
|
||||
|
|
@ -76,14 +86,14 @@ def downgrade() -> None:
|
|||
)
|
||||
WHERE logs::jsonb ? 'telephony_status_callbacks';
|
||||
""")
|
||||
|
||||
|
||||
# Revert key name
|
||||
op.execute("""
|
||||
UPDATE organization_configurations
|
||||
SET key = 'TWILIO_CONFIGURATION'
|
||||
WHERE key = 'TELEPHONY_CONFIGURATION';
|
||||
""")
|
||||
|
||||
|
||||
# Revert enum to previous state
|
||||
op.sync_enum_values(
|
||||
enum_schema="public",
|
||||
|
|
@ -96,5 +106,5 @@ def downgrade() -> None:
|
|||
],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
|
||||
print("Downgrade complete: Reverted configuration key names and enum")
|
||||
|
||||
print("Downgrade complete: Reverted configuration key names and enum")
|
||||
|
|
|
|||
|
|
@ -63,8 +63,12 @@ class OrganizationConfigurationKey(Enum):
|
|||
DISPOSITION_CODE_MAPPING = "DISPOSITION_CODE_MAPPING"
|
||||
DISPOSITION_MESSAGE_TEMPLATE = "DISPOSITION_MESSAGE_TEMPLATE"
|
||||
CONCURRENT_CALL_LIMIT = "CONCURRENT_CALL_LIMIT"
|
||||
TELEPHONY_CONFIGURATION = "TELEPHONY_CONFIGURATION" # Stores all providers + active one
|
||||
TWILIO_CONFIGURATION = "TWILIO_CONFIGURATION" # Deprecated - for backward compatibility
|
||||
TELEPHONY_CONFIGURATION = (
|
||||
"TELEPHONY_CONFIGURATION" # Stores all providers + active one
|
||||
)
|
||||
TWILIO_CONFIGURATION = (
|
||||
"TWILIO_CONFIGURATION" # Deprecated - for backward compatibility
|
||||
)
|
||||
|
||||
|
||||
class WorkflowStatus(Enum):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
langfuse==3.4.0
|
||||
langfuse==3.9.3
|
||||
fastapi==0.116.2
|
||||
asyncpg==0.30.0
|
||||
alembic==1.16.5
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from typing import Union
|
||||
from api.schemas.telephony_config import (
|
||||
TelephonyConfigurationResponse,
|
||||
TwilioConfigurationRequest,
|
||||
|
|
@ -19,14 +20,13 @@ router = APIRouter(prefix="/organizations", tags=["organizations"])
|
|||
# Provider configuration constants
|
||||
PROVIDER_MASKED_FIELDS = {
|
||||
"twilio": ["account_sid", "auth_token"],
|
||||
"vonage": ["private_key", "api_key", "api_secret"]
|
||||
"vonage": ["private_key", "api_key", "api_secret"],
|
||||
}
|
||||
|
||||
|
||||
# TODO: Make endpoints provider-agnostic
|
||||
@router.get("/telephony-config", response_model=TelephonyConfigurationResponse)
|
||||
async def get_telephony_configuration(
|
||||
user: UserModel = Depends(get_user)
|
||||
):
|
||||
async def get_telephony_configuration(user: UserModel = Depends(get_user)):
|
||||
"""Get telephony configuration for the user's organization with masked sensitive fields."""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
|
|
@ -40,11 +40,13 @@ async def get_telephony_configuration(
|
|||
return TelephonyConfigurationResponse()
|
||||
|
||||
stored_provider = config.value.get("provider", "twilio")
|
||||
|
||||
|
||||
if stored_provider == "twilio":
|
||||
account_sid = config.value.get("account_sid", "")
|
||||
auth_token = config.value.get("auth_token", "")
|
||||
from_numbers = config.value.get("from_numbers", []) if account_sid and auth_token else []
|
||||
from_numbers = (
|
||||
config.value.get("from_numbers", []) if account_sid and auth_token else []
|
||||
)
|
||||
|
||||
return TelephonyConfigurationResponse(
|
||||
twilio=TwilioConfigurationResponse(
|
||||
|
|
@ -53,15 +55,19 @@ async def get_telephony_configuration(
|
|||
auth_token=mask_key(auth_token) if auth_token else "",
|
||||
from_numbers=from_numbers,
|
||||
),
|
||||
vonage=None
|
||||
vonage=None,
|
||||
)
|
||||
elif stored_provider == "vonage":
|
||||
application_id = config.value.get("application_id", "")
|
||||
private_key = config.value.get("private_key", "")
|
||||
api_key = config.value.get("api_key", "")
|
||||
api_secret = config.value.get("api_secret", "")
|
||||
from_numbers = config.value.get("from_numbers", []) if application_id and private_key else []
|
||||
|
||||
from_numbers = (
|
||||
config.value.get("from_numbers", [])
|
||||
if application_id and private_key
|
||||
else []
|
||||
)
|
||||
|
||||
return TelephonyConfigurationResponse(
|
||||
twilio=None,
|
||||
vonage=VonageConfigurationResponse(
|
||||
|
|
@ -71,7 +77,7 @@ async def get_telephony_configuration(
|
|||
api_key=mask_key(api_key) if api_key else None,
|
||||
api_secret=mask_key(api_secret) if api_secret else None,
|
||||
from_numbers=from_numbers,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
return TelephonyConfigurationResponse()
|
||||
|
|
@ -79,8 +85,8 @@ async def get_telephony_configuration(
|
|||
|
||||
@router.post("/telephony-config")
|
||||
async def save_telephony_configuration(
|
||||
request: Union[TwilioConfigurationRequest, VonageConfigurationRequest],
|
||||
user: UserModel = Depends(get_user)
|
||||
request: Union[TwilioConfigurationRequest, VonageConfigurationRequest],
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
"""Save telephony configuration for the user's organization."""
|
||||
if not user.selected_organization_id:
|
||||
|
|
@ -105,12 +111,14 @@ async def save_telephony_configuration(
|
|||
"provider": "vonage",
|
||||
"application_id": request.application_id,
|
||||
"private_key": request.private_key,
|
||||
"api_key": getattr(request, 'api_key', None),
|
||||
"api_secret": getattr(request, 'api_secret', None),
|
||||
"api_key": getattr(request, "api_key", None),
|
||||
"api_secret": getattr(request, "api_secret", None),
|
||||
"from_numbers": request.from_numbers,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported provider: {request.provider}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Unsupported provider: {request.provider}"
|
||||
)
|
||||
|
||||
if existing_config and existing_config.value:
|
||||
existing_provider = existing_config.value.get("provider")
|
||||
|
|
@ -126,14 +134,16 @@ async def save_telephony_configuration(
|
|||
|
||||
return {"message": "Telephony configuration saved successfully"}
|
||||
|
||||
def preserve_masked_fields(request, existing_config, config_value):
|
||||
|
||||
def preserve_masked_fields(request, existing_config, config_value):
|
||||
provider = request.provider
|
||||
masked_fields = PROVIDER_MASKED_FIELDS.get(provider, [])
|
||||
|
||||
|
||||
for field_name in masked_fields:
|
||||
if hasattr(request, field_name):
|
||||
field_value = getattr(request, field_name)
|
||||
# Check if field has a value and is a masked version of the existing value
|
||||
if field_value and is_mask_of(field_value, existing_config.value.get(field_name, "")):
|
||||
if field_value and is_mask_of(
|
||||
field_value, existing_config.value.get(field_name, "")
|
||||
):
|
||||
config_value[field_name] = existing_config.value[field_name]
|
||||
|
|
|
|||
|
|
@ -1,19 +1,19 @@
|
|||
"""
|
||||
Generic telephony routes that work with any telephony provider.
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Optional
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Header, HTTPException, Request, WebSocket
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, WebSocket
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.campaign.campaign_event_publisher import get_campaign_event_publisher
|
||||
|
|
@ -32,6 +32,7 @@ class InitiateCallRequest(BaseModel):
|
|||
|
||||
class StatusCallbackRequest(BaseModel):
|
||||
"""Generic status callback that can handle different providers"""
|
||||
|
||||
# Common fields
|
||||
call_id: str
|
||||
status: str
|
||||
|
|
@ -39,10 +40,10 @@ class StatusCallbackRequest(BaseModel):
|
|||
to_number: Optional[str] = None
|
||||
direction: Optional[str] = None
|
||||
duration: Optional[str] = None
|
||||
|
||||
|
||||
# Provider-specific fields stored as extra
|
||||
extra: dict = {}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_twilio(cls, data: dict):
|
||||
"""Convert Twilio callback to generic format"""
|
||||
|
|
@ -53,9 +54,9 @@ class StatusCallbackRequest(BaseModel):
|
|||
to_number=data.get("To"),
|
||||
direction=data.get("Direction"),
|
||||
duration=data.get("CallDuration") or data.get("Duration"),
|
||||
extra=data
|
||||
extra=data,
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_vonage(cls, data: dict):
|
||||
"""Convert Vonage event to generic format"""
|
||||
|
|
@ -63,14 +64,14 @@ class StatusCallbackRequest(BaseModel):
|
|||
status_map = {
|
||||
"started": "initiated",
|
||||
"ringing": "ringing",
|
||||
"answered": "answered",
|
||||
"answered": "answered",
|
||||
"complete": "completed",
|
||||
"failed": "failed",
|
||||
"busy": "busy",
|
||||
"timeout": "no-answer",
|
||||
"rejected": "busy"
|
||||
"rejected": "busy",
|
||||
}
|
||||
|
||||
|
||||
return cls(
|
||||
call_id=data.get("uuid", ""),
|
||||
status=status_map.get(data.get("status", ""), data.get("status", "")),
|
||||
|
|
@ -78,7 +79,7 @@ class StatusCallbackRequest(BaseModel):
|
|||
to_number=data.get("to"),
|
||||
direction=data.get("direction"),
|
||||
duration=data.get("duration"),
|
||||
extra=data
|
||||
extra=data,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -87,32 +88,32 @@ async def initiate_call(
|
|||
request: InitiateCallRequest, user: UserModel = Depends(get_user)
|
||||
):
|
||||
"""Initiate a call using the configured telephony provider."""
|
||||
|
||||
|
||||
# Get the telephony provider for the organization
|
||||
provider = await get_telephony_provider(user.selected_organization_id)
|
||||
|
||||
|
||||
# Validate provider is configured
|
||||
if not provider.validate_config():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="telephony_not_configured",
|
||||
)
|
||||
|
||||
|
||||
# Determine the workflow run mode based on provider type
|
||||
workflow_run_mode = provider.PROVIDER_NAME
|
||||
|
||||
|
||||
user_configuration = await db_client.get_user_configurations(user.id)
|
||||
|
||||
|
||||
phone_number = request.phone_number or user_configuration.test_phone_number
|
||||
|
||||
|
||||
if not phone_number:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Phone number must be provided in request or set in user configuration"
|
||||
status_code=400,
|
||||
detail="Phone number must be provided in request or set in user configuration",
|
||||
)
|
||||
|
||||
|
||||
workflow_run_id = request.workflow_run_id
|
||||
|
||||
|
||||
if not workflow_run_id:
|
||||
workflow_run_name = f"WR-TEL-{random.randint(1000, 9999)}"
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
|
|
@ -130,12 +131,12 @@ async def initiate_call(
|
|||
if not workflow_run:
|
||||
raise HTTPException(status_code=400, detail="Workflow run not found")
|
||||
workflow_run_name = workflow_run.name
|
||||
|
||||
|
||||
# Construct webhook URL based on provider type
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
|
||||
|
||||
webhook_endpoint = provider.WEBHOOK_ENDPOINT
|
||||
|
||||
|
||||
webhook_url = (
|
||||
f"https://{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
|
||||
f"?workflow_id={request.workflow_id}"
|
||||
|
|
@ -143,35 +144,29 @@ async def initiate_call(
|
|||
f"&workflow_run_id={workflow_run_id}"
|
||||
f"&organization_id={user.selected_organization_id}"
|
||||
)
|
||||
|
||||
|
||||
# Initiate call via provider
|
||||
result = await provider.initiate_call(
|
||||
to_number=phone_number,
|
||||
webhook_url=webhook_url,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
|
||||
# Store provider type and any provider-specific metadata in workflow run context
|
||||
gathered_context = {
|
||||
"provider": provider.PROVIDER_NAME,
|
||||
**(result.provider_metadata or {})
|
||||
**(result.provider_metadata or {}),
|
||||
}
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
gathered_context=gathered_context
|
||||
run_id=workflow_run_id, gathered_context=gathered_context
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Call initiated successfully with run name {workflow_run_name}"
|
||||
}
|
||||
|
||||
return {"message": f"Call initiated successfully with run name {workflow_run_name}"}
|
||||
|
||||
|
||||
@router.post("/twiml", include_in_schema=False)
|
||||
async def handle_twiml_webhook(
|
||||
workflow_id: int,
|
||||
user_id: int,
|
||||
workflow_run_id: int,
|
||||
organization_id: int
|
||||
workflow_id: int, user_id: int, workflow_run_id: int, organization_id: int
|
||||
):
|
||||
"""
|
||||
Handle initial webhook from telephony provider.
|
||||
|
|
@ -179,32 +174,32 @@ async def handle_twiml_webhook(
|
|||
"""
|
||||
|
||||
provider = await get_telephony_provider(organization_id)
|
||||
|
||||
|
||||
response_content = await provider.get_webhook_response(
|
||||
workflow_id, user_id, workflow_run_id
|
||||
)
|
||||
|
||||
|
||||
return HTMLResponse(content=response_content, media_type="application/xml")
|
||||
|
||||
|
||||
@router.get("/ncco", include_in_schema=False)
|
||||
async def handle_ncco_webhook(
|
||||
workflow_id: int,
|
||||
user_id: int,
|
||||
workflow_id: int,
|
||||
user_id: int,
|
||||
workflow_run_id: int,
|
||||
organization_id: Optional[int] = None
|
||||
organization_id: Optional[int] = None,
|
||||
):
|
||||
"""Handle NCCO (Nexmo Call Control Objects) webhook for Vonage.
|
||||
|
||||
|
||||
Returns JSON response instead of XML like TwiML.
|
||||
"""
|
||||
|
||||
provider = await get_telephony_provider(organization_id or user_id)
|
||||
|
||||
|
||||
response_content = await provider.get_webhook_response(
|
||||
workflow_id, user_id, workflow_run_id
|
||||
)
|
||||
|
||||
|
||||
return json.loads(response_content)
|
||||
|
||||
|
||||
|
|
@ -218,36 +213,38 @@ async def websocket_endpoint(
|
|||
try:
|
||||
# Set the run context
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
|
||||
# Get workflow run to determine provider type
|
||||
workflow_run = await db_client.get_workflow_run(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.error(f"Workflow run {workflow_run_id} not found")
|
||||
await websocket.close(code=4404, reason="Workflow run not found")
|
||||
return
|
||||
|
||||
|
||||
# Get workflow for organization info
|
||||
workflow = await db_client.get_workflow(workflow_id)
|
||||
if not workflow:
|
||||
logger.error(f"Workflow {workflow_id} not found")
|
||||
await websocket.close(code=4404, reason="Workflow not found")
|
||||
return
|
||||
|
||||
|
||||
# Extract provider type from workflow run context
|
||||
provider_type = None
|
||||
if workflow_run.gathered_context:
|
||||
provider_type = workflow_run.gathered_context.get("provider")
|
||||
|
||||
|
||||
if not provider_type:
|
||||
logger.error(f"No provider type found in workflow run {workflow_run_id}")
|
||||
await websocket.close(code=4400, reason="Provider type not found")
|
||||
return
|
||||
|
||||
logger.info(f"WebSocket connected for {provider_type} provider, workflow_run {workflow_run_id}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"WebSocket connected for {provider_type} provider, workflow_run {workflow_run_id}"
|
||||
)
|
||||
|
||||
# Get the telephony provider instance
|
||||
provider = await get_telephony_provider(workflow.organization_id)
|
||||
|
||||
|
||||
# Verify the provider matches what was stored
|
||||
if provider.PROVIDER_NAME != provider_type:
|
||||
logger.error(
|
||||
|
|
@ -255,10 +252,12 @@ async def websocket_endpoint(
|
|||
)
|
||||
await websocket.close(code=4400, reason="Provider mismatch")
|
||||
return
|
||||
|
||||
|
||||
# Delegate to provider-specific handler
|
||||
await provider.handle_websocket(websocket, workflow_id, user_id, workflow_run_id)
|
||||
|
||||
await provider.handle_websocket(
|
||||
websocket, workflow_id, user_id, workflow_run_id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket connection: {e}")
|
||||
await websocket.close(1011, "Internal server error")
|
||||
|
|
@ -271,44 +270,46 @@ async def handle_twilio_status_callback(
|
|||
x_webhook_signature: Optional[str] = Header(None),
|
||||
):
|
||||
"""Handle Twilio-specific status callbacks."""
|
||||
|
||||
|
||||
# Parse form data
|
||||
form_data = await request.form()
|
||||
callback_data = dict(form_data)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Received status callback: {json.dumps(callback_data)}"
|
||||
)
|
||||
|
||||
|
||||
# Get workflow run to find organization
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(f"Workflow run {workflow_run_id} not found for status callback")
|
||||
return {"status": "ignored", "reason": "workflow_run_not_found"}
|
||||
|
||||
|
||||
# Get workflow and provider
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning(f"Workflow {workflow_run.workflow_id} not found")
|
||||
return {"status": "ignored", "reason": "workflow_not_found"}
|
||||
|
||||
|
||||
provider = await get_telephony_provider(workflow.organization_id)
|
||||
|
||||
|
||||
if x_webhook_signature:
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
full_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
|
||||
|
||||
|
||||
is_valid = await provider.verify_webhook_signature(
|
||||
full_url, callback_data, x_webhook_signature
|
||||
)
|
||||
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"Invalid webhook signature for workflow run {workflow_run_id}")
|
||||
logger.warning(
|
||||
f"Invalid webhook signature for workflow run {workflow_run_id}"
|
||||
)
|
||||
return {"status": "error", "reason": "invalid_signature"}
|
||||
|
||||
|
||||
# Parse the callback data into generic format
|
||||
parsed_data = provider.parse_status_callback(callback_data)
|
||||
|
||||
|
||||
# Create StatusCallbackRequest from parsed data
|
||||
status_update = StatusCallbackRequest(
|
||||
call_id=parsed_data["call_id"],
|
||||
|
|
@ -317,22 +318,20 @@ async def handle_twilio_status_callback(
|
|||
to_number=parsed_data.get("to_number"),
|
||||
direction=parsed_data.get("direction"),
|
||||
duration=parsed_data.get("duration"),
|
||||
extra=parsed_data.get("extra", {})
|
||||
extra=parsed_data.get("extra", {}),
|
||||
)
|
||||
|
||||
|
||||
# Process the status update
|
||||
await _process_status_update(workflow_run_id, status_update, workflow_run)
|
||||
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
async def _process_status_update(
|
||||
workflow_run_id: int,
|
||||
status: StatusCallbackRequest,
|
||||
workflow_run: any
|
||||
workflow_run_id: int, status: StatusCallbackRequest, workflow_run: any
|
||||
):
|
||||
"""Process status updates from telephony providers."""
|
||||
|
||||
|
||||
# Log the status callback
|
||||
telephony_callback_logs = workflow_run.logs.get("telephony_status_callbacks", [])
|
||||
telephony_callback_log = {
|
||||
|
|
@ -340,31 +339,29 @@ async def _process_status_update(
|
|||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"call_id": status.call_id,
|
||||
"duration": status.duration,
|
||||
**status.extra # Include provider-specific data
|
||||
**status.extra, # Include provider-specific data
|
||||
}
|
||||
telephony_callback_logs.append(telephony_callback_log)
|
||||
|
||||
|
||||
# Update workflow run logs
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
logs={"telephony_status_callbacks": telephony_callback_logs},
|
||||
)
|
||||
|
||||
|
||||
# Handle call completion
|
||||
if status.status == "completed":
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Call completed with duration: {status.duration}s"
|
||||
)
|
||||
|
||||
|
||||
# Release concurrent slot if this was a campaign call
|
||||
if workflow_run.campaign_id:
|
||||
await campaign_call_dispatcher.release_call_slot(workflow_run_id)
|
||||
|
||||
|
||||
# Mark workflow run as completed
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id, is_completed=True
|
||||
)
|
||||
|
||||
await db_client.update_workflow_run(run_id=workflow_run_id, is_completed=True)
|
||||
|
||||
# Publish campaign event if applicable
|
||||
if workflow_run.campaign_id:
|
||||
publisher = await get_campaign_event_publisher()
|
||||
|
|
@ -374,32 +371,40 @@ async def _process_status_update(
|
|||
queued_run_id=workflow_run.queued_run_id,
|
||||
call_duration=int(status.duration) if status.duration else 0,
|
||||
)
|
||||
|
||||
|
||||
elif status.status in ["failed", "busy", "no-answer", "canceled"]:
|
||||
logger.warning(f"[run {workflow_run_id}] Call failed with status: {status.status}")
|
||||
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Call failed with status: {status.status}"
|
||||
)
|
||||
|
||||
# Release concurrent slot for terminal statuses if this was a campaign call
|
||||
if workflow_run.campaign_id:
|
||||
await campaign_call_dispatcher.release_call_slot(workflow_run_id)
|
||||
|
||||
|
||||
# Check if retry is needed for campaign calls (busy/no-answer)
|
||||
if status.status in ["busy", "no-answer"] and workflow_run.campaign_id:
|
||||
publisher = await get_campaign_event_publisher()
|
||||
await publisher.publish_retry_needed(
|
||||
workflow_run_id=workflow_run_id,
|
||||
reason=status.status.replace("-", "_"), # Convert no-answer to no_answer
|
||||
reason=status.status.replace(
|
||||
"-", "_"
|
||||
), # Convert no-answer to no_answer
|
||||
campaign_id=workflow_run.campaign_id,
|
||||
queued_run_id=workflow_run.queued_run_id,
|
||||
)
|
||||
|
||||
|
||||
# Mark workflow run as completed with failure tags
|
||||
call_tags = workflow_run.gathered_context.get("call_tags", []) if workflow_run.gathered_context else []
|
||||
call_tags = (
|
||||
workflow_run.gathered_context.get("call_tags", [])
|
||||
if workflow_run.gathered_context
|
||||
else []
|
||||
)
|
||||
call_tags.extend(["not_connected", f"telephony_{status.status.lower()}"])
|
||||
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
is_completed=True,
|
||||
gathered_context={"call_tags": call_tags}
|
||||
gathered_context={"call_tags": call_tags},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -409,20 +414,20 @@ async def handle_vonage_events(
|
|||
workflow_run_id: int,
|
||||
):
|
||||
"""Handle Vonage-specific event webhooks.
|
||||
|
||||
|
||||
Vonage sends all call events to a single endpoint.
|
||||
Events include: started, ringing, answered, complete, failed, etc.
|
||||
"""
|
||||
# Parse the event data
|
||||
event_data = await request.json()
|
||||
logger.info(f"[run {workflow_run_id}] Received Vonage event: {event_data}")
|
||||
|
||||
|
||||
# Get workflow run for processing
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.error(f"[run {workflow_run_id}] Workflow run not found")
|
||||
return {"status": "error", "message": "Workflow run not found"}
|
||||
|
||||
|
||||
# For a completed call that includes cost info, capture it immediately
|
||||
if event_data.get("status") == "completed":
|
||||
# Vonage sometimes includes price info in the webhook
|
||||
|
|
@ -436,27 +441,32 @@ async def handle_vonage_events(
|
|||
if "rate" in event_data:
|
||||
cost_info["vonage_webhook_rate"] = float(event_data["rate"])
|
||||
if "duration" in event_data:
|
||||
cost_info["vonage_webhook_duration"] = int(event_data["duration"])
|
||||
|
||||
cost_info["vonage_webhook_duration"] = int(
|
||||
event_data["duration"]
|
||||
)
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
cost_info=cost_info
|
||||
run_id=workflow_run_id, cost_info=cost_info
|
||||
)
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Captured Vonage cost info from webhook"
|
||||
)
|
||||
logger.info(f"[run {workflow_run_id}] Captured Vonage cost info from webhook")
|
||||
except Exception as e:
|
||||
logger.error(f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}")
|
||||
|
||||
logger.error(
|
||||
f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}"
|
||||
)
|
||||
|
||||
# Get workflow and provider
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.error(f"[run {workflow_run_id}] Workflow not found")
|
||||
return {"status": "error", "message": "Workflow not found"}
|
||||
|
||||
|
||||
provider = await get_telephony_provider(workflow.organization_id)
|
||||
|
||||
|
||||
# Parse the event data into generic format
|
||||
parsed_data = provider.parse_status_callback(event_data)
|
||||
|
||||
|
||||
# Create StatusCallbackRequest from parsed data
|
||||
status_update = StatusCallbackRequest(
|
||||
call_id=parsed_data["call_id"],
|
||||
|
|
@ -465,11 +475,11 @@ async def handle_vonage_events(
|
|||
to_number=parsed_data.get("to_number"),
|
||||
direction=parsed_data.get("direction"),
|
||||
duration=parsed_data.get("duration"),
|
||||
extra=parsed_data.get("extra", {})
|
||||
extra=parsed_data.get("extra", {}),
|
||||
)
|
||||
|
||||
|
||||
# Process the status update
|
||||
await _process_status_update(workflow_run_id, status_update, workflow_run)
|
||||
|
||||
|
||||
# Return 204 No Content as expected by Vonage
|
||||
return {"status": "ok"}
|
||||
return {"status": "ok"}
|
||||
|
|
|
|||
|
|
@ -124,7 +124,9 @@ class SignalingManager:
|
|||
)
|
||||
else:
|
||||
# Create new connection using correct SmallWebRTC API
|
||||
pc = SmallWebRTCConnection(ice_servers=ice_servers, connection_timeout_secs=60)
|
||||
pc = SmallWebRTCConnection(
|
||||
ice_servers=ice_servers, connection_timeout_secs=60
|
||||
)
|
||||
# Set the pc_id before initialization so it's available in get_answer()
|
||||
pc._pc_id = pc_id
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.db.models import QueuedRunModel, WorkflowRunModel
|
||||
from api.enums import OrganizationConfigurationKey, WorkflowRunMode
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.services.campaign.rate_limiter import rate_limiter
|
||||
from api.services.telephony.factory import get_telephony_provider
|
||||
from api.services.telephony.base import TelephonyProvider
|
||||
from api.services.telephony.factory import get_telephony_provider
|
||||
from api.utils.tunnel import TunnelURLProvider
|
||||
|
||||
|
||||
|
|
@ -238,7 +238,7 @@ class CampaignCallDispatcher:
|
|||
f"&campaign_id={campaign.id}"
|
||||
f"&organization_id={campaign.organization_id}"
|
||||
)
|
||||
|
||||
|
||||
call_result = await provider.initiate_call(
|
||||
to_number=phone_number,
|
||||
webhook_url=webhook_url,
|
||||
|
|
@ -255,7 +255,9 @@ class CampaignCallDispatcher:
|
|||
)
|
||||
|
||||
# Update workflow run as failed
|
||||
telephony_callback_logs = workflow_run.logs.get("telephony_status_callbacks", [])
|
||||
telephony_callback_logs = workflow_run.logs.get(
|
||||
"telephony_status_callbacks", []
|
||||
)
|
||||
telephony_callback_log = {
|
||||
"status": "failed",
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
|
|
|
|||
|
|
@ -24,6 +24,9 @@ from api.services.workflow.dto import ReactFlowDTO
|
|||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.filters.stt_mute_filter import (
|
||||
STTMuteConfig,
|
||||
STTMuteFilter,
|
||||
|
|
@ -83,7 +86,8 @@ class LoopTalkPipelineBuilder:
|
|||
audio_buffer, audio_synchronizer, transcript, context = (
|
||||
create_pipeline_components(audio_config)
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
# Get workflow graph
|
||||
workflow_graph = WorkflowGraph(
|
||||
|
|
@ -113,7 +117,6 @@ class LoopTalkPipelineBuilder:
|
|||
pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor(
|
||||
max_call_duration_seconds=300,
|
||||
max_duration_end_task_callback=engine.create_max_duration_callback(),
|
||||
llm_generated_text_callback=engine.create_llm_generated_text_callback(),
|
||||
generation_started_callback=engine.create_generation_started_callback(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -272,14 +272,6 @@ class LoopTalkTestOrchestrator:
|
|||
|
||||
await task.cancel()
|
||||
|
||||
# Connect the context aggregator events to engine
|
||||
@assistant_context_aggregator.event_handler("on_push_aggregation")
|
||||
async def on_assistant_aggregator_push_context(_aggregator):
|
||||
logger.debug(
|
||||
"Assistant aggregator push context – flushing pending transitions"
|
||||
)
|
||||
await engine.flush_pending_transitions()
|
||||
|
||||
# Register custom audio and transcript handlers for LoopTalk
|
||||
await self._register_looptalk_handlers(
|
||||
audio_synchronizer, transcript, test_session_id, role
|
||||
|
|
|
|||
|
|
@ -1,69 +0,0 @@
|
|||
"""Engine Pre-Aggregator Processor
|
||||
|
||||
This processor sits before the user context aggregator in the pipeline and handles
|
||||
engine-specific callbacks for frames that need to be processed before aggregation.
|
||||
This ensures the engine can update context before the aggregator generates LLM frames.
|
||||
"""
|
||||
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.pipecat.exceptions import VoicemailDetectedException
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class EnginePreAggregatorProcessor(FrameProcessor):
|
||||
"""
|
||||
Processor that handles engine callbacks before user context aggregation.
|
||||
|
||||
This processor is positioned before the user context aggregator to ensure
|
||||
the engine can update LLM context before aggregation occurs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_started_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
user_stopped_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._user_started_speaking_callback = user_started_speaking_callback
|
||||
self._user_stopped_speaking_callback = user_stopped_speaking_callback
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Handle frames that need engine processing before aggregation
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
try:
|
||||
await self._handle_user_stopped_speaking()
|
||||
except VoicemailDetectedException:
|
||||
# We have detected voicemail, lets not
|
||||
# forward the UserStoppedSpeakingFrame, so that
|
||||
# we don't issue an llm call from user context
|
||||
# aggregator
|
||||
logger.debug("Voicemail detected, not pushing UserStoppedSpeakingFrame")
|
||||
return
|
||||
|
||||
# Always push the frame downstream
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_user_started_speaking(self):
|
||||
"""Handle UserStartedSpeakingFrame before aggregation."""
|
||||
if self._user_started_speaking_callback:
|
||||
# logger.debug("Engine pre-aggregator: User started speaking")
|
||||
await self._user_started_speaking_callback()
|
||||
|
||||
async def _handle_user_stopped_speaking(self):
|
||||
"""Handle UserStoppedSpeakingFrame before aggregation."""
|
||||
if self._user_stopped_speaking_callback:
|
||||
# logger.debug("Engine pre-aggregator: User stopped speaking")
|
||||
await self._user_stopped_speaking_callback()
|
||||
|
|
@ -9,7 +9,7 @@ from api.constants import (
|
|||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
|
||||
from pipecat.processors.audio.audio_synchronizer import AudioSynchronizer
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
|
|
@ -39,7 +39,7 @@ def create_pipeline_components(audio_config: AudioConfig, engine: "PipecatEngine
|
|||
assistant_correct_aggregation_callback=engine.create_aggregation_correction_callback()
|
||||
)
|
||||
|
||||
context = OpenAILLMContext()
|
||||
context = LLMContext()
|
||||
|
||||
return audio_buffer, audio_synchronizer, transcript, context
|
||||
|
||||
|
|
@ -58,7 +58,6 @@ def build_pipeline(
|
|||
stt_mute_filter,
|
||||
pipeline_metrics_aggregator,
|
||||
user_idle_disconnect,
|
||||
engine_pre_aggregator_processor=None,
|
||||
):
|
||||
"""Build the main pipeline with all components"""
|
||||
# Register processors with synchronizer for merged audio
|
||||
|
|
@ -69,16 +68,12 @@ def build_pipeline(
|
|||
processors = [
|
||||
transport.input(), # Transport user input
|
||||
audio_buffer.input(), # Record input audio (only processes InputAudioRawFrame)
|
||||
stt_mute_filter,
|
||||
stt, # STT can now have audio_passthrough=False
|
||||
stt_mute_filter, # STTMuteFilters don't let VAD related events pass through if muted
|
||||
user_idle_disconnect,
|
||||
transcript.user(),
|
||||
]
|
||||
|
||||
# Insert engine pre-aggregator processor if provided (before user aggregator)
|
||||
if engine_pre_aggregator_processor:
|
||||
processors.append(engine_pre_aggregator_processor)
|
||||
|
||||
processors.extend(
|
||||
[
|
||||
user_context_aggregator,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from pipecat.frames.frames import (
|
|||
Frame,
|
||||
HeartbeatFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMGeneratedTextFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
TTSSpeakFrame,
|
||||
|
|
@ -26,7 +25,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
|
|||
self,
|
||||
max_call_duration_seconds: int = 300,
|
||||
max_duration_end_task_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
llm_generated_text_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
generation_started_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
llm_text_frame_callback: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
):
|
||||
|
|
@ -34,7 +32,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
|
|||
self._start_time = None
|
||||
self._max_call_duration_seconds = max_call_duration_seconds
|
||||
self._max_duration_end_task_callback = max_duration_end_task_callback
|
||||
self._llm_generated_text_callback = llm_generated_text_callback
|
||||
self._generation_started_callback = generation_started_callback
|
||||
self._llm_text_frame_callback = llm_text_frame_callback
|
||||
self._end_task_frame_pushed = False
|
||||
|
|
@ -46,8 +43,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
|
|||
await self._start(frame)
|
||||
elif isinstance(frame, HeartbeatFrame):
|
||||
await self._check_call_duration()
|
||||
elif isinstance(frame, LLMGeneratedTextFrame):
|
||||
await self._generated_text_frame(frame)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._generation_started()
|
||||
elif (
|
||||
|
|
@ -74,11 +69,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
|
|||
"Max call duration exceeded. Skipping EndTaskFrame since already sent"
|
||||
)
|
||||
|
||||
async def _generated_text_frame(self, _: LLMGeneratedTextFrame):
|
||||
"""Handle LLMGeneratedTextFrame."""
|
||||
if self._llm_generated_text_callback is not None:
|
||||
await self._llm_generated_text_callback()
|
||||
|
||||
async def _generation_started(self):
|
||||
if self._generation_started_callback:
|
||||
await self._generation_started_callback()
|
||||
|
|
|
|||
|
|
@ -7,9 +7,6 @@ from api.db import db_client
|
|||
from api.db.models import WorkflowModel
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.pipecat.audio_config import AudioConfig, create_audio_config
|
||||
from api.services.pipecat.engine_pre_aggregator_processor import (
|
||||
EnginePreAggregatorProcessor,
|
||||
)
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_audio_data_handler,
|
||||
register_task_event_handler,
|
||||
|
|
@ -43,6 +40,9 @@ from api.services.workflow.pipecat_engine import PipecatEngine
|
|||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.filters.stt_mute_filter import (
|
||||
STTMuteConfig,
|
||||
STTMuteFilter,
|
||||
|
|
@ -119,7 +119,7 @@ async def run_pipeline_vonage(
|
|||
user_id: int,
|
||||
):
|
||||
"""Run pipeline for Vonage WebSocket connections.
|
||||
|
||||
|
||||
Vonage uses raw PCM audio over WebSocket instead of base64-encoded μ-law.
|
||||
The audio is transmitted as binary frames at 16kHz by default.
|
||||
"""
|
||||
|
|
@ -137,7 +137,9 @@ async def run_pipeline_vonage(
|
|||
if "vad_configuration" in workflow.workflow_configurations:
|
||||
vad_config = workflow.workflow_configurations["vad_configuration"]
|
||||
if "ambient_noise_configuration" in workflow.workflow_configurations:
|
||||
ambient_noise_config = workflow.workflow_configurations["ambient_noise_configuration"]
|
||||
ambient_noise_config = workflow.workflow_configurations[
|
||||
"ambient_noise_configuration"
|
||||
]
|
||||
|
||||
try:
|
||||
# Setup audio config for Vonage using the centralized config
|
||||
|
|
@ -355,21 +357,14 @@ async def _run_pipeline(
|
|||
expect_stripped_words=True,
|
||||
correct_aggregation_callback=engine.create_aggregation_correction_callback(),
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params
|
||||
)
|
||||
|
||||
# Create engine pre-aggregator processor for speaking events
|
||||
engine_pre_aggregator_processor = EnginePreAggregatorProcessor(
|
||||
user_started_speaking_callback=engine.create_user_started_speaking_callback(),
|
||||
user_stopped_speaking_callback=engine.create_user_stopped_speaking_callback(),
|
||||
)
|
||||
|
||||
# Create usage metrics aggregator with engine's callback
|
||||
pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor(
|
||||
max_call_duration_seconds=max_call_duration_seconds,
|
||||
max_duration_end_task_callback=engine.create_max_duration_callback(),
|
||||
llm_generated_text_callback=engine.create_llm_generated_text_callback(),
|
||||
generation_started_callback=engine.create_generation_started_callback(),
|
||||
llm_text_frame_callback=engine.handle_llm_text_frame,
|
||||
# Note: speaking event callbacks are now handled by pre-aggregator processor
|
||||
|
|
@ -396,11 +391,6 @@ async def _run_pipeline(
|
|||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
@assistant_context_aggregator.event_handler("on_push_aggregation")
|
||||
async def on_assistant_aggregator_push_context(_aggregator):
|
||||
logger.debug("Assistant aggregator push context – flushing pending transitions")
|
||||
await engine.flush_pending_transitions(source="context_push")
|
||||
|
||||
# Build the pipeline with the STT mute filter and context controller
|
||||
pipeline = build_pipeline(
|
||||
transport,
|
||||
|
|
@ -416,7 +406,6 @@ async def _run_pipeline(
|
|||
stt_mute_filter,
|
||||
pipeline_metrics_aggregator,
|
||||
user_idle_disconnect,
|
||||
engine_pre_aggregator_processor=engine_pre_aggregator_processor,
|
||||
)
|
||||
|
||||
# Create pipeline task with audio configuration
|
||||
|
|
|
|||
|
|
@ -165,14 +165,15 @@ async def create_vonage_transport(
|
|||
|
||||
# Use the factory to load config from database
|
||||
from api.services.telephony.factory import load_telephony_config
|
||||
|
||||
config = await load_telephony_config(organization_id)
|
||||
|
||||
|
||||
if config.get("provider") != "vonage":
|
||||
raise ValueError(f"Expected Vonage provider, got {config.get('provider')}")
|
||||
|
||||
application_id = config.get("application_id")
|
||||
private_key = config.get("private_key")
|
||||
|
||||
|
||||
if not application_id or not private_key:
|
||||
raise ValueError(
|
||||
f"Incomplete Vonage configuration for organization {organization_id}"
|
||||
|
|
@ -186,8 +187,8 @@ async def create_vonage_transport(
|
|||
private_key=private_key,
|
||||
params=VonageFrameSerializer.InputParams(
|
||||
vonage_sample_rate=audio_config.transport_in_sample_rate,
|
||||
sample_rate=audio_config.pipeline_sample_rate
|
||||
)
|
||||
sample_rate=audio_config.pipeline_sample_rate,
|
||||
),
|
||||
)
|
||||
|
||||
# Important: Vonage uses binary WebSocket mode, not text
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ Base telephony provider interface for abstracting telephony services.
|
|||
This allows easy switching between different providers (Twilio, Vonage, etc.)
|
||||
while keeping business logic decoupled from specific implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
|
@ -14,10 +15,15 @@ if TYPE_CHECKING:
|
|||
@dataclass
|
||||
class CallInitiationResult:
|
||||
"""Standardized response from initiate_call across all providers."""
|
||||
call_id: str # Provider's call identifier (SID for Twilio, UUID for Vonage)
|
||||
status: str # Initial status (e.g., "queued", "initiated", "started")
|
||||
provider_metadata: Dict[str, Any] = field(default_factory=dict) # Data that needs to be persisted
|
||||
raw_response: Dict[str, Any] = field(default_factory=dict) # Full provider response for debugging
|
||||
|
||||
call_id: str # Provider's call identifier (SID for Twilio, UUID for Vonage)
|
||||
status: str # Initial status (e.g., "queued", "initiated", "started")
|
||||
provider_metadata: Dict[str, Any] = field(
|
||||
default_factory=dict
|
||||
) # Data that needs to be persisted
|
||||
raw_response: Dict[str, Any] = field(
|
||||
default_factory=dict
|
||||
) # Full provider response for debugging
|
||||
|
||||
|
||||
class TelephonyProvider(ABC):
|
||||
|
|
@ -25,6 +31,7 @@ class TelephonyProvider(ABC):
|
|||
Abstract base class for telephony providers.
|
||||
All telephony providers must implement these core methods.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = None
|
||||
WEBHOOK_ENDPOINT = None
|
||||
|
||||
|
|
@ -38,13 +45,13 @@ class TelephonyProvider(ABC):
|
|||
) -> CallInitiationResult:
|
||||
"""
|
||||
Initiate an outbound call.
|
||||
|
||||
|
||||
Args:
|
||||
to_number: The destination phone number
|
||||
webhook_url: The URL to receive call events
|
||||
workflow_run_id: Optional workflow run ID for tracking
|
||||
**kwargs: Provider-specific additional parameters
|
||||
|
||||
|
||||
Returns:
|
||||
CallInitiationResult with standardized call details
|
||||
"""
|
||||
|
|
@ -54,10 +61,10 @@ class TelephonyProvider(ABC):
|
|||
async def get_call_status(self, call_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the current status of a call.
|
||||
|
||||
|
||||
Args:
|
||||
call_id: The provider-specific call identifier
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing call status information
|
||||
"""
|
||||
|
|
@ -67,7 +74,7 @@ class TelephonyProvider(ABC):
|
|||
async def get_available_phone_numbers(self) -> List[str]:
|
||||
"""
|
||||
Get list of available phone numbers for this provider.
|
||||
|
||||
|
||||
Returns:
|
||||
List of phone numbers that can be used for outbound calls
|
||||
"""
|
||||
|
|
@ -77,7 +84,7 @@ class TelephonyProvider(ABC):
|
|||
def validate_config(self) -> bool:
|
||||
"""
|
||||
Validate that the provider is properly configured.
|
||||
|
||||
|
||||
Returns:
|
||||
True if configuration is valid, False otherwise
|
||||
"""
|
||||
|
|
@ -89,12 +96,12 @@ class TelephonyProvider(ABC):
|
|||
) -> bool:
|
||||
"""
|
||||
Verify webhook signature for security.
|
||||
|
||||
|
||||
Args:
|
||||
url: The webhook URL
|
||||
params: The webhook parameters
|
||||
signature: The signature to verify
|
||||
|
||||
|
||||
Returns:
|
||||
True if signature is valid, False otherwise
|
||||
"""
|
||||
|
|
@ -106,12 +113,12 @@ class TelephonyProvider(ABC):
|
|||
) -> str:
|
||||
"""
|
||||
Generate the initial webhook response for starting a call session.
|
||||
|
||||
|
||||
Args:
|
||||
workflow_id: The workflow ID
|
||||
user_id: The user ID
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
|
||||
Returns:
|
||||
Provider-specific response (e.g., TwiML for Twilio)
|
||||
"""
|
||||
|
|
@ -121,10 +128,10 @@ class TelephonyProvider(ABC):
|
|||
async def get_call_cost(self, call_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cost information for a completed call.
|
||||
|
||||
|
||||
Args:
|
||||
call_id: Provider-specific call identifier (SID for Twilio, UUID for Vonage)
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing:
|
||||
- cost_usd: The cost in USD as float
|
||||
|
|
@ -138,10 +145,10 @@ class TelephonyProvider(ABC):
|
|||
def parse_status_callback(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse provider-specific status callback data into generic format.
|
||||
|
||||
|
||||
Args:
|
||||
data: Raw callback data from the provider
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with standardized fields:
|
||||
- call_id: Provider's call identifier
|
||||
|
|
@ -163,14 +170,14 @@ class TelephonyProvider(ABC):
|
|||
) -> None:
|
||||
"""
|
||||
Handle provider-specific WebSocket connection for real-time call audio.
|
||||
|
||||
|
||||
This method encapsulates all provider-specific WebSocket handshake and
|
||||
message routing logic, keeping the main websocket endpoint clean.
|
||||
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection
|
||||
workflow_id: The workflow ID
|
||||
user_id: The user ID
|
||||
workflow_run_id: The workflow run ID
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ Factory for creating telephony providers.
|
|||
Handles configuration loading from environment (OSS) or database (SaaS).
|
||||
The providers themselves don't know or care where config comes from.
|
||||
"""
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -18,36 +18,36 @@ from api.services.telephony.providers.vonage_provider import VonageProvider
|
|||
async def load_telephony_config(organization_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Load telephony configuration from database.
|
||||
|
||||
|
||||
Args:
|
||||
organization_id: Organization ID for database config
|
||||
|
||||
|
||||
Returns:
|
||||
Configuration dictionary with provider type and credentials
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If no configuration found for the organization
|
||||
"""
|
||||
if not organization_id:
|
||||
raise ValueError("Organization ID is required to load telephony configuration")
|
||||
|
||||
|
||||
logger.debug(f"Loading telephony config from database for org {organization_id}")
|
||||
|
||||
|
||||
config = await db_client.get_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.TELEPHONY_CONFIGURATION.value,
|
||||
)
|
||||
|
||||
|
||||
if config and config.value:
|
||||
# Simple single-provider format
|
||||
provider = config.value.get("provider", "twilio")
|
||||
|
||||
|
||||
if provider == "twilio":
|
||||
return {
|
||||
"provider": "twilio",
|
||||
"account_sid": config.value.get("account_sid"),
|
||||
"auth_token": config.value.get("auth_token"),
|
||||
"from_numbers": config.value.get("from_numbers", [])
|
||||
"from_numbers": config.value.get("from_numbers", []),
|
||||
}
|
||||
elif provider == "vonage":
|
||||
return {
|
||||
|
|
@ -56,41 +56,41 @@ async def load_telephony_config(organization_id: int) -> Dict[str, Any]:
|
|||
"private_key": config.value.get("private_key"),
|
||||
"api_key": config.value.get("api_key"),
|
||||
"api_secret": config.value.get("api_secret"),
|
||||
"from_numbers": config.value.get("from_numbers", [])
|
||||
"from_numbers": config.value.get("from_numbers", []),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown provider in config: {provider}")
|
||||
|
||||
raise ValueError(f"No telephony configuration found for organization {organization_id}")
|
||||
|
||||
raise ValueError(
|
||||
f"No telephony configuration found for organization {organization_id}"
|
||||
)
|
||||
|
||||
|
||||
async def get_telephony_provider(
|
||||
organization_id: int
|
||||
) -> TelephonyProvider:
|
||||
async def get_telephony_provider(organization_id: int) -> TelephonyProvider:
|
||||
"""
|
||||
Factory function to create telephony providers.
|
||||
|
||||
|
||||
Args:
|
||||
organization_id: Organization ID (required)
|
||||
|
||||
|
||||
Returns:
|
||||
Configured telephony provider instance
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If provider type is unknown or configuration is invalid
|
||||
"""
|
||||
# Load configuration
|
||||
config = await load_telephony_config(organization_id)
|
||||
|
||||
|
||||
provider_type = config.get("provider", "twilio")
|
||||
logger.info(f"Creating {provider_type} telephony provider")
|
||||
|
||||
|
||||
# Create provider instance with configuration
|
||||
if provider_type == "twilio":
|
||||
return TwilioProvider(config)
|
||||
|
||||
|
||||
elif provider_type == "vonage":
|
||||
return VonageProvider(config)
|
||||
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown telephony provider: {provider_type}")
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
# Telephony provider implementations
|
||||
# Telephony provider implementations
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Twilio implementation of the TelephonyProvider interface.
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
|
@ -9,9 +10,9 @@ import aiohttp
|
|||
from loguru import logger
|
||||
from twilio.request_validator import RequestValidator
|
||||
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.telephony.base import CallInitiationResult, TelephonyProvider
|
||||
from api.utils.tunnel import TunnelURLProvider
|
||||
from api.enums import WorkflowRunMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
|
|
@ -22,14 +23,14 @@ class TwilioProvider(TelephonyProvider):
|
|||
Twilio implementation of TelephonyProvider.
|
||||
Accepts configuration and works the same regardless of OSS/SaaS mode.
|
||||
"""
|
||||
|
||||
|
||||
PROVIDER_NAME = WorkflowRunMode.TWILIO.value
|
||||
WEBHOOK_ENDPOINT = "twiml"
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize TwilioProvider with configuration.
|
||||
|
||||
|
||||
Args:
|
||||
config: Dictionary containing:
|
||||
- account_sid: Twilio Account SID
|
||||
|
|
@ -39,11 +40,11 @@ class TwilioProvider(TelephonyProvider):
|
|||
self.account_sid = config.get("account_sid")
|
||||
self.auth_token = config.get("auth_token")
|
||||
self.from_numbers = config.get("from_numbers", [])
|
||||
|
||||
|
||||
# Handle both single number (string) and multiple numbers (list)
|
||||
if isinstance(self.from_numbers, str):
|
||||
self.from_numbers = [self.from_numbers]
|
||||
|
||||
|
||||
self.base_url = f"https://api.twilio.com/2010-04-01/Accounts/{self.account_sid}"
|
||||
|
||||
async def initiate_call(
|
||||
|
|
@ -58,32 +59,35 @@ class TwilioProvider(TelephonyProvider):
|
|||
"""
|
||||
if not self.validate_config():
|
||||
raise ValueError("Twilio provider not properly configured")
|
||||
|
||||
|
||||
endpoint = f"{self.base_url}/Calls.json"
|
||||
|
||||
|
||||
# Select a random phone number
|
||||
from_number = random.choice(self.from_numbers)
|
||||
logger.info(f"Selected phone number {from_number} for outbound call")
|
||||
|
||||
|
||||
# Prepare call data
|
||||
data = {
|
||||
"To": to_number,
|
||||
"From": from_number,
|
||||
"Url": webhook_url
|
||||
}
|
||||
|
||||
data = {"To": to_number, "From": from_number, "Url": webhook_url}
|
||||
|
||||
# Add status callback if workflow_run_id provided
|
||||
if workflow_run_id:
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
callback_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
|
||||
data.update({
|
||||
"StatusCallback": callback_url,
|
||||
"StatusCallbackEvent": ["initiated", "ringing", "answered", "completed"],
|
||||
"StatusCallbackMethod": "POST"
|
||||
})
|
||||
|
||||
data.update(
|
||||
{
|
||||
"StatusCallback": callback_url,
|
||||
"StatusCallbackEvent": [
|
||||
"initiated",
|
||||
"ringing",
|
||||
"answered",
|
||||
"completed",
|
||||
],
|
||||
"StatusCallbackMethod": "POST",
|
||||
}
|
||||
)
|
||||
|
||||
data.update(kwargs)
|
||||
|
||||
|
||||
# Make the API request
|
||||
async with aiohttp.ClientSession() as session:
|
||||
auth = aiohttp.BasicAuth(self.account_sid, self.auth_token)
|
||||
|
|
@ -91,14 +95,14 @@ class TwilioProvider(TelephonyProvider):
|
|||
if response.status != 201:
|
||||
error_data = await response.json()
|
||||
raise Exception(f"Failed to initiate call: {error_data}")
|
||||
|
||||
|
||||
response_data = await response.json()
|
||||
|
||||
|
||||
return CallInitiationResult(
|
||||
call_id=response_data["sid"],
|
||||
status=response_data.get("status", "queued"),
|
||||
provider_metadata={}, # Twilio doesn't need to persist extra data
|
||||
raw_response=response_data
|
||||
raw_response=response_data,
|
||||
)
|
||||
|
||||
async def get_call_status(self, call_id: str) -> Dict[str, Any]:
|
||||
|
|
@ -107,16 +111,16 @@ class TwilioProvider(TelephonyProvider):
|
|||
"""
|
||||
if not self.validate_config():
|
||||
raise ValueError("Twilio provider not properly configured")
|
||||
|
||||
|
||||
endpoint = f"{self.base_url}/Calls/{call_id}.json"
|
||||
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
auth = aiohttp.BasicAuth(self.account_sid, self.auth_token)
|
||||
async with session.get(endpoint, auth=auth) as response:
|
||||
if response.status != 200:
|
||||
error_data = await response.json()
|
||||
raise Exception(f"Failed to get call status: {error_data}")
|
||||
|
||||
|
||||
return await response.json()
|
||||
|
||||
async def get_available_phone_numbers(self) -> List[str]:
|
||||
|
|
@ -129,11 +133,7 @@ class TwilioProvider(TelephonyProvider):
|
|||
"""
|
||||
Validate Twilio configuration.
|
||||
"""
|
||||
return bool(
|
||||
self.account_sid and
|
||||
self.auth_token and
|
||||
self.from_numbers
|
||||
)
|
||||
return bool(self.account_sid and self.auth_token and self.from_numbers)
|
||||
|
||||
async def verify_webhook_signature(
|
||||
self, url: str, params: Dict[str, Any], signature: str
|
||||
|
|
@ -144,7 +144,7 @@ class TwilioProvider(TelephonyProvider):
|
|||
if not self.auth_token:
|
||||
logger.error("No auth token available for webhook signature verification")
|
||||
return False
|
||||
|
||||
|
||||
validator = RequestValidator(self.auth_token)
|
||||
return validator.validate(url, params, signature)
|
||||
|
||||
|
|
@ -155,7 +155,7 @@ class TwilioProvider(TelephonyProvider):
|
|||
Generate TwiML response for starting a call session.
|
||||
"""
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
|
||||
|
||||
twiml_content = f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Response>
|
||||
<Connect>
|
||||
|
|
@ -168,15 +168,15 @@ class TwilioProvider(TelephonyProvider):
|
|||
async def get_call_cost(self, call_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cost information for a completed Twilio call.
|
||||
|
||||
|
||||
Args:
|
||||
call_id: The Twilio Call SID
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing cost information
|
||||
"""
|
||||
endpoint = f"{self.base_url}/Calls/{call_id}.json"
|
||||
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
auth = aiohttp.BasicAuth(self.account_sid, self.auth_token)
|
||||
|
|
@ -188,34 +188,29 @@ class TwilioProvider(TelephonyProvider):
|
|||
"cost_usd": 0.0,
|
||||
"duration": 0,
|
||||
"status": "error",
|
||||
"error": str(error_data)
|
||||
"error": str(error_data),
|
||||
}
|
||||
|
||||
|
||||
call_data = await response.json()
|
||||
|
||||
|
||||
# Twilio returns price as a negative string (e.g., "-0.0085")
|
||||
price_str = call_data.get("price", "0")
|
||||
cost_usd = abs(float(price_str)) if price_str else 0.0
|
||||
|
||||
|
||||
# Duration is in seconds as a string
|
||||
duration = int(call_data.get("duration", "0"))
|
||||
|
||||
|
||||
return {
|
||||
"cost_usd": cost_usd,
|
||||
"duration": duration,
|
||||
"status": call_data.get("status", "unknown"),
|
||||
"price_unit": call_data.get("price_unit", "USD"),
|
||||
"raw_response": call_data
|
||||
"raw_response": call_data,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception fetching Twilio call cost: {e}")
|
||||
return {
|
||||
"cost_usd": 0.0,
|
||||
"duration": 0,
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
return {"cost_usd": 0.0, "duration": 0, "status": "error", "error": str(e)}
|
||||
|
||||
def parse_status_callback(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -228,7 +223,7 @@ class TwilioProvider(TelephonyProvider):
|
|||
"to_number": data.get("To"),
|
||||
"direction": data.get("Direction"),
|
||||
"duration": data.get("CallDuration") or data.get("Duration"),
|
||||
"extra": data # Include all original data
|
||||
"extra": data, # Include all original data
|
||||
}
|
||||
|
||||
async def handle_websocket(
|
||||
|
|
@ -240,36 +235,38 @@ class TwilioProvider(TelephonyProvider):
|
|||
) -> None:
|
||||
"""
|
||||
Handle Twilio-specific WebSocket connection.
|
||||
|
||||
|
||||
Twilio sends:
|
||||
1. "connected" event first
|
||||
2. "start" event with streamSid and callSid
|
||||
3. Then audio messages
|
||||
"""
|
||||
from api.services.pipecat.run_pipeline import run_pipeline_twilio
|
||||
|
||||
|
||||
try:
|
||||
# Wait for "connected" event
|
||||
first_msg = await websocket.receive_text()
|
||||
msg = json.loads(first_msg)
|
||||
|
||||
|
||||
if msg.get("event") != "connected":
|
||||
logger.error(f"Expected 'connected' event, got: {msg.get('event')}")
|
||||
await websocket.close(code=4400, reason="Expected connected event")
|
||||
return
|
||||
|
||||
logger.debug(f"Twilio WebSocket connected for workflow_run {workflow_run_id}")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Twilio WebSocket connected for workflow_run {workflow_run_id}"
|
||||
)
|
||||
|
||||
# Wait for "start" event with stream details
|
||||
start_msg = await websocket.receive_text()
|
||||
logger.debug(f"Received start message: {start_msg}")
|
||||
|
||||
|
||||
start_msg = json.loads(start_msg)
|
||||
if start_msg.get("event") != "start":
|
||||
logger.error("Expected 'start' event second")
|
||||
await websocket.close(code=4400, reason="Expected start event")
|
||||
return
|
||||
|
||||
|
||||
# Extract Twilio-specific identifiers
|
||||
try:
|
||||
stream_sid = start_msg["start"]["streamSid"]
|
||||
|
|
@ -278,12 +275,12 @@ class TwilioProvider(TelephonyProvider):
|
|||
logger.error("Missing streamSid or callSid in start message")
|
||||
await websocket.close(code=4400, reason="Missing stream identifiers")
|
||||
return
|
||||
|
||||
|
||||
# Run the Twilio pipeline
|
||||
await run_pipeline_twilio(
|
||||
websocket, stream_sid, call_sid, workflow_id, workflow_run_id, user_id
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Twilio WebSocket handler: {e}")
|
||||
raise
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Vonage (Nexmo) implementation of the TelephonyProvider interface.
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
|
|
@ -10,9 +11,9 @@ import aiohttp
|
|||
import jwt
|
||||
from loguru import logger
|
||||
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.telephony.base import CallInitiationResult, TelephonyProvider
|
||||
from api.utils.tunnel import TunnelURLProvider
|
||||
from api.enums import WorkflowRunMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
|
|
@ -23,14 +24,14 @@ class VonageProvider(TelephonyProvider):
|
|||
Vonage implementation of TelephonyProvider.
|
||||
Uses JWT authentication and NCCO for call control.
|
||||
"""
|
||||
|
||||
|
||||
PROVIDER_NAME = WorkflowRunMode.VONAGE.value
|
||||
WEBHOOK_ENDPOINT = "ncco"
|
||||
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize VonageProvider with configuration.
|
||||
|
||||
|
||||
Args:
|
||||
config: Dictionary containing:
|
||||
- api_key: Vonage API Key
|
||||
|
|
@ -44,25 +45,27 @@ class VonageProvider(TelephonyProvider):
|
|||
self.application_id = config.get("application_id")
|
||||
self.private_key = config.get("private_key")
|
||||
self.from_numbers = config.get("from_numbers", [])
|
||||
|
||||
|
||||
# Handle both single number (string) and multiple numbers (list)
|
||||
if isinstance(self.from_numbers, str):
|
||||
self.from_numbers = [self.from_numbers]
|
||||
|
||||
|
||||
self.base_url = "https://api.nexmo.com"
|
||||
|
||||
def _generate_jwt(self) -> str:
|
||||
"""Generate JWT token for Vonage API authentication."""
|
||||
if not self.application_id or not self.private_key:
|
||||
raise ValueError("Application ID and private key required for JWT generation")
|
||||
|
||||
raise ValueError(
|
||||
"Application ID and private key required for JWT generation"
|
||||
)
|
||||
|
||||
claims = {
|
||||
"application_id": self.application_id,
|
||||
"iat": int(time.time()),
|
||||
"exp": int(time.time()) + 3600,
|
||||
"jti": str(time.time())
|
||||
"jti": str(time.time()),
|
||||
}
|
||||
|
||||
|
||||
return jwt.encode(claims, self.private_key, algorithm="RS256")
|
||||
|
||||
async def initiate_call(
|
||||
|
|
@ -77,68 +80,57 @@ class VonageProvider(TelephonyProvider):
|
|||
"""
|
||||
if not self.validate_config():
|
||||
raise ValueError("Vonage provider not properly configured")
|
||||
|
||||
|
||||
endpoint = f"{self.base_url}/v1/calls"
|
||||
|
||||
|
||||
# Select a random phone number
|
||||
from_number = random.choice(self.from_numbers)
|
||||
# Remove '+' prefix for Vonage
|
||||
from_number = from_number.replace("+", "")
|
||||
to_number = to_number.replace("+", "")
|
||||
|
||||
|
||||
logger.info(f"Selected phone number {from_number} for outbound call")
|
||||
|
||||
|
||||
# Prepare call data
|
||||
data = {
|
||||
"to": [{
|
||||
"type": "phone",
|
||||
"number": to_number
|
||||
}],
|
||||
"from": {
|
||||
"type": "phone",
|
||||
"number": from_number
|
||||
},
|
||||
"to": [{"type": "phone", "number": to_number}],
|
||||
"from": {"type": "phone", "number": from_number},
|
||||
"answer_url": [webhook_url],
|
||||
"answer_method": "GET"
|
||||
"answer_method": "GET",
|
||||
}
|
||||
|
||||
|
||||
# Add event webhook if workflow_run_id provided
|
||||
if workflow_run_id:
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
event_url = f"https://{backend_endpoint}/api/v1/telephony/vonage/events/{workflow_run_id}"
|
||||
data.update({
|
||||
"event_url": [event_url],
|
||||
"event_method": "POST"
|
||||
})
|
||||
|
||||
data.update({"event_url": [event_url], "event_method": "POST"})
|
||||
|
||||
data.update(kwargs)
|
||||
|
||||
|
||||
# Generate JWT token
|
||||
token = self._generate_jwt()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
# Make the API request
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
endpoint,
|
||||
json=data,
|
||||
headers=headers
|
||||
) as response:
|
||||
async with session.post(endpoint, json=data, headers=headers) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
|
||||
if response.status != 201:
|
||||
raise Exception(f"Failed to initiate call: {response_data}")
|
||||
|
||||
|
||||
return CallInitiationResult(
|
||||
call_id=response_data["uuid"],
|
||||
status=response_data.get("status", "started"),
|
||||
provider_metadata={
|
||||
"call_uuid": response_data["uuid"] # Vonage needs UUID persisted for WebSocket
|
||||
"call_uuid": response_data[
|
||||
"uuid"
|
||||
] # Vonage needs UUID persisted for WebSocket
|
||||
},
|
||||
raw_response=response_data
|
||||
raw_response=response_data,
|
||||
)
|
||||
|
||||
async def get_call_status(self, call_id: str) -> Dict[str, Any]:
|
||||
|
|
@ -147,21 +139,19 @@ class VonageProvider(TelephonyProvider):
|
|||
"""
|
||||
if not self.validate_config():
|
||||
raise ValueError("Vonage provider not properly configured")
|
||||
|
||||
|
||||
endpoint = f"{self.base_url}/v1/calls/{call_id}"
|
||||
|
||||
|
||||
# Generate JWT token
|
||||
token = self._generate_jwt()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(endpoint, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_data = await response.json()
|
||||
raise Exception(f"Failed to get call status: {error_data}")
|
||||
|
||||
|
||||
return await response.json()
|
||||
|
||||
async def get_available_phone_numbers(self) -> List[str]:
|
||||
|
|
@ -174,11 +164,7 @@ class VonageProvider(TelephonyProvider):
|
|||
"""
|
||||
Validate Vonage configuration.
|
||||
"""
|
||||
return bool(
|
||||
self.application_id and
|
||||
self.private_key and
|
||||
self.from_numbers
|
||||
)
|
||||
return bool(self.application_id and self.private_key and self.from_numbers)
|
||||
|
||||
async def verify_webhook_signature(
|
||||
self, url: str, params: Dict[str, Any], signature: str
|
||||
|
|
@ -190,14 +176,14 @@ class VonageProvider(TelephonyProvider):
|
|||
if not self.api_secret:
|
||||
logger.error("No API secret available for webhook signature verification")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Vonage sends JWT in Authorization header. Verify the JWT signature
|
||||
decoded = jwt.decode(
|
||||
signature,
|
||||
self.api_secret,
|
||||
signature,
|
||||
self.api_secret,
|
||||
algorithms=["HS256"],
|
||||
options={"verify_signature": True}
|
||||
options={"verify_signature": True},
|
||||
)
|
||||
return True
|
||||
except jwt.InvalidTokenError:
|
||||
|
|
@ -211,43 +197,42 @@ class VonageProvider(TelephonyProvider):
|
|||
NCCO (Nexmo Call Control Objects) is JSON-based, unlike TwiML which is XML.
|
||||
"""
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
|
||||
|
||||
# NCCO for WebSocket connection
|
||||
ncco = [
|
||||
{
|
||||
"action": "connect",
|
||||
"endpoint": [{
|
||||
"type": "websocket",
|
||||
"uri": f"wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}",
|
||||
"content-type": "audio/l16;rate=16000", # 16kHz Linear PCM
|
||||
"headers": {}
|
||||
}]
|
||||
"endpoint": [
|
||||
{
|
||||
"type": "websocket",
|
||||
"uri": f"wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}",
|
||||
"content-type": "audio/l16;rate=16000", # 16kHz Linear PCM
|
||||
"headers": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
return json.dumps(ncco)
|
||||
|
||||
def _get_auth_headers(self) -> Dict[str, str]:
|
||||
"""Generate authorization headers for Vonage API."""
|
||||
token = self._generate_jwt()
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
|
||||
async def get_call_cost(self, call_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cost information for a completed Vonage call.
|
||||
|
||||
|
||||
Args:
|
||||
call_id: The Vonage Call UUID
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing cost information
|
||||
"""
|
||||
headers = self._get_auth_headers()
|
||||
endpoint = f"https://api.nexmo.com/v1/calls/{call_id}"
|
||||
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(endpoint, headers=headers) as response:
|
||||
|
|
@ -258,39 +243,34 @@ class VonageProvider(TelephonyProvider):
|
|||
"cost_usd": 0.0,
|
||||
"duration": 0,
|
||||
"status": "error",
|
||||
"error": str(error_data)
|
||||
"error": str(error_data),
|
||||
}
|
||||
|
||||
|
||||
call_data = await response.json()
|
||||
|
||||
|
||||
# Vonage returns price and rate
|
||||
# Price is the total cost, rate is the per-minute rate
|
||||
price = float(call_data.get("price", 0))
|
||||
cost_usd = price # Vonage returns positive values
|
||||
|
||||
|
||||
# Duration is in seconds
|
||||
duration = int(call_data.get("duration", 0))
|
||||
|
||||
|
||||
# Get the call status
|
||||
status = call_data.get("status", "unknown")
|
||||
|
||||
|
||||
return {
|
||||
"cost_usd": cost_usd,
|
||||
"duration": duration,
|
||||
"status": status,
|
||||
"price_unit": "USD", # Vonage uses USD by default
|
||||
"rate": call_data.get("rate", 0), # Per-minute rate
|
||||
"raw_response": call_data
|
||||
"raw_response": call_data,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception fetching Vonage call cost: {e}")
|
||||
return {
|
||||
"cost_usd": 0.0,
|
||||
"duration": 0,
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
return {"cost_usd": 0.0, "duration": 0, "status": "error", "error": str(e)}
|
||||
|
||||
def parse_status_callback(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -300,14 +280,14 @@ class VonageProvider(TelephonyProvider):
|
|||
status_map = {
|
||||
"started": "initiated",
|
||||
"ringing": "ringing",
|
||||
"answered": "answered",
|
||||
"answered": "answered",
|
||||
"complete": "completed",
|
||||
"failed": "failed",
|
||||
"busy": "busy",
|
||||
"timeout": "no-answer",
|
||||
"rejected": "busy"
|
||||
"rejected": "busy",
|
||||
}
|
||||
|
||||
|
||||
return {
|
||||
"call_id": data.get("uuid", ""),
|
||||
"status": status_map.get(data.get("status", ""), data.get("status", "")),
|
||||
|
|
@ -315,7 +295,7 @@ class VonageProvider(TelephonyProvider):
|
|||
"to_number": data.get("to"),
|
||||
"direction": data.get("direction"),
|
||||
"duration": data.get("duration"),
|
||||
"extra": data # Include all original data
|
||||
"extra": data, # Include all original data
|
||||
}
|
||||
|
||||
async def handle_websocket(
|
||||
|
|
@ -327,14 +307,14 @@ class VonageProvider(TelephonyProvider):
|
|||
) -> None:
|
||||
"""
|
||||
Handle Vonage-specific WebSocket connection.
|
||||
|
||||
|
||||
Vonage can send:
|
||||
1. JSON metadata first (websocket:connected event)
|
||||
2. Or directly start with binary audio
|
||||
"""
|
||||
from api.db import db_client
|
||||
from api.services.pipecat.run_pipeline import run_pipeline_vonage
|
||||
|
||||
|
||||
try:
|
||||
# Get workflow run to extract call UUID
|
||||
workflow_run = await db_client.get_workflow_run(workflow_run_id)
|
||||
|
|
@ -342,38 +322,48 @@ class VonageProvider(TelephonyProvider):
|
|||
logger.error(f"Workflow run {workflow_run_id} not found")
|
||||
await websocket.close(code=4404, reason="Workflow run not found")
|
||||
return
|
||||
|
||||
|
||||
# Get workflow for organization info
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
if not workflow:
|
||||
logger.error(f"Workflow {workflow_id} not found")
|
||||
await websocket.close(code=4404, reason="Workflow not found")
|
||||
return
|
||||
|
||||
|
||||
# Extract call UUID from workflow run context
|
||||
call_uuid = workflow_run.gathered_context.get("call_uuid") if workflow_run.gathered_context else None
|
||||
|
||||
call_uuid = (
|
||||
workflow_run.gathered_context.get("call_uuid")
|
||||
if workflow_run.gathered_context
|
||||
else None
|
||||
)
|
||||
|
||||
if not call_uuid:
|
||||
logger.error(f"No call UUID found for Vonage connection in workflow run {workflow_run_id}")
|
||||
logger.error(
|
||||
f"No call UUID found for Vonage connection in workflow run {workflow_run_id}"
|
||||
)
|
||||
await websocket.close(code=4400, reason="Missing call UUID")
|
||||
return
|
||||
|
||||
logger.info(f"Vonage WebSocket connected for workflow_run {workflow_run_id}, call_uuid: {call_uuid}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Vonage WebSocket connected for workflow_run {workflow_run_id}, call_uuid: {call_uuid}"
|
||||
)
|
||||
|
||||
# Peek at first message to see if it's metadata or audio
|
||||
first_msg = await websocket.receive()
|
||||
|
||||
|
||||
if "text" in first_msg:
|
||||
# JSON metadata - check if it's the connection event
|
||||
msg = json.loads(first_msg["text"])
|
||||
if msg.get("event") == "websocket:connected":
|
||||
logger.debug(f"Received Vonage connection confirmation for {workflow_run_id}")
|
||||
logger.debug(
|
||||
f"Received Vonage connection confirmation for {workflow_run_id}"
|
||||
)
|
||||
# Continue to pipeline regardless of message type
|
||||
elif "bytes" in first_msg:
|
||||
# Binary audio - Vonage started with audio immediately
|
||||
logger.debug(f"Vonage started with binary audio for {workflow_run_id}")
|
||||
# The pipeline will handle this first audio chunk
|
||||
|
||||
|
||||
# Run the Vonage pipeline
|
||||
await run_pipeline_vonage(
|
||||
websocket,
|
||||
|
|
@ -382,9 +372,9 @@ class VonageProvider(TelephonyProvider):
|
|||
workflow.organization_id,
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
user_id
|
||||
user_id,
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Vonage WebSocket handler: {e}")
|
||||
raise
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -22,9 +22,7 @@ from pipecat.frames.frames import (
|
|||
)
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import (
|
||||
BaseOutputTransport
|
||||
)
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,14 +14,14 @@ from pipecat.frames.frames import (
|
|||
CancelFrame,
|
||||
EndFrame,
|
||||
FunctionCallResultProperties,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ class PipecatEngine:
|
|||
*,
|
||||
task: Optional[PipelineTask] = None,
|
||||
llm: Optional["LLMService"] = None,
|
||||
context: Optional[OpenAILLMContext] = None,
|
||||
context: Optional[LLMContext] = None,
|
||||
tts: Optional[Any] = None,
|
||||
transport: Optional[BaseTransport] = None,
|
||||
workflow: WorkflowGraph,
|
||||
|
|
@ -82,7 +82,6 @@ class PipecatEngine:
|
|||
self._workflow_run_id = workflow_run_id
|
||||
self._initialized = False
|
||||
self._client_disconnected = False
|
||||
self._pending_function_calls = 0
|
||||
self._current_node: Optional[Node] = None
|
||||
self._gathered_context: dict = {}
|
||||
self._user_response_timeout_task: Optional[asyncio.Task] = None
|
||||
|
|
@ -102,29 +101,9 @@ class PipecatEngine:
|
|||
self._voicemail_detector = None
|
||||
self._voicemail_detection_task: Optional[asyncio.Task] = None
|
||||
|
||||
# This transition is generated by the llm as part of tool call. This can
|
||||
# also be accompanied with some content which can be played using TTS. If the
|
||||
# bot is interrupted, we would cancel this transition (we do cancel this currently when
|
||||
# the next generation starts in handle_generation_started callback handler.)
|
||||
self._pending_generated_transition_after_context_push: Optional[
|
||||
Callable[[], Awaitable[None]]
|
||||
] = None
|
||||
|
||||
# This is the transtion which is typically programmatic transition, and not goes as
|
||||
# tool call to LLM. This is not interrupted by the user and is done on context push
|
||||
self._pending_control_transition_after_context_push: Optional[
|
||||
Callable[[], Awaitable[None]]
|
||||
] = None
|
||||
|
||||
# Flag to determine if the current llm generation has a text completion
|
||||
self._defer_context_push: bool = False
|
||||
|
||||
# Lazy loaded built-in function schemas
|
||||
self._builtin_function_schemas: Optional[list[dict]] = None
|
||||
|
||||
# Flag to control whether to queue context frame
|
||||
self._queue_context_frame: bool = True
|
||||
|
||||
# Track current LLM reference text for TTS aggregation correction
|
||||
self._current_llm_reference_text: str = ""
|
||||
|
||||
|
|
@ -211,23 +190,15 @@ class PipecatEngine:
|
|||
|
||||
async def _create_transition_func(self, name: str, transition_to_node: str):
|
||||
async def transition_func(function_call_params: FunctionCallParams) -> None:
|
||||
"""Inner function that handles the actual tool invocation."""
|
||||
"""Inner function that handles the node change tool calls"""
|
||||
try:
|
||||
# Track pending function call
|
||||
self._pending_function_calls += 1
|
||||
logger.debug(
|
||||
f"Function call pending: {function_call_params.function_name} (total: {self._pending_function_calls})"
|
||||
)
|
||||
|
||||
# For edge functions, prevent LLM completion until transition (run_llm=False)
|
||||
# For node functions, allow immediate completion (run_llm=True)
|
||||
async def on_context_updated() -> None:
|
||||
"""
|
||||
Framework will run this function after the function call result has been updated in the context.
|
||||
pipecat framework will run this function after the function call result has been updated in the context.
|
||||
This way, when we do set_node from within this function, and go for LLM completion with updated
|
||||
system prompts, the context is updated with function call result.
|
||||
"""
|
||||
self._pending_function_calls -= 1
|
||||
# Perform variable extraction before transitioning to new node
|
||||
await self._perform_variable_extraction_if_needed(
|
||||
self._current_node
|
||||
|
|
@ -241,41 +212,14 @@ class PipecatEngine:
|
|||
on_context_updated=on_context_updated,
|
||||
)
|
||||
|
||||
async def _invoke_result_callback():
|
||||
"""
|
||||
Functions are executed immediately when they come from LLM as part of text completion.
|
||||
But, if the LLM completion also has some text, we would want to not call the function if the user interrupts the speech.
|
||||
We would also not want the function to be added to context, so that the LLM can call the function again. Hence, we
|
||||
defer the function invocation until we receive on_context_updated callback, i.e the bot has finished speaking
|
||||
the text that was generated.
|
||||
"""
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
|
||||
if self._defer_context_push:
|
||||
"""
|
||||
We set the flag to _defer_context_push when we receive text in the current generation from LLM.
|
||||
This is set in the handle_llm_generated_text callback handler.
|
||||
"""
|
||||
logger.debug(
|
||||
"Deferring transition function result until context push"
|
||||
)
|
||||
# Only one deferred transition should exist at any time.
|
||||
# Overwrite if one is somehow already set (unexpected).
|
||||
self._pending_generated_transition_after_context_push = (
|
||||
_invoke_result_callback
|
||||
)
|
||||
else:
|
||||
"""
|
||||
If there was no text in the current generation, and we only had function call,
|
||||
lets invoke the result callback, so that framework can call on_context_updated and
|
||||
we can do switch node.
|
||||
"""
|
||||
await _invoke_result_callback()
|
||||
# Call results callback from the pipecat framework
|
||||
# so that a new llm generation can be triggred if
|
||||
# required
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in transition function {name}: {str(e)}")
|
||||
self._pending_function_calls = 0
|
||||
error_result = {"status": "error", "error": str(e)}
|
||||
await function_call_params.result_callback(error_result)
|
||||
|
||||
|
|
@ -362,27 +306,6 @@ class PipecatEngine:
|
|||
]
|
||||
)
|
||||
|
||||
async def _setup_static_start_node_transition(self, node: Node) -> None:
|
||||
"""Set up the deferred transition for static start nodes."""
|
||||
if not node.out_edges:
|
||||
return
|
||||
|
||||
next_node_id = node.out_edges[0].target
|
||||
|
||||
if not node.wait_for_user_response:
|
||||
# Normal static start node - transition immediately after context push
|
||||
async def _deferred_static_transition():
|
||||
try:
|
||||
await self.set_node(next_node_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error executing deferred static node transition to {next_node_id}: {exc}"
|
||||
)
|
||||
|
||||
self._pending_control_transition_after_context_push = (
|
||||
_deferred_static_transition
|
||||
)
|
||||
|
||||
async def _perform_variable_extraction_if_needed(
|
||||
self, previous_node: Optional[Node]
|
||||
) -> None:
|
||||
|
|
@ -441,17 +364,7 @@ class PipecatEngine:
|
|||
functions,
|
||||
) = await self._compose_system_message_functions_for_node(node)
|
||||
await self._update_llm_context(system_message, functions)
|
||||
|
||||
# Queue context frame if needed
|
||||
if self._queue_context_frame:
|
||||
await self.task.queue_frame(OpenAILLMContextFrame(self.context))
|
||||
else:
|
||||
logger.debug(
|
||||
f"Not queueing context frame for node: {node.name} as _queue_context_frame is False"
|
||||
)
|
||||
|
||||
# Reset _queue_context_frame as default behavior
|
||||
self._queue_context_frame = True
|
||||
await self.task.queue_frame(LLMContextFrame(self.context))
|
||||
|
||||
async def set_node(self, node_id: str):
|
||||
"""
|
||||
|
|
@ -525,12 +438,7 @@ class PipecatEngine:
|
|||
await asyncio.sleep(delay_duration)
|
||||
|
||||
if node.is_static:
|
||||
# Queue TTS for static start node
|
||||
formatted_prompt = self._format_prompt(node.prompt)
|
||||
await self._queue_tts_response(formatted_prompt)
|
||||
|
||||
# Set up deferred transition for static start nodes
|
||||
await self._setup_static_start_node_transition(node)
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
# Start generation for non-static start node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
|
@ -538,66 +446,24 @@ class PipecatEngine:
|
|||
async def _handle_end_node(self, node: Node) -> None:
|
||||
"""Handle end node execution."""
|
||||
if node.is_static:
|
||||
# Queue TTS for static end node
|
||||
formatted_prompt = self._format_prompt(node.prompt)
|
||||
await self._queue_tts_response(formatted_prompt)
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
# Start generation for non-static end node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
||||
# If this end node has extraction enabled, perform extraction immediately
|
||||
if node.extraction_enabled and node.extraction_variables:
|
||||
await self._perform_variable_extraction_if_needed(node)
|
||||
|
||||
# TODO: Extract disposition code from extracted variables
|
||||
# Defer send_end_task_frame using _pending_control_transition_after_context_push
|
||||
|
||||
# Decide the end-task reason dynamically depending on call_disposition.
|
||||
async def _deferred_end_task():
|
||||
# call_disposition is the disposition which is generated from
|
||||
# llm call based on the conversation so far.
|
||||
# TODO: Make this more generic based on configuration or llm prompting
|
||||
disposition = self._gathered_context.get("call_disposition")
|
||||
if disposition == "XFER":
|
||||
reason = EndTaskReason.USER_QUALIFIED.value
|
||||
else:
|
||||
reason = EndTaskReason.USER_DISQUALIFIED.value
|
||||
await self.send_end_task_frame(reason)
|
||||
|
||||
self._pending_control_transition_after_context_push = _deferred_end_task
|
||||
await self.send_end_task_frame(EndTaskReason.USER_QUALIFIED.value)
|
||||
|
||||
async def _handle_agent_node(self, node: Node) -> None:
|
||||
"""Handle agent node execution."""
|
||||
if node.is_static:
|
||||
# Queue TTS for static agent node
|
||||
formatted_prompt = self._format_prompt(node.prompt)
|
||||
await self._queue_tts_response(formatted_prompt)
|
||||
|
||||
# Set up deferred transition for static agent nodes
|
||||
await self._setup_agent_node_transition(node)
|
||||
raise ValueError("Static nodes are not supported!")
|
||||
else:
|
||||
# Set context and functions for non-static agent node
|
||||
await self._setup_llm_context_and_start_generation(node)
|
||||
|
||||
async def _setup_agent_node_transition(self, node: Node) -> None:
|
||||
"""Set up the deferred transition for static agent nodes."""
|
||||
if not node.out_edges:
|
||||
return
|
||||
|
||||
next_node_id = node.out_edges[0].target
|
||||
|
||||
async def _deferred_static_transition():
|
||||
try:
|
||||
await self.set_node(next_node_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error executing deferred static node transition to {next_node_id}: {exc}"
|
||||
)
|
||||
|
||||
self._pending_control_transition_after_context_push = (
|
||||
_deferred_static_transition
|
||||
)
|
||||
|
||||
async def send_end_task_frame(
|
||||
self,
|
||||
reason: str,
|
||||
|
|
@ -640,7 +506,7 @@ class PipecatEngine:
|
|||
# Store the mapped disconnect reason
|
||||
self._gathered_context["call_disposition"] = mapped_disposition
|
||||
|
||||
# TODO: Generalise this, currently tailored to Kapil's use case
|
||||
# TODO: Generalise this
|
||||
self._gathered_context["address"] = ", ".join(
|
||||
[
|
||||
self._call_context_vars.get("address1", ""),
|
||||
|
|
@ -759,55 +625,6 @@ class PipecatEngine:
|
|||
|
||||
return system_message, functions
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pending transition handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def flush_pending_transitions(self, *, source: str = "context_push"):
|
||||
"""Execute and clear any pending transitions.
|
||||
|
||||
Args:
|
||||
source: Indicates the trigger that caused this flush:
|
||||
- "context_push": the assistant context aggregator completed a push.
|
||||
"""
|
||||
|
||||
if source != "context_push":
|
||||
raise ValueError("Invalid flush source – expected 'context_push'")
|
||||
|
||||
len_pending_functions = 0
|
||||
|
||||
if self._pending_generated_transition_after_context_push is not None:
|
||||
len_pending_functions += 1
|
||||
if self._pending_control_transition_after_context_push is not None:
|
||||
len_pending_functions += 1
|
||||
|
||||
# Nothing to do
|
||||
if len_pending_functions == 0:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Flushing {len_pending_functions} pending transition(s) after {source.replace('_', ' ')}"
|
||||
)
|
||||
|
||||
# Generated transition
|
||||
if self._pending_generated_transition_after_context_push is not None:
|
||||
pending_cb = self._pending_generated_transition_after_context_push
|
||||
self._pending_generated_transition_after_context_push = None
|
||||
try:
|
||||
await pending_cb()
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.error(f"Error executing deferred transition: {exc}")
|
||||
|
||||
# Control transition (context push)
|
||||
if self._pending_control_transition_after_context_push is not None:
|
||||
logger.debug("Executing control transition after context push")
|
||||
static_cb = self._pending_control_transition_after_context_push
|
||||
self._pending_control_transition_after_context_push = None
|
||||
try:
|
||||
await static_cb()
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.error(f"Error executing deferred static node transition: {exc}")
|
||||
|
||||
def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]:
|
||||
"""
|
||||
This callback is called by STTMuteFilter to determine if the STT should be muted.
|
||||
|
|
@ -828,15 +645,6 @@ class PipecatEngine:
|
|||
"""
|
||||
return engine_callbacks.create_max_duration_callback(self)
|
||||
|
||||
def create_llm_generated_text_callback(self):
|
||||
"""
|
||||
This callback is called when some text is generated by the LLM.
|
||||
We use this to defer the result_callback of the node transition functions if
|
||||
there is set_node called along with some text generated. This way, we will
|
||||
have the context sent in the next generation from new node.
|
||||
"""
|
||||
return engine_callbacks.create_llm_generated_text_callback(self)
|
||||
|
||||
def create_generation_started_callback(self):
|
||||
"""
|
||||
This callback is called when a new generation starts.
|
||||
|
|
@ -844,26 +652,12 @@ class PipecatEngine:
|
|||
"""
|
||||
return engine_callbacks.create_generation_started_callback(self)
|
||||
|
||||
def create_user_stopped_speaking_callback(self):
|
||||
"""
|
||||
This callback is called when the user stops speaking.
|
||||
We use this to handle transitions when wait_for_user_response is enabled.
|
||||
"""
|
||||
return engine_callbacks.create_user_stopped_speaking_callback(self)
|
||||
|
||||
def create_user_started_speaking_callback(self):
|
||||
"""
|
||||
This callback is called when the user starts speaking.
|
||||
We use this to handle wait_for_user_greeting functionality.
|
||||
"""
|
||||
return engine_callbacks.create_user_started_speaking_callback(self)
|
||||
|
||||
def create_aggregation_correction_callback(self) -> Callable[[str], str]:
|
||||
"""Create a callback that corrects corrupted aggregation using reference text."""
|
||||
return engine_callbacks.create_aggregation_correction_callback(self)
|
||||
|
||||
def set_context(self, context: OpenAILLMContext) -> None:
|
||||
"""Set the OpenAI LLM context.
|
||||
def set_context(self, context: LLMContext) -> None:
|
||||
"""Set the LLM context.
|
||||
|
||||
This allows setting the context after the engine has been created,
|
||||
which is useful when the context needs to be created after the engine.
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import re
|
|||
from typing import TYPE_CHECKING, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
|
|
@ -23,9 +24,8 @@ from pipecat.processors.filters.stt_mute_filter import STTMuteFilter
|
|||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.user_idle_processor import UserIdleProcessor
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from pipecat.processors.user_idle_processor import UserIdleProcessor
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -114,23 +114,6 @@ def create_max_duration_callback(engine: "PipecatEngine"):
|
|||
return handle_max_duration
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM-generated-text handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_llm_generated_text_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback invoked when the LLM emits text (not only tool calls)."""
|
||||
|
||||
async def handle_llm_generated_text(): # noqa: D401
|
||||
logger.debug(
|
||||
"Generation has text content in current response - deferring context push from set_node"
|
||||
)
|
||||
engine._defer_context_push = True
|
||||
|
||||
return handle_llm_generated_text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Generation-started handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -140,96 +123,13 @@ def create_generation_started_callback(engine: "PipecatEngine"):
|
|||
"""Return a callback that resets flags at the start of each LLM generation."""
|
||||
|
||||
async def handle_generation_started(): # noqa: D401
|
||||
logger.debug("LLM generation started - resetting defer flags and tool counters")
|
||||
engine._defer_context_push = False
|
||||
engine._pending_function_calls = 0
|
||||
engine._pending_generated_transition_after_context_push = None
|
||||
logger.debug("LLM generation started in callback processor")
|
||||
# Clear reference text from previous generation
|
||||
engine._current_llm_reference_text = ""
|
||||
|
||||
return handle_generation_started
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-stopped-speaking handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_user_stopped_speaking_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that handles when the user stops speaking.
|
||||
|
||||
According to simplified flow:
|
||||
- For start nodes with wait_for_user_response=True:
|
||||
- Cancel timeout task if still active
|
||||
- Transition to next node with _queue_context_frame=False
|
||||
"""
|
||||
|
||||
async def handle_user_stopped_speaking():
|
||||
# Only handle if current node is a start node with wait_for_user_response
|
||||
if (
|
||||
engine._current_node
|
||||
and engine._current_node.is_start
|
||||
and engine._current_node.wait_for_user_response
|
||||
and engine._current_node.out_edges
|
||||
):
|
||||
# Cancel timeout task if it's still active
|
||||
if (
|
||||
engine._user_response_timeout_task
|
||||
and not engine._user_response_timeout_task.done()
|
||||
):
|
||||
logger.debug("Cancelling user response timeout - user responded")
|
||||
engine._user_response_timeout_task.cancel()
|
||||
engine._user_response_timeout_task = None
|
||||
|
||||
# Transition to next node
|
||||
next_node_id = engine._current_node.out_edges[0].target
|
||||
logger.debug(
|
||||
f"User stopped speaking after wait_for_user_response - transitioning to: {next_node_id}"
|
||||
)
|
||||
|
||||
# Set flag to not queue context frame since
|
||||
# it will be pushed by user context aggregator
|
||||
# we are just setting the context with next node's
|
||||
# functions and prompts
|
||||
engine._queue_context_frame = False
|
||||
|
||||
# Transition to next node
|
||||
await engine.set_node(next_node_id)
|
||||
|
||||
return handle_user_stopped_speaking
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-started-speaking handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_user_started_speaking_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that handles when the user starts speaking.
|
||||
|
||||
According to simplified flow:
|
||||
- For start nodes with wait_for_user_response=True:
|
||||
- Cancel the timeout timer if it exists (but don't set to None)
|
||||
"""
|
||||
|
||||
async def handle_user_started_speaking():
|
||||
# Only handle if current node is a start node with wait_for_user_response
|
||||
if (
|
||||
engine._current_node
|
||||
and engine._current_node.is_start
|
||||
and engine._current_node.wait_for_user_response
|
||||
and engine._user_response_timeout_task
|
||||
and not engine._user_response_timeout_task.done()
|
||||
):
|
||||
logger.debug(
|
||||
"User started speaking during wait_for_user_response - cancelling timeout timer"
|
||||
)
|
||||
engine._user_response_timeout_task.cancel()
|
||||
# Don't set to None here - let user_stopped_speaking handle the transition
|
||||
|
||||
return handle_user_started_speaking
|
||||
|
||||
|
||||
def create_aggregation_correction_callback(engine: "PipecatEngine"):
|
||||
"""Create a callback that uses engine's reference text to correct corrupted aggregation."""
|
||||
|
||||
|
|
|
|||
|
|
@ -2,16 +2,10 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from google.genai.types import (
|
||||
Content,
|
||||
Part,
|
||||
)
|
||||
from api.utils.template_renderer import render_template
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.services.google.llm import GoogleLLMContext
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.utils.template_renderer import render_template
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
||||
__all__ = [
|
||||
"get_function_schema",
|
||||
|
|
@ -44,7 +38,7 @@ def get_function_schema(
|
|||
|
||||
|
||||
def update_llm_context(
|
||||
context: OpenAILLMContext,
|
||||
context: LLMContext,
|
||||
system_message: Dict[str, Any],
|
||||
functions: List[FunctionSchema],
|
||||
) -> None:
|
||||
|
|
@ -59,21 +53,6 @@ def update_llm_context(
|
|||
# associated with the current LLM service can convert them to the correct
|
||||
# provider-specific representation when required.
|
||||
tools_schema = ToolsSchema(standard_tools=functions)
|
||||
|
||||
if isinstance(context, GoogleLLMContext):
|
||||
context.system_message = system_message["content"]
|
||||
|
||||
if functions:
|
||||
# Lets only call set_tools if we have functions, else Gemini will
|
||||
# throw an exception
|
||||
context.set_tools(tools_schema)
|
||||
|
||||
if context.messages[-1].role != "user":
|
||||
# Google expects the last message should end with user message
|
||||
context.add_message(Content(role="user", parts=[Part(text="...")]))
|
||||
return
|
||||
|
||||
# In case of OpenAILLMContext, replace the system message with incoming system message
|
||||
previous_interactions = context.messages
|
||||
|
||||
# Filter out old system messages but keep user/assistant/function content.
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@ from typing import TYPE_CHECKING, Any, List
|
|||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
from opentelemetry import trace
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
from pipecat.utils.tracing.service_attributes import add_llm_span_attributes
|
||||
|
||||
from api.services.pipecat.tracing_config import is_tracing_enabled
|
||||
from api.services.workflow.dto import ExtractionVariableDTO
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.utils.tracing.service_attributes import add_llm_span_attributes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
|
@ -139,7 +139,7 @@ class VariableExtractionManager:
|
|||
f"{conversation_history}"
|
||||
)
|
||||
|
||||
extraction_context = OpenAILLMContext()
|
||||
extraction_context = LLMContext()
|
||||
extraction_messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
|
|
@ -171,7 +171,7 @@ class VariableExtractionManager:
|
|||
service_name="OpenAILLMService",
|
||||
model=self._model,
|
||||
operation_name="variable_extraction",
|
||||
messages=json.dumps(extraction_messages),
|
||||
messages=extraction_messages,
|
||||
output=llm_response,
|
||||
stream=False,
|
||||
parameters={"temperature": 0.0, "response_format": "json_object"},
|
||||
|
|
|
|||
|
|
@ -44,8 +44,6 @@ class Node:
|
|||
self.extraction_prompt = data.extraction_prompt
|
||||
self.extraction_variables = data.extraction_variables
|
||||
self.add_global_prompt = data.add_global_prompt
|
||||
self.wait_for_user_response = data.wait_for_user_response
|
||||
self.wait_for_user_response_timeout = data.wait_for_user_response_timeout
|
||||
self.detect_voicemail = data.detect_voicemail
|
||||
self.delayed_start = data.delayed_start
|
||||
self.delayed_start_duration = data.delayed_start_duration
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ import os
|
|||
import aiohttp
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import IntegrationModel
|
||||
from api.enums import OrganizationConfigurationKey, WorkflowRunMode
|
||||
from api.utils.template_renderer import render_template
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
|
||||
|
||||
async def run_integrations_post_workflow_run(ctx, workflow_run_id: int):
|
||||
|
|
@ -162,7 +162,7 @@ async def _process_slack_integration(
|
|||
"""
|
||||
logger.info(f"Processing Slack integration {integration.id}")
|
||||
|
||||
# TODO: Generalise this, currently tailored to Kapil's use case
|
||||
# TODO: Generalise this
|
||||
if gathered_context.get("mapped_call_disposition") != "XFER":
|
||||
logger.debug(
|
||||
f"Not sending message on slack since not XFER: {gathered_context.get('mapped_call_disposition')}"
|
||||
|
|
|
|||
|
|
@ -28,20 +28,24 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
|
|||
|
||||
# Fetch telephony call cost for both Twilio and Vonage
|
||||
telephony_cost_usd = 0.0
|
||||
if workflow_run.mode in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value] and workflow_run.cost_info:
|
||||
if (
|
||||
workflow_run.mode
|
||||
in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
|
||||
and workflow_run.cost_info
|
||||
):
|
||||
# Get the call ID (provider-agnostic approach with backward compatibility)
|
||||
call_id = workflow_run.cost_info.get("call_id")
|
||||
|
||||
|
||||
# Fallback to legacy provider-specific fields if needed
|
||||
if not call_id:
|
||||
if workflow_run.mode == WorkflowRunMode.TWILIO.value:
|
||||
call_id = workflow_run.cost_info.get("twilio_call_sid")
|
||||
elif workflow_run.mode == WorkflowRunMode.VONAGE.value:
|
||||
call_id = workflow_run.cost_info.get("vonage_call_uuid")
|
||||
|
||||
|
||||
# Provider name is derived from workflow run mode
|
||||
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
|
||||
|
||||
|
||||
if call_id:
|
||||
try:
|
||||
# Get workflow to access organization_id
|
||||
|
|
@ -55,12 +59,14 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
|
|||
# Use telephony provider abstraction
|
||||
provider = await get_telephony_provider(workflow.organization_id)
|
||||
call_cost_info = await provider.get_call_cost(call_id)
|
||||
|
||||
|
||||
if call_cost_info.get("status") != "error":
|
||||
telephony_cost_usd = call_cost_info.get("cost_usd", 0.0)
|
||||
cost_breakdown["telephony_call"] = telephony_cost_usd
|
||||
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd # Keep backward compatibility
|
||||
|
||||
cost_breakdown[f"{provider_name}_call"] = (
|
||||
telephony_cost_usd # Keep backward compatibility
|
||||
)
|
||||
|
||||
# Add telephony cost to the total
|
||||
cost_breakdown["total"] = (
|
||||
float(cost_breakdown["total"]) + telephony_cost_usd
|
||||
|
|
@ -69,8 +75,10 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
|
|||
f"{provider_name.title()} call cost: ${telephony_cost_usd:.6f} USD for call {call_id}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}")
|
||||
|
||||
logger.error(
|
||||
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch telephony call cost: {e}")
|
||||
# Don't fail the whole cost calculation if telephony API fails
|
||||
|
|
@ -119,7 +127,9 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
|
|||
elif "twilio_call_sid" in workflow_run.cost_info:
|
||||
cost_info["twilio_call_sid"] = workflow_run.cost_info["twilio_call_sid"]
|
||||
elif "vonage_call_uuid" in workflow_run.cost_info:
|
||||
cost_info["vonage_call_uuid"] = workflow_run.cost_info["vonage_call_uuid"]
|
||||
cost_info["vonage_call_uuid"] = workflow_run.cost_info[
|
||||
"vonage_call_uuid"
|
||||
]
|
||||
|
||||
# Update workflow run with cost information
|
||||
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
|
||||
|
|
|
|||
|
|
@ -1,179 +0,0 @@
|
|||
### - This test has some weird loop which keeps on increasing the context size
|
||||
|
||||
# import asyncio
|
||||
# import json
|
||||
# import unittest
|
||||
# from types import SimpleNamespace
|
||||
# from unittest import mock
|
||||
|
||||
# from loguru import logger
|
||||
|
||||
# from pipecat.frames.frames import (
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# FunctionCallsStartedFrame,
|
||||
# LLMFullResponseEndFrame,
|
||||
# LLMFullResponseStartFrame,
|
||||
# LLMGeneratedTextFrame,
|
||||
# LLMTextFrame,
|
||||
# )
|
||||
# from pipecat.pipeline.pipeline import Pipeline
|
||||
# from pipecat.processors.aggregators.openai_llm_context import (
|
||||
# OpenAILLMContext,
|
||||
# OpenAILLMContextFrame,
|
||||
# )
|
||||
# from pipecat.services.llm_service import (
|
||||
# FunctionCallParams,
|
||||
# FunctionCallResultProperties,
|
||||
# )
|
||||
# from pipecat.services.openai.llm import OpenAILLMService
|
||||
# from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
# class _MockAsyncStream:
|
||||
# """A minimal async-stream wrapper that mimics ``openai.AsyncStream``."""
|
||||
|
||||
# def __init__(self, chunks):
|
||||
# self._chunks = chunks
|
||||
|
||||
# def __aiter__(self):
|
||||
# self._idx = 0
|
||||
# return self
|
||||
|
||||
# async def __anext__(self):
|
||||
# if self._idx >= len(self._chunks):
|
||||
# raise StopAsyncIteration
|
||||
# item = self._chunks[self._idx]
|
||||
# self._idx += 1
|
||||
# await asyncio.sleep(0) # Yield control
|
||||
# return item
|
||||
|
||||
|
||||
# # ------------------------------------------------------------------
|
||||
# # Factories for mock chunks
|
||||
# # ------------------------------------------------------------------
|
||||
|
||||
|
||||
# def _make_tool_call(tool_name: str, args_json: str, *, idx: int = 0):
|
||||
# function = SimpleNamespace(name=tool_name, arguments=args_json)
|
||||
# return SimpleNamespace(index=idx, id=f"call-{idx}", function=function)
|
||||
|
||||
|
||||
# def _make_chunk(*, content: str | None = None, tool_calls=None, usage=None):
|
||||
# delta = SimpleNamespace()
|
||||
# # When we are asked to simulate multiple tool calls in parallel, OpenAI
|
||||
# # sends *separate* chunks for every tool-call index. To mimic that behaviour
|
||||
# # in tests we split a list of tool calls (>1) into individual chunks – one
|
||||
# # for each tool call – while keeping the original single-chunk behaviour
|
||||
# # when zero or one tool calls are supplied. This enables us to write
|
||||
# # concise tests such as ``_make_chunk(tool_calls=[call_1, call_2])`` that
|
||||
# # accurately reflect the streaming protocol.
|
||||
|
||||
# # No special handling needed if there is textual content or 0/1 tool calls.
|
||||
# if content is not None or tool_calls is None or len(tool_calls) <= 1:
|
||||
# if content is not None:
|
||||
# delta.content = content
|
||||
# # Always set tool_calls so downstream code can safely access it
|
||||
# delta.tool_calls = tool_calls if tool_calls is not None else None
|
||||
# return SimpleNamespace(choices=[SimpleNamespace(delta=delta)], usage=usage)
|
||||
|
||||
# # --- Multiple tool calls (len(tool_calls) > 1) ---
|
||||
# # Create a list of chunks, each containing a single tool call. This is the
|
||||
# # format produced by the OpenAI client when several tools are invoked in a
|
||||
# # single assistant response.
|
||||
# chunks = []
|
||||
# for tc in tool_calls:
|
||||
# delta_tc = SimpleNamespace(tool_calls=[tc])
|
||||
# chunks.append(SimpleNamespace(choices=[SimpleNamespace(delta=delta_tc)], usage=usage))
|
||||
|
||||
# return chunks
|
||||
|
||||
|
||||
# class TestBaseOpenAILLMService(unittest.IsolatedAsyncioTestCase):
|
||||
# async def test_process_context_with_patch(self):
|
||||
# streamed_text = "Hello from OpenAI!"
|
||||
# tool_name = "echo"
|
||||
# tool_name_2 = "echo_2"
|
||||
# tool_args = {"text": "hello"}
|
||||
# tool_args_2 = {"text": "hello_2"}
|
||||
|
||||
# # Build mocked stream (tool call first, then text)
|
||||
# chunks = [
|
||||
# _make_chunk(content=streamed_text),
|
||||
# _make_chunk(tool_calls=[_make_tool_call(tool_name, json.dumps(tool_args))]),
|
||||
# _make_chunk(tool_calls=[_make_tool_call(tool_name_2, json.dumps(tool_args_2), idx=1)]),
|
||||
# ]
|
||||
|
||||
# # Instantiate real OpenAILLMService (no need for actual API key)
|
||||
# llm = OpenAILLMService(model="gpt-4o-mini", api_key="test")
|
||||
|
||||
# # Patch get_chat_completions to return our mocked async stream
|
||||
# async def fake_get_chat_completions(self, context, messages): # noqa: D401
|
||||
# return _MockAsyncStream(chunks)
|
||||
|
||||
# with mock.patch.object(llm.__class__, "get_chat_completions", fake_get_chat_completions):
|
||||
# # Register echo tool
|
||||
# executed = False
|
||||
|
||||
# async def echo_handler(params: FunctionCallParams):
|
||||
# nonlocal executed
|
||||
# executed = True
|
||||
# # sleep for 1 second
|
||||
# logger.info("echo_handler: sleeping for 5 second")
|
||||
# await asyncio.sleep(5)
|
||||
# await params.result_callback(
|
||||
# {"ok": True},
|
||||
# properties=FunctionCallResultProperties(run_llm=True),
|
||||
# )
|
||||
|
||||
# async def echo_2_handler(params: FunctionCallParams):
|
||||
# nonlocal executed
|
||||
# executed = True
|
||||
# # sleep for 1 second
|
||||
# logger.info("echo_2_handler: sleeping for 5 second")
|
||||
# await asyncio.sleep(5)
|
||||
# await params.result_callback(
|
||||
# {"ok": True},
|
||||
# properties=FunctionCallResultProperties(run_llm=True),
|
||||
# )
|
||||
|
||||
# llm.register_function(tool_name, echo_handler)
|
||||
# llm.register_function(tool_name_2, echo_2_handler)
|
||||
|
||||
# # Prepare context and send
|
||||
# context = OpenAILLMContext()
|
||||
# context.add_message({"role": "user", "content": "Hi"})
|
||||
# frames_to_send = [OpenAILLMContextFrame(context)]
|
||||
|
||||
# expected_down_frames = [
|
||||
# LLMFullResponseStartFrame,
|
||||
# FunctionCallsStartedFrame,
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# LLMGeneratedTextFrame,
|
||||
# LLMTextFrame,
|
||||
# LLMFullResponseEndFrame,
|
||||
# ]
|
||||
|
||||
# context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# pipeline = Pipeline([llm, context_aggregator.assistant()])
|
||||
|
||||
# down_frames, _ = await run_test(
|
||||
# pipeline,
|
||||
# frames_to_send=frames_to_send,
|
||||
# expected_down_frames=expected_down_frames,
|
||||
# send_end_frame=False,
|
||||
# )
|
||||
|
||||
# # Assertions
|
||||
# self.assertTrue(executed)
|
||||
# for fr in down_frames:
|
||||
# if isinstance(fr, FunctionCallResultFrame):
|
||||
# self.assertTrue(fr.run_llm)
|
||||
# if isinstance(fr, LLMTextFrame):
|
||||
# self.assertEqual(fr.text, streamed_text)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# unittest.main()
|
||||
|
|
@ -1,143 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Test script to verify that LLMGeneratedTextFrame signaling works correctly
|
||||
with the new local variable approach.
|
||||
"""
|
||||
|
||||
|
||||
def test_local_variable_logic():
|
||||
"""Test the core logic using the same pattern as the implementation"""
|
||||
|
||||
print("=== Testing Local Variable Logic ===")
|
||||
|
||||
# Simulate the logic from _process_context
|
||||
text_generation_signaled = False
|
||||
frames_sent = []
|
||||
|
||||
# Simulate chunks with text content
|
||||
chunks_with_content = ["Hello", " world", "!"]
|
||||
|
||||
for content in chunks_with_content:
|
||||
# This is the exact logic from our implementation
|
||||
if content: # equivalent to chunk.choices[0].delta.content
|
||||
if not text_generation_signaled:
|
||||
frames_sent.append("LLMGeneratedTextFrame")
|
||||
text_generation_signaled = True
|
||||
frames_sent.append(f"LLMTextFrame({content})")
|
||||
|
||||
print(f"Frames sent: {frames_sent}")
|
||||
|
||||
# Verify behavior
|
||||
generated_signals = [f for f in frames_sent if f == "LLMGeneratedTextFrame"]
|
||||
text_frames = [f for f in frames_sent if f.startswith("LLMTextFrame")]
|
||||
|
||||
assert len(generated_signals) == 1, (
|
||||
f"Expected 1 signal, got {len(generated_signals)}"
|
||||
)
|
||||
assert len(text_frames) == 3, f"Expected 3 text frames, got {len(text_frames)}"
|
||||
assert frames_sent[0] == "LLMGeneratedTextFrame", "Signal should be first"
|
||||
|
||||
print("✅ Local variable logic works correctly")
|
||||
return True
|
||||
|
||||
|
||||
def test_no_text_logic():
|
||||
"""Test that no signal is sent when there's no text"""
|
||||
|
||||
print("\n=== Testing No Text Logic ===")
|
||||
|
||||
text_generation_signaled = False
|
||||
frames_sent = []
|
||||
|
||||
# Simulate chunks with no text content (function calls only)
|
||||
chunks_with_content = [None, None, None] # No text content
|
||||
|
||||
for content in chunks_with_content:
|
||||
if content: # This will be False for all chunks
|
||||
if not text_generation_signaled:
|
||||
frames_sent.append("LLMGeneratedTextFrame")
|
||||
text_generation_signaled = True
|
||||
frames_sent.append(f"LLMTextFrame({content})")
|
||||
|
||||
print(f"Frames sent: {frames_sent}")
|
||||
|
||||
assert len(frames_sent) == 0, f"Expected no frames, got {frames_sent}"
|
||||
|
||||
print("✅ No signal sent when no text content")
|
||||
return True
|
||||
|
||||
|
||||
def test_mixed_content_logic():
|
||||
"""Test behavior with mixed function calls and text"""
|
||||
|
||||
print("\n=== Testing Mixed Content Logic ===")
|
||||
|
||||
text_generation_signaled = False
|
||||
frames_sent = []
|
||||
|
||||
# Simulate chunks: function call, text, function call, text
|
||||
chunks = [
|
||||
{"type": "function", "content": None},
|
||||
{"type": "text", "content": "Hello"},
|
||||
{"type": "function", "content": None},
|
||||
{"type": "text", "content": " world"},
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
if chunk["type"] == "function":
|
||||
frames_sent.append("FunctionCallFrame")
|
||||
elif chunk["content"]: # text content
|
||||
if not text_generation_signaled:
|
||||
frames_sent.append("LLMGeneratedTextFrame")
|
||||
text_generation_signaled = True
|
||||
frames_sent.append(f"LLMTextFrame({chunk['content']})")
|
||||
|
||||
print(f"Frames sent: {frames_sent}")
|
||||
|
||||
generated_signals = [f for f in frames_sent if f == "LLMGeneratedTextFrame"]
|
||||
|
||||
assert len(generated_signals) == 1, (
|
||||
f"Expected 1 signal, got {len(generated_signals)}"
|
||||
)
|
||||
# Signal should come before first text frame but after any function frames
|
||||
signal_index = frames_sent.index("LLMGeneratedTextFrame")
|
||||
first_text_index = next(
|
||||
i for i, f in enumerate(frames_sent) if f.startswith("LLMTextFrame")
|
||||
)
|
||||
assert signal_index == first_text_index - 1, (
|
||||
"Signal should come right before first text"
|
||||
)
|
||||
|
||||
print("✅ Mixed content logic works correctly")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
test1_result = test_local_variable_logic()
|
||||
test2_result = test_no_text_logic()
|
||||
test3_result = test_mixed_content_logic()
|
||||
|
||||
print(f"\n=== Test Results ===")
|
||||
print(f"Local variable test: {'✅ PASS' if test1_result else '❌ FAIL'}")
|
||||
print(f"No text test: {'✅ PASS' if test2_result else '❌ FAIL'}")
|
||||
print(f"Mixed content test: {'✅ PASS' if test3_result else '❌ FAIL'}")
|
||||
|
||||
if test1_result and test2_result and test3_result:
|
||||
print("\n🎉 All LLMGeneratedTextFrame signaling logic tests passed!")
|
||||
print(
|
||||
"✅ Implementation correctly signals text generation once, as early as possible"
|
||||
)
|
||||
else:
|
||||
print("\n❌ Some tests failed.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,536 +0,0 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.dto import EdgeDataDTO, NodeDataDTO
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import Edge, Node, WorkflowGraph
|
||||
|
||||
|
||||
class TestPipecatEngineSetNode:
|
||||
"""Test cases for PipecatEngine.set_node method refactoring."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow(self):
|
||||
"""Create a mock workflow with various node types."""
|
||||
workflow = Mock(spec=WorkflowGraph)
|
||||
workflow.nodes = {}
|
||||
workflow.start_node_id = "start_node"
|
||||
workflow.global_node_id = None
|
||||
return workflow
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self, mock_workflow):
|
||||
"""Create mock dependencies for PipecatEngine initialization."""
|
||||
task = AsyncMock()
|
||||
task.queue_frames = AsyncMock()
|
||||
task.queue_frame = AsyncMock()
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.register_function = Mock()
|
||||
llm.push_frame = AsyncMock()
|
||||
|
||||
context = Mock(spec=OpenAILLMContext)
|
||||
context.set_node_name = Mock()
|
||||
|
||||
return {
|
||||
"task": task,
|
||||
"llm": llm,
|
||||
"context": context,
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"workflow": mock_workflow,
|
||||
"call_context_vars": {"test_var": "test_value"},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, mock_dependencies):
|
||||
"""Create a PipecatEngine instance."""
|
||||
# Add audio_buffer and workflow_run_id to dependencies
|
||||
mock_dependencies["audio_buffer"] = None
|
||||
mock_dependencies["workflow_run_id"] = 123
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
# Mock the builtin function registration
|
||||
engine._register_builtin_functions = AsyncMock()
|
||||
return engine
|
||||
|
||||
def create_node(self, node_id, **kwargs):
|
||||
"""Helper to create a node with default values."""
|
||||
defaults = {
|
||||
"name": f"Node {node_id}",
|
||||
"prompt": f"Prompt for {node_id}",
|
||||
"is_static": False,
|
||||
"is_start": False,
|
||||
"is_end": False,
|
||||
"allow_interrupt": True,
|
||||
"extraction_enabled": False,
|
||||
"extraction_prompt": "",
|
||||
"extraction_variables": [],
|
||||
"add_global_prompt": True,
|
||||
"wait_for_user_response": False,
|
||||
"detect_voicemail": False,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
|
||||
data = Mock(spec=NodeDataDTO)
|
||||
for key, value in defaults.items():
|
||||
setattr(data, key, value)
|
||||
|
||||
node = Mock(spec=Node)
|
||||
node.id = node_id
|
||||
node.data = data
|
||||
node.out_edges = []
|
||||
|
||||
# Copy attributes from data to node
|
||||
for key, value in defaults.items():
|
||||
setattr(node, key, value)
|
||||
|
||||
return node
|
||||
|
||||
def create_edge(
|
||||
self, source, target, label="Continue", condition="Always continue"
|
||||
):
|
||||
"""Helper to create an edge."""
|
||||
data = Mock(spec=EdgeDataDTO)
|
||||
data.label = label
|
||||
data.condition = condition
|
||||
|
||||
edge = Mock(spec=Edge)
|
||||
edge.source = source
|
||||
edge.target = target
|
||||
edge.data = data
|
||||
edge.get_function_name = Mock(return_value=label.lower().replace(" ", "_"))
|
||||
|
||||
return edge
|
||||
|
||||
# ===== START NODE TESTS =====
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_static_immediate_execution(self, engine, mock_workflow):
|
||||
"""Test: Basic static start node executes immediately."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=True,
|
||||
prompt="Welcome to our service!",
|
||||
)
|
||||
next_node = self.create_node("next_node", is_static=False)
|
||||
|
||||
edge = self.create_edge("start_node", "next_node")
|
||||
start_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should queue TTS immediately
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
frames = engine.task.queue_frames.call_args[0][0]
|
||||
assert len(frames) == 3
|
||||
assert isinstance(frames[0], LLMFullResponseStartFrame)
|
||||
assert isinstance(frames[1], TTSSpeakFrame)
|
||||
assert frames[1].text == "Welcome to our service!"
|
||||
assert isinstance(frames[2], LLMFullResponseEndFrame)
|
||||
|
||||
# Static start nodes now set pending transition after context push
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
# Should not have set detect_voicemail for static start without it
|
||||
assert not engine._detect_voicemail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_with_detect_voicemail_no_audio_buffer(
|
||||
self, engine, mock_workflow
|
||||
):
|
||||
"""Test: Start node with voicemail detection but no audio buffer logs warning."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=True,
|
||||
detect_voicemail=True,
|
||||
prompt="Hello, this is a business call.",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node}
|
||||
|
||||
# Engine has no audio buffer (None)
|
||||
assert engine._audio_buffer is None
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should NOT set voicemail detection flag since no audio buffer
|
||||
assert engine._detect_voicemail is False
|
||||
assert engine._voicemail_detector is None
|
||||
|
||||
# Should queue TTS immediately
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
frames = engine.task.queue_frames.call_args[0][0]
|
||||
assert isinstance(frames[1], TTSSpeakFrame)
|
||||
assert frames[1].text == "Hello, this is a business call."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_non_static_with_detect_voicemail(
|
||||
self, engine, mock_workflow
|
||||
):
|
||||
"""Test: Non-static start node with voicemail detection without audio buffer."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=False, # Non-static
|
||||
detect_voicemail=True,
|
||||
prompt="You are an AI assistant. Start the conversation.",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node}
|
||||
|
||||
# Mock the context update method
|
||||
engine._update_llm_context = AsyncMock()
|
||||
engine._compose_system_message_functions_for_node = AsyncMock(
|
||||
return_value=({"role": "system", "content": "Test prompt"}, [])
|
||||
)
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should NOT set voicemail detection flags (no audio buffer)
|
||||
assert engine._detect_voicemail is False
|
||||
assert engine._voicemail_detector is None
|
||||
|
||||
# Should update LLM context for non-static node
|
||||
engine._update_llm_context.assert_called_once()
|
||||
|
||||
# Should queue context frame
|
||||
engine.task.queue_frame.assert_called_once()
|
||||
frame = engine.task.queue_frame.call_args[0][0]
|
||||
assert isinstance(frame, OpenAILLMContextFrame)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_static_with_wait_for_user_response(
|
||||
self, engine, mock_workflow
|
||||
):
|
||||
"""Test: Static start node with wait_for_user_response."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=True,
|
||||
wait_for_user_response=True,
|
||||
prompt="Please tell me your name.",
|
||||
)
|
||||
next_node = self.create_node("next_node")
|
||||
|
||||
edge = self.create_edge("start_node", "next_node")
|
||||
start_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should queue TTS immediately
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
|
||||
# Should have a pending control transition that will start the timer
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
# Timer task should not exist yet
|
||||
assert (
|
||||
not hasattr(engine, "_user_response_timeout_task")
|
||||
or engine._user_response_timeout_task is None
|
||||
)
|
||||
|
||||
# Simulate context push to start the timer
|
||||
await engine.flush_pending_transitions(source="context_push")
|
||||
|
||||
# Now the timeout task should be created
|
||||
assert engine._user_response_timeout_task is not None
|
||||
assert not engine._user_response_timeout_task.done()
|
||||
|
||||
# Clean up the task
|
||||
engine._user_response_timeout_task.cancel()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_non_static(self, engine, mock_workflow):
|
||||
"""Test: Non-static start node sends context to LLM."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=False,
|
||||
prompt="You are a helpful assistant. Greet the user.",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node}
|
||||
|
||||
# Mock the context update method
|
||||
engine._update_llm_context = AsyncMock()
|
||||
engine._compose_system_message_functions_for_node = AsyncMock(
|
||||
return_value=({"role": "system", "content": "Test prompt"}, [])
|
||||
)
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should set context name
|
||||
engine.context.set_node_name.assert_called_once_with("Node start_node")
|
||||
|
||||
# Should update LLM context
|
||||
engine._update_llm_context.assert_called_once()
|
||||
|
||||
# Should queue context frame
|
||||
engine.task.queue_frame.assert_called_once()
|
||||
frame = engine.task.queue_frame.call_args[0][0]
|
||||
assert isinstance(frame, OpenAILLMContextFrame)
|
||||
|
||||
# ===== AGENT NODE TESTS =====
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_node_static(self, engine, mock_workflow):
|
||||
"""Test: Static agent node plays TTS and transitions."""
|
||||
# Setup
|
||||
agent_node = self.create_node(
|
||||
"agent_node", is_static=True, prompt="Processing your request..."
|
||||
)
|
||||
next_node = self.create_node("next_node")
|
||||
|
||||
edge = self.create_edge("agent_node", "next_node")
|
||||
agent_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"agent_node": agent_node, "next_node": next_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("agent_node")
|
||||
|
||||
# Verify
|
||||
# Should queue TTS
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
frames = engine.task.queue_frames.call_args[0][0]
|
||||
assert isinstance(frames[1], TTSSpeakFrame)
|
||||
assert frames[1].text == "Processing your request..."
|
||||
|
||||
# Should have pending transition
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_node_non_static(self, engine, mock_workflow):
|
||||
"""Test: Non-static agent node sends context to LLM."""
|
||||
# Setup
|
||||
agent_node = self.create_node(
|
||||
"agent_node",
|
||||
is_static=False,
|
||||
prompt="Analyze the user's request and respond appropriately.",
|
||||
)
|
||||
decision_node = self.create_node("decision_node")
|
||||
|
||||
edge = self.create_edge("agent_node", "decision_node", "analyze_complete")
|
||||
agent_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"agent_node": agent_node, "decision_node": decision_node}
|
||||
|
||||
# Mock methods
|
||||
engine._update_llm_context = AsyncMock()
|
||||
engine._compose_system_message_functions_for_node = AsyncMock(
|
||||
return_value=(
|
||||
{"role": "system", "content": "Test"},
|
||||
[{"name": "test_func"}],
|
||||
)
|
||||
)
|
||||
|
||||
# Execute
|
||||
await engine.set_node("agent_node")
|
||||
|
||||
# Verify
|
||||
# Should register transition function
|
||||
engine.llm.register_function.assert_called_once()
|
||||
call_args = engine.llm.register_function.call_args
|
||||
assert call_args[0][0] == "analyze_complete"
|
||||
assert callable(call_args[0][1]) # Check it's a function
|
||||
assert call_args[1]["cancel_on_interruption"] is True
|
||||
|
||||
# Should update context and send frame
|
||||
engine._update_llm_context.assert_called_once()
|
||||
engine.task.queue_frame.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_node_with_interruption_control(self, engine, mock_workflow):
|
||||
"""Test: Agent node respects allow_interrupt flag."""
|
||||
# Setup
|
||||
no_interrupt_node = self.create_node(
|
||||
"no_interrupt",
|
||||
is_static=True,
|
||||
allow_interrupt=False,
|
||||
prompt="Please wait while I process...",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"no_interrupt": no_interrupt_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("no_interrupt")
|
||||
|
||||
# Verify current node is set (for STT mute callback)
|
||||
assert engine._current_node == no_interrupt_node
|
||||
assert engine._current_node.allow_interrupt is False
|
||||
|
||||
# ===== END NODE TESTS =====
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_node_static(self, engine, mock_workflow):
|
||||
"""Test: Static end node plays final message and schedules end task."""
|
||||
# Setup
|
||||
end_node = self.create_node(
|
||||
"end_node",
|
||||
is_static=True,
|
||||
is_end=True,
|
||||
prompt="Thank you for calling. Goodbye!",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"end_node": end_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("end_node")
|
||||
|
||||
# Verify
|
||||
# Should queue TTS
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
frames = engine.task.queue_frames.call_args[0][0]
|
||||
assert frames[1].text == "Thank you for calling. Goodbye!"
|
||||
|
||||
# Should have pending end task
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
# Execute the pending transition
|
||||
await engine._pending_control_transition_after_context_push()
|
||||
|
||||
# Should have sent EndFrame via task.queue_frame
|
||||
# The second call should be the EndFrame (first was TTS frames)
|
||||
assert engine.task.queue_frame.call_count >= 1
|
||||
end_frame = engine.task.queue_frame.call_args[0][0]
|
||||
assert isinstance(end_frame, EndFrame)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_node_with_extraction(self, engine, mock_workflow):
|
||||
"""Test: End node with variable extraction."""
|
||||
# Setup
|
||||
end_node = self.create_node(
|
||||
"end_node",
|
||||
is_end=True,
|
||||
is_static=False,
|
||||
extraction_enabled=True,
|
||||
extraction_variables=["user_name", "satisfaction_level"],
|
||||
extraction_prompt="Extract user name and satisfaction",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"end_node": end_node}
|
||||
|
||||
# Mock the extraction manager
|
||||
engine._variable_extraction_manager = Mock()
|
||||
engine._perform_variable_extraction_if_needed = AsyncMock()
|
||||
|
||||
# Mock context update and composition methods
|
||||
engine._update_llm_context = AsyncMock()
|
||||
engine._compose_system_message_functions_for_node = AsyncMock(
|
||||
return_value=({"role": "system", "content": "Test"}, [])
|
||||
)
|
||||
|
||||
# Execute
|
||||
await engine.set_node("end_node")
|
||||
|
||||
# Verify
|
||||
# Should trigger extraction
|
||||
engine._perform_variable_extraction_if_needed.assert_called_once_with(end_node)
|
||||
|
||||
# Should have pending end task
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
# ===== CALLBACK INTEGRATION TESTS =====
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_stopped_speaking_during_response_wait(
|
||||
self, engine, mock_workflow
|
||||
):
|
||||
"""Test: User stops speaking triggers transition during wait_for_response."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node", is_start=True, is_static=True, wait_for_user_response=True
|
||||
)
|
||||
next_node = self.create_node("next_node")
|
||||
|
||||
edge = self.create_edge("start_node", "next_node")
|
||||
start_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
|
||||
|
||||
# Set current node to start node
|
||||
engine._current_node = start_node
|
||||
engine._user_response_timeout_task = asyncio.create_task(asyncio.sleep(3))
|
||||
|
||||
# Create callback and execute
|
||||
callback = engine.create_user_stopped_speaking_callback()
|
||||
|
||||
# Mock set_node to avoid recursion
|
||||
with patch.object(engine, "set_node", new=AsyncMock()) as mock_set_node:
|
||||
await callback()
|
||||
|
||||
# Verify
|
||||
mock_set_node.assert_called_once_with("next_node")
|
||||
assert engine._queue_context_frame is False # Should be set to False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_push_callback_executes_pending_transitions(self, engine):
|
||||
"""Test: flush_pending_transitions executes deferred transitions."""
|
||||
# Setup pending transitions
|
||||
mock_generated_transition = AsyncMock()
|
||||
mock_control_transition = AsyncMock()
|
||||
|
||||
engine._pending_generated_transition_after_context_push = (
|
||||
mock_generated_transition
|
||||
)
|
||||
engine._pending_control_transition_after_context_push = mock_control_transition
|
||||
|
||||
# Execute
|
||||
await engine.flush_pending_transitions(source="context_push")
|
||||
|
||||
# Verify both transitions were executed
|
||||
mock_generated_transition.assert_called_once()
|
||||
mock_control_transition.assert_called_once()
|
||||
|
||||
# Verify they were cleared
|
||||
assert engine._pending_generated_transition_after_context_push is None
|
||||
assert engine._pending_control_transition_after_context_push is None
|
||||
|
||||
# ===== COMPLEX SCENARIO TESTS =====
|
||||
|
||||
|
||||
# Add helper for testing with real async behavior
|
||||
def ANY(cls=None):
|
||||
"""Helper for matching any argument in mock calls."""
|
||||
|
||||
class AnyMatcher:
|
||||
def __init__(self, cls):
|
||||
self.cls = cls
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.cls:
|
||||
return isinstance(other, self.cls)
|
||||
return True
|
||||
|
||||
return AnyMatcher(cls)
|
||||
|
|
@ -5,46 +5,41 @@ handles provider switches correctly without losing billing data.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any
|
||||
|
||||
# Test scenarios to validate
|
||||
|
||||
|
||||
async def test_scenario_1_mid_call_provider_switch():
|
||||
"""
|
||||
Test: What happens if provider is switched while a call is active?
|
||||
|
||||
|
||||
Expected behavior:
|
||||
- Active call continues with original provider
|
||||
- Call is billed to original provider
|
||||
- New calls use new provider
|
||||
"""
|
||||
print("Test 1: Mid-call provider switching")
|
||||
|
||||
|
||||
# Simulate workflow run with Twilio
|
||||
twilio_run = {
|
||||
"id": 1,
|
||||
"mode": "twilio",
|
||||
"cost_info": {
|
||||
"twilio_call_sid": "CA123456789",
|
||||
"provider": "twilio"
|
||||
},
|
||||
"is_completed": False
|
||||
"cost_info": {"twilio_call_sid": "CA123456789", "provider": "twilio"},
|
||||
"is_completed": False,
|
||||
}
|
||||
|
||||
|
||||
# Provider switch happens here (in real scenario, user changes config)
|
||||
# But the call continues...
|
||||
|
||||
|
||||
# When cost calculation runs, it should:
|
||||
# 1. Use the provider stored in cost_info
|
||||
# 2. Fetch cost from Twilio using twilio_call_sid
|
||||
# 3. Store cost with provider attribution
|
||||
|
||||
|
||||
result = {
|
||||
"test": "mid_call_switch",
|
||||
"status": "PASS",
|
||||
"reason": "Call continues with original provider, billing intact"
|
||||
"reason": "Call continues with original provider, billing intact",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
|
@ -53,41 +48,41 @@ async def test_scenario_1_mid_call_provider_switch():
|
|||
async def test_scenario_2_pending_cost_calculation():
|
||||
"""
|
||||
Test: Calls that ended but cost not yet calculated when provider switches.
|
||||
|
||||
|
||||
Expected behavior:
|
||||
- Background job should use the provider info stored in cost_info
|
||||
- Cost should be fetched from correct provider
|
||||
"""
|
||||
print("\nTest 2: Pending cost calculation during switch")
|
||||
|
||||
|
||||
# Workflow runs that ended but cost job hasn't run yet
|
||||
pending_runs = [
|
||||
{
|
||||
"id": 2,
|
||||
"mode": "twilio",
|
||||
"mode": "twilio",
|
||||
"cost_info": {"twilio_call_sid": "CA987654321", "provider": "twilio"},
|
||||
"is_completed": True
|
||||
"is_completed": True,
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"mode": "vonage",
|
||||
"cost_info": {"vonage_call_uuid": "uuid-123", "provider": "vonage"},
|
||||
"is_completed": True
|
||||
}
|
||||
"is_completed": True,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Provider switch happens here
|
||||
# Cost calculation jobs run after switch
|
||||
|
||||
|
||||
# Each job should:
|
||||
# 1. Check the provider field in cost_info
|
||||
# 2. Use appropriate provider API to fetch cost
|
||||
# 3. Handle gracefully if credentials changed
|
||||
|
||||
|
||||
result = {
|
||||
"test": "pending_cost_calculation",
|
||||
"status": "PASS",
|
||||
"reason": "Cost jobs use stored provider info correctly"
|
||||
"reason": "Cost jobs use stored provider info correctly",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
|
@ -96,33 +91,37 @@ async def test_scenario_2_pending_cost_calculation():
|
|||
async def test_scenario_3_mixed_provider_history():
|
||||
"""
|
||||
Test: Organization has calls from both Twilio and Vonage.
|
||||
|
||||
|
||||
Expected behavior:
|
||||
- Historical costs remain intact
|
||||
- Reports show correct attribution
|
||||
- Total costs aggregate correctly
|
||||
"""
|
||||
print("\nTest 3: Mixed provider history")
|
||||
|
||||
|
||||
historical_runs = [
|
||||
{"provider": "twilio", "cost_usd": 0.15, "date": "2024-01-01"},
|
||||
{"provider": "vonage", "cost_usd": 0.12, "date": "2024-01-02"},
|
||||
{"provider": "twilio", "cost_usd": 0.18, "date": "2024-01-03"},
|
||||
{"provider": "vonage", "cost_usd": 0.14, "date": "2024-01-04"},
|
||||
]
|
||||
|
||||
|
||||
# Calculate totals
|
||||
total_cost = sum(run["cost_usd"] for run in historical_runs)
|
||||
twilio_cost = sum(run["cost_usd"] for run in historical_runs if run["provider"] == "twilio")
|
||||
vonage_cost = sum(run["cost_usd"] for run in historical_runs if run["provider"] == "vonage")
|
||||
|
||||
twilio_cost = sum(
|
||||
run["cost_usd"] for run in historical_runs if run["provider"] == "twilio"
|
||||
)
|
||||
vonage_cost = sum(
|
||||
run["cost_usd"] for run in historical_runs if run["provider"] == "vonage"
|
||||
)
|
||||
|
||||
result = {
|
||||
"test": "mixed_provider_history",
|
||||
"status": "PASS",
|
||||
"status": "PASS",
|
||||
"total_cost": total_cost,
|
||||
"twilio_cost": twilio_cost,
|
||||
"vonage_cost": vonage_cost,
|
||||
"reason": f"Costs correctly aggregated: Total ${total_cost:.2f} (Twilio: ${twilio_cost:.2f}, Vonage: ${vonage_cost:.2f})"
|
||||
"reason": f"Costs correctly aggregated: Total ${total_cost:.2f} (Twilio: ${twilio_cost:.2f}, Vonage: ${vonage_cost:.2f})",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
|
@ -131,41 +130,41 @@ async def test_scenario_3_mixed_provider_history():
|
|||
async def test_scenario_4_cost_api_failure():
|
||||
"""
|
||||
Test: Provider API fails when fetching cost.
|
||||
|
||||
|
||||
Expected behavior:
|
||||
- Error logged but system continues
|
||||
- Call record preserved
|
||||
- Cost marked as 0 or unknown
|
||||
"""
|
||||
print("\nTest 4: Cost API failure handling")
|
||||
|
||||
|
||||
# Simulate API failure scenarios
|
||||
failure_scenarios = [
|
||||
{
|
||||
"provider": "twilio",
|
||||
"error": "401 Unauthorized - credentials changed",
|
||||
"expected": "Cost set to 0, error logged"
|
||||
"expected": "Cost set to 0, error logged",
|
||||
},
|
||||
{
|
||||
"provider": "vonage",
|
||||
"error": "404 Not Found - call record deleted",
|
||||
"expected": "Cost set to 0, error logged"
|
||||
"expected": "Cost set to 0, error logged",
|
||||
},
|
||||
{
|
||||
"provider": "twilio",
|
||||
"error": "500 Internal Server Error",
|
||||
"expected": "Cost set to 0, retry possible"
|
||||
}
|
||||
"expected": "Cost set to 0, retry possible",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
for scenario in failure_scenarios:
|
||||
print(f" - {scenario['provider']}: {scenario['error']}")
|
||||
print(f" Expected: {scenario['expected']}")
|
||||
|
||||
|
||||
result = {
|
||||
"test": "cost_api_failure",
|
||||
"status": "PASS",
|
||||
"reason": "All failure scenarios handled gracefully"
|
||||
"reason": "All failure scenarios handled gracefully",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
|
@ -174,22 +173,22 @@ async def test_scenario_4_cost_api_failure():
|
|||
async def test_scenario_5_configuration_migration():
|
||||
"""
|
||||
Test: Database migration from single to multi-provider format.
|
||||
|
||||
|
||||
Expected behavior:
|
||||
- Old TWILIO_CONFIGURATION migrated to TELEPHONY_CONFIGURATION
|
||||
- Single provider config wrapped in multi-provider structure
|
||||
- Existing cost_info gets provider field added
|
||||
"""
|
||||
print("\nTest 5: Configuration migration")
|
||||
|
||||
|
||||
# Old format
|
||||
old_config = {
|
||||
"account_sid": "AC123",
|
||||
"auth_token": "token123",
|
||||
"auth_token": "token123",
|
||||
"from_numbers": ["+1234567890"],
|
||||
"provider": "twilio"
|
||||
"provider": "twilio",
|
||||
}
|
||||
|
||||
|
||||
# New format after migration
|
||||
new_config = {
|
||||
"active_provider": "twilio",
|
||||
|
|
@ -197,20 +196,20 @@ async def test_scenario_5_configuration_migration():
|
|||
"twilio": {
|
||||
"account_sid": "AC123",
|
||||
"auth_token": "token123",
|
||||
"from_numbers": ["+1234567890"]
|
||||
"from_numbers": ["+1234567890"],
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Validate migration
|
||||
assert new_config["active_provider"] == "twilio"
|
||||
assert "providers" in new_config
|
||||
assert new_config["providers"]["twilio"]["account_sid"] == old_config["account_sid"]
|
||||
|
||||
|
||||
result = {
|
||||
"test": "configuration_migration",
|
||||
"status": "PASS",
|
||||
"reason": "Configuration migrated to multi-provider format correctly"
|
||||
"reason": "Configuration migrated to multi-provider format correctly",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
|
@ -219,39 +218,34 @@ async def test_scenario_5_configuration_migration():
|
|||
async def test_scenario_6_provider_cost_discrepancy():
|
||||
"""
|
||||
Test: Webhook cost vs API cost discrepancy.
|
||||
|
||||
|
||||
Expected behavior:
|
||||
- Webhook cost stored immediately if available
|
||||
- API cost fetched later for verification
|
||||
- Both costs stored for auditing
|
||||
"""
|
||||
print("\nTest 6: Provider cost discrepancy handling")
|
||||
|
||||
|
||||
# Vonage webhook provides immediate cost
|
||||
webhook_cost = {
|
||||
"vonage_webhook_price": 0.15,
|
||||
"vonage_webhook_duration": 120
|
||||
}
|
||||
|
||||
webhook_cost = {"vonage_webhook_price": 0.15, "vonage_webhook_duration": 120}
|
||||
|
||||
# API call provides authoritative cost
|
||||
api_cost = {
|
||||
"cost_usd": 0.14, # Slight difference
|
||||
"duration": 120
|
||||
"duration": 120,
|
||||
}
|
||||
|
||||
|
||||
# Both should be stored
|
||||
final_cost_info = {
|
||||
**webhook_cost,
|
||||
"cost_breakdown": {
|
||||
"telephony_call": api_cost["cost_usd"]
|
||||
},
|
||||
"provider": "vonage"
|
||||
"cost_breakdown": {"telephony_call": api_cost["cost_usd"]},
|
||||
"provider": "vonage",
|
||||
}
|
||||
|
||||
|
||||
result = {
|
||||
"test": "cost_discrepancy",
|
||||
"status": "PASS",
|
||||
"reason": "Both webhook and API costs stored for auditing"
|
||||
"reason": "Both webhook and API costs stored for auditing",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
|
@ -262,40 +256,40 @@ async def run_all_tests():
|
|||
print("=" * 60)
|
||||
print("PROVIDER SWITCHING TEST SUITE")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
tests = [
|
||||
test_scenario_1_mid_call_provider_switch,
|
||||
test_scenario_2_pending_cost_calculation,
|
||||
test_scenario_3_mixed_provider_history,
|
||||
test_scenario_4_cost_api_failure,
|
||||
test_scenario_5_configuration_migration,
|
||||
test_scenario_6_provider_cost_discrepancy
|
||||
test_scenario_6_provider_cost_discrepancy,
|
||||
]
|
||||
|
||||
|
||||
results = []
|
||||
for test in tests:
|
||||
result = await test()
|
||||
results.append(result)
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
passed = sum(1 for r in results if r["status"] == "PASS")
|
||||
failed = sum(1 for r in results if r["status"] == "FAIL")
|
||||
|
||||
|
||||
print(f"Total Tests: {len(results)}")
|
||||
print(f"Passed: {passed}")
|
||||
print(f"Failed: {failed}")
|
||||
|
||||
|
||||
if failed == 0:
|
||||
print("\n✅ ALL TESTS PASSED - Provider switching is working correctly!")
|
||||
else:
|
||||
print("\n❌ Some tests failed - Review the implementation")
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the test suite
|
||||
asyncio.run(run_all_tests())
|
||||
asyncio.run(run_all_tests())
|
||||
|
|
|
|||
2
pipecat
2
pipecat
|
|
@ -1 +1 @@
|
|||
Subproject commit fa68d2ce261544398013307d2c6a69e0556b4449
|
||||
Subproject commit 53653657d851e8052f9cc5b73b6f675a44c86fe7
|
||||
|
|
@ -18,8 +18,6 @@ interface EndCallEditFormProps {
|
|||
nodeData: FlowNodeData;
|
||||
prompt: string;
|
||||
setPrompt: (value: string) => void;
|
||||
isStatic: boolean;
|
||||
setIsStatic: (value: boolean) => void;
|
||||
name: string;
|
||||
setName: (value: string) => void;
|
||||
extractionEnabled: boolean;
|
||||
|
|
@ -45,7 +43,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
|
|||
|
||||
// Form state
|
||||
const [prompt, setPrompt] = useState(data.prompt);
|
||||
const [isStatic, setIsStatic] = useState(data.is_static ?? true);
|
||||
const [name, setName] = useState(data.name);
|
||||
|
||||
// Variable Extraction state
|
||||
|
|
@ -58,7 +55,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
|
|||
handleSaveNodeData({
|
||||
...data,
|
||||
prompt,
|
||||
is_static: isStatic,
|
||||
name,
|
||||
allow_interrupt: false, // Always set to false for end nodes
|
||||
extraction_enabled: extractionEnabled,
|
||||
|
|
@ -77,7 +73,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
|
|||
const handleOpenChange = (newOpen: boolean) => {
|
||||
if (newOpen) {
|
||||
setPrompt(data.prompt);
|
||||
setIsStatic(data.is_static ?? true);
|
||||
setName(data.name);
|
||||
setExtractionEnabled(data.extraction_enabled ?? false);
|
||||
setExtractionPrompt(data.extraction_prompt ?? "");
|
||||
|
|
@ -91,7 +86,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
|
|||
useEffect(() => {
|
||||
if (open) {
|
||||
setPrompt(data.prompt);
|
||||
setIsStatic(data.is_static ?? true);
|
||||
setName(data.name);
|
||||
setExtractionEnabled(data.extraction_enabled ?? false);
|
||||
setExtractionPrompt(data.extraction_prompt ?? "");
|
||||
|
|
@ -137,8 +131,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
|
|||
nodeData={data}
|
||||
prompt={prompt}
|
||||
setPrompt={setPrompt}
|
||||
isStatic={isStatic}
|
||||
setIsStatic={setIsStatic}
|
||||
name={name}
|
||||
setName={setName}
|
||||
extractionEnabled={extractionEnabled}
|
||||
|
|
@ -159,8 +151,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
|
|||
const EndCallEditForm = ({
|
||||
prompt,
|
||||
setPrompt,
|
||||
isStatic,
|
||||
setIsStatic,
|
||||
name,
|
||||
setName,
|
||||
extractionEnabled,
|
||||
|
|
@ -206,14 +196,10 @@ const EndCallEditForm = ({
|
|||
</Label>
|
||||
<Input value={name} onChange={(e) => setName(e.target.value)} />
|
||||
|
||||
<Label>{isStatic ? "Text" : "Prompt"}</Label>
|
||||
<Label>Prompt</Label>
|
||||
<Label className="text-xs text-gray-500">
|
||||
What would you like the agent to say when the call ends? Its a good idea to have a static goodbye message.
|
||||
Enter the prompt for the agent. This will be used to generate the agent's response. Prompt engineering's best practices apply.
|
||||
</Label>
|
||||
<div className="flex items-center space-x-2">
|
||||
<Switch id="static-text" checked={isStatic} onCheckedChange={setIsStatic} />
|
||||
<Label htmlFor="static-text">Static Text</Label>
|
||||
</div>
|
||||
<Textarea
|
||||
value={prompt}
|
||||
onChange={(e) => setPrompt(e.target.value)}
|
||||
|
|
@ -221,7 +207,7 @@ const EndCallEditForm = ({
|
|||
style={{
|
||||
overflowY: 'auto'
|
||||
}}
|
||||
placeholder={isStatic ? "Thank you for calling Dograh. Have a great day!" : "Enter a dynamic prompt"}
|
||||
placeholder="Enter a dynamic prompt"
|
||||
/>
|
||||
<div className="flex items-center space-x-2">
|
||||
<Switch id="add-global-prompt" checked={addGlobalPrompt} onCheckedChange={setAddGlobalPrompt} />
|
||||
|
|
|
|||
|
|
@ -19,16 +19,12 @@ interface StartCallEditFormProps {
|
|||
nodeData: FlowNodeData;
|
||||
prompt: string;
|
||||
setPrompt: (value: string) => void;
|
||||
isStatic: boolean;
|
||||
setIsStatic: (value: boolean) => void;
|
||||
name: string;
|
||||
setName: (value: string) => void;
|
||||
allowInterrupt: boolean;
|
||||
setAllowInterrupt: (value: boolean) => void;
|
||||
addGlobalPrompt: boolean;
|
||||
setAddGlobalPrompt: (value: boolean) => void;
|
||||
waitForUserResponse: boolean;
|
||||
setWaitForUserResponse: (value: boolean) => void;
|
||||
detectVoicemail: boolean;
|
||||
setDetectVoicemail: (value: boolean) => void;
|
||||
delayedStart: boolean;
|
||||
|
|
@ -50,11 +46,9 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
|
||||
// Form state
|
||||
const [prompt, setPrompt] = useState(data.prompt ?? "");
|
||||
const [isStatic, setIsStatic] = useState(data.is_static ?? true);
|
||||
const [name, setName] = useState(data.name);
|
||||
const [allowInterrupt, setAllowInterrupt] = useState(data.allow_interrupt ?? true);
|
||||
const [addGlobalPrompt, setAddGlobalPrompt] = useState(data.add_global_prompt ?? true);
|
||||
const [waitForUserResponse, setWaitForUserResponse] = useState(data.wait_for_user_response ?? false);
|
||||
const [detectVoicemail, setDetectVoicemail] = useState(data.detect_voicemail ?? true);
|
||||
const [delayedStart, setDelayedStart] = useState(data.delayed_start ?? false);
|
||||
const [delayedStartDuration, setDelayedStartDuration] = useState(data.delayed_start_duration ?? 2);
|
||||
|
|
@ -63,11 +57,9 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
handleSaveNodeData({
|
||||
...data,
|
||||
prompt,
|
||||
is_static: isStatic,
|
||||
name,
|
||||
allow_interrupt: allowInterrupt,
|
||||
add_global_prompt: addGlobalPrompt,
|
||||
wait_for_user_response: waitForUserResponse,
|
||||
detect_voicemail: detectVoicemail,
|
||||
delayed_start: delayedStart,
|
||||
delayed_start_duration: delayedStart ? delayedStartDuration : undefined
|
||||
|
|
@ -83,11 +75,9 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
const handleOpenChange = (newOpen: boolean) => {
|
||||
if (newOpen) {
|
||||
setPrompt(data.prompt ?? "");
|
||||
setIsStatic(data.is_static ?? true);
|
||||
setName(data.name);
|
||||
setAllowInterrupt(data.allow_interrupt ?? true);
|
||||
setAddGlobalPrompt(data.add_global_prompt ?? true);
|
||||
setWaitForUserResponse(data.wait_for_user_response ?? false);
|
||||
setDetectVoicemail(data.detect_voicemail ?? true);
|
||||
setDelayedStart(data.delayed_start ?? false);
|
||||
setDelayedStartDuration(data.delayed_start_duration ?? 3);
|
||||
|
|
@ -99,11 +89,9 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
useEffect(() => {
|
||||
if (open) {
|
||||
setPrompt(data.prompt ?? "");
|
||||
setIsStatic(data.is_static ?? true);
|
||||
setName(data.name);
|
||||
setAllowInterrupt(data.allow_interrupt ?? true);
|
||||
setAddGlobalPrompt(data.add_global_prompt ?? true);
|
||||
setWaitForUserResponse(data.wait_for_user_response ?? false);
|
||||
setDetectVoicemail(data.detect_voicemail ?? true);
|
||||
setDelayedStart(data.delayed_start ?? false);
|
||||
setDelayedStartDuration(data.delayed_start_duration ?? 3);
|
||||
|
|
@ -147,16 +135,12 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
nodeData={data}
|
||||
prompt={prompt}
|
||||
setPrompt={setPrompt}
|
||||
isStatic={isStatic}
|
||||
setIsStatic={setIsStatic}
|
||||
name={name}
|
||||
setName={setName}
|
||||
allowInterrupt={allowInterrupt}
|
||||
setAllowInterrupt={setAllowInterrupt}
|
||||
addGlobalPrompt={addGlobalPrompt}
|
||||
setAddGlobalPrompt={setAddGlobalPrompt}
|
||||
waitForUserResponse={waitForUserResponse}
|
||||
setWaitForUserResponse={setWaitForUserResponse}
|
||||
detectVoicemail={detectVoicemail}
|
||||
setDetectVoicemail={setDetectVoicemail}
|
||||
delayedStart={delayedStart}
|
||||
|
|
@ -173,16 +157,12 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
const StartCallEditForm = ({
|
||||
prompt,
|
||||
setPrompt,
|
||||
isStatic,
|
||||
setIsStatic,
|
||||
name,
|
||||
setName,
|
||||
allowInterrupt,
|
||||
setAllowInterrupt,
|
||||
addGlobalPrompt,
|
||||
setAddGlobalPrompt,
|
||||
waitForUserResponse,
|
||||
setWaitForUserResponse,
|
||||
detectVoicemail,
|
||||
setDetectVoicemail,
|
||||
delayedStart,
|
||||
|
|
@ -201,14 +181,10 @@ const StartCallEditForm = ({
|
|||
onChange={(e) => setName(e.target.value)}
|
||||
/>
|
||||
|
||||
<Label>{isStatic ? "Text" : "Prompt"}</Label>
|
||||
<Label>Prompt</Label>
|
||||
<Label className="text-xs text-gray-500">
|
||||
What would you like the agent to say when the call starts? Its a good idea to have a static greeting that can be used to identify the call.
|
||||
Enter the prompt for the agent. This will be used to generate the agent's response. Prompt engineering's best practices apply.
|
||||
</Label>
|
||||
<div className="flex items-center space-x-2">
|
||||
<Switch id="static-text" checked={isStatic} onCheckedChange={setIsStatic} />
|
||||
<Label htmlFor="static-text">Static Text</Label>
|
||||
</div>
|
||||
<Textarea
|
||||
value={prompt}
|
||||
onChange={(e) => setPrompt(e.target.value)}
|
||||
|
|
@ -216,7 +192,7 @@ const StartCallEditForm = ({
|
|||
style={{
|
||||
overflowY: 'auto'
|
||||
}}
|
||||
placeholder={isStatic ? "Hello, welcome to Dograh. How can I help you today?" : "Enter a dynamic prompt"}
|
||||
placeholder="Enter a prompt"
|
||||
/>
|
||||
<div className="flex items-center space-x-2">
|
||||
<Switch id="allow-interrupt" checked={allowInterrupt} onCheckedChange={setAllowInterrupt} />
|
||||
|
|
@ -230,34 +206,10 @@ const StartCallEditForm = ({
|
|||
id="add-global-prompt"
|
||||
checked={addGlobalPrompt}
|
||||
onCheckedChange={setAddGlobalPrompt}
|
||||
disabled={isStatic}
|
||||
/>
|
||||
<Label htmlFor="add-global-prompt" className={isStatic ? "opacity-50" : ""}>
|
||||
<Label htmlFor="add-global-prompt">
|
||||
Add Global Prompt
|
||||
</Label>
|
||||
<Label className={`text-xs text-gray-500 ${isStatic ? "opacity-50" : ""}`}>
|
||||
{isStatic
|
||||
? "Not applicable for static text"
|
||||
: "Whether you want to add global prompt with this node's prompt."}
|
||||
</Label>
|
||||
</div>
|
||||
<div className="flex flex-col space-y-2">
|
||||
<div className="flex items-center space-x-2">
|
||||
<Switch
|
||||
id="wait-for-user-response"
|
||||
checked={waitForUserResponse}
|
||||
onCheckedChange={setWaitForUserResponse}
|
||||
disabled={!isStatic}
|
||||
/>
|
||||
<Label htmlFor="wait-for-user-response" className={!isStatic ? "opacity-50" : ""}>
|
||||
Wait for user's response
|
||||
</Label>
|
||||
<Label className={`text-xs text-gray-500 ${!isStatic ? "opacity-50" : ""}`}>
|
||||
{!isStatic
|
||||
? "Only applicable for static text"
|
||||
: "Wait for user to respond before disconnecting the call."}
|
||||
</Label>
|
||||
</div>
|
||||
</div>
|
||||
{!isOSSMode() && (
|
||||
<div className="flex items-center space-x-2">
|
||||
|
|
|
|||
|
|
@ -20,8 +20,6 @@ export type FlowNodeData = {
|
|||
extraction_prompt?: string;
|
||||
extraction_variables?: ExtractionVariable[];
|
||||
add_global_prompt?: boolean;
|
||||
wait_for_user_response?: boolean;
|
||||
wait_for_user_response_timeout?: number;
|
||||
wait_for_user_greeting?: boolean;
|
||||
detect_voicemail?: boolean;
|
||||
delayed_start?: boolean;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue