From cc05f363ff9f887fcf824657131e36e365ec5f47 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Date: Sat, 15 Nov 2025 17:22:15 +0530 Subject: [PATCH] feat: simplify pipecat engine execution --- ...7d25b75117_add_vonage_and_rename_config.py | 38 +- api/enums.py | 8 +- api/requirements.txt | 2 +- api/routes/organization.py | 48 +- api/routes/telephony.py | 228 ++++---- api/routes/webrtc_signaling.py | 4 +- api/services/campaign/call_dispatcher.py | 10 +- .../looptalk/core/pipeline_builder.py | 7 +- api/services/looptalk/orchestrator.py | 8 - .../engine_pre_aggregator_processor.py | 69 --- api/services/pipecat/pipeline_builder.py | 11 +- .../pipeline_engine_callbacks_processor.py | 10 - api/services/pipecat/run_pipeline.py | 27 +- api/services/pipecat/transport_setup.py | 9 +- api/services/telephony/base.py | 49 +- api/services/telephony/factory.py | 46 +- api/services/telephony/providers/__init__.py | 2 +- .../telephony/providers/twilio_provider.py | 121 ++-- .../telephony/providers/vonage_provider.py | 202 ++++--- .../telephony/stasis_rtp_transport.py | 4 +- api/services/workflow/pipecat_engine.py | 244 +------- .../workflow/pipecat_engine_callbacks.py | 106 +--- api/services/workflow/pipecat_engine_utils.py | 27 +- .../pipecat_engine_variable_extractor.py | 8 +- api/services/workflow/workflow.py | 2 - api/tasks/run_integrations.py | 4 +- api/tasks/workflow_run_cost.py | 30 +- api/tests/test_base_openai_llm_service.py | 179 ------ api/tests/test_llm_generated_text_signal.py | 143 ----- api/tests/test_pipecat_engine_set_node.py | 536 ------------------ api/tests/test_provider_switching.py | 144 +++-- pipecat | 2 +- ui/src/components/flow/nodes/EndCall.tsx | 20 +- ui/src/components/flow/nodes/StartCall.tsx | 56 +- ui/src/components/flow/types.ts | 2 - 35 files changed, 545 insertions(+), 1861 deletions(-) delete mode 100644 api/services/pipecat/engine_pre_aggregator_processor.py delete mode 100644 api/tests/test_base_openai_llm_service.py delete mode 100644 api/tests/test_llm_generated_text_signal.py delete mode 100644 api/tests/test_pipecat_engine_set_node.py diff --git a/api/alembic/versions/a57d25b75117_add_vonage_and_rename_config.py b/api/alembic/versions/a57d25b75117_add_vonage_and_rename_config.py index d37b292..02643c0 100644 --- a/api/alembic/versions/a57d25b75117_add_vonage_and_rename_config.py +++ b/api/alembic/versions/a57d25b75117_add_vonage_and_rename_config.py @@ -5,15 +5,15 @@ Revises: 982ec8e434be Create Date: 2025-10-21 12:28:06.053318 """ + from typing import Sequence, Union from alembic import op from alembic_postgresql_enum import TableReference - # revision identifiers, used by Alembic. -revision: str = 'a57d25b75117' -down_revision: Union[str, None] = '982ec8e434be' +revision: str = "a57d25b75117" +down_revision: Union[str, None] = "982ec8e434be" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -26,12 +26,20 @@ def upgrade() -> None: 2. Migrates TWILIO_CONFIGURATION key to TELEPHONY_CONFIGURATION 3. Renames twilio_status_callbacks to telephony_status_callbacks in workflow_run logs """ - + # Add 'vonage' to the workflow_run_mode enum op.sync_enum_values( enum_schema="public", enum_name="workflow_run_mode", - new_values=["twilio", "stasis", "webrtc", "smallwebrtc", "VOICE", "CHAT", "vonage"], + new_values=[ + "twilio", + "stasis", + "webrtc", + "smallwebrtc", + "VOICE", + "CHAT", + "vonage", + ], affected_columns=[ TableReference( table_schema="public", table_name="workflow_runs", column_name="mode" @@ -39,14 +47,14 @@ def upgrade() -> None: ], enum_values_to_rename=[], ) - + # Rename the key from TWILIO_CONFIGURATION to TELEPHONY_CONFIGURATION op.execute(""" UPDATE organization_configurations SET key = 'TELEPHONY_CONFIGURATION' WHERE key = 'TWILIO_CONFIGURATION'; """) - + # Rename twilio_status_callbacks to telephony_status_callbacks in workflow_run logs op.execute(""" UPDATE workflow_runs @@ -57,15 +65,17 @@ def upgrade() -> None: ) WHERE logs::jsonb ? 'twilio_status_callbacks'; """) - - print("Migration complete: Added vonage to enum, renamed configuration key, and updated status callback keys") + + print( + "Migration complete: Added vonage to enum, renamed configuration key, and updated status callback keys" + ) def downgrade() -> None: """ Revert configuration key names and enum. """ - + # Revert telephony_status_callbacks to twilio_status_callbacks in workflow_run logs op.execute(""" UPDATE workflow_runs @@ -76,14 +86,14 @@ def downgrade() -> None: ) WHERE logs::jsonb ? 'telephony_status_callbacks'; """) - + # Revert key name op.execute(""" UPDATE organization_configurations SET key = 'TWILIO_CONFIGURATION' WHERE key = 'TELEPHONY_CONFIGURATION'; """) - + # Revert enum to previous state op.sync_enum_values( enum_schema="public", @@ -96,5 +106,5 @@ def downgrade() -> None: ], enum_values_to_rename=[], ) - - print("Downgrade complete: Reverted configuration key names and enum") \ No newline at end of file + + print("Downgrade complete: Reverted configuration key names and enum") diff --git a/api/enums.py b/api/enums.py index 7175e78..4462696 100644 --- a/api/enums.py +++ b/api/enums.py @@ -63,8 +63,12 @@ class OrganizationConfigurationKey(Enum): DISPOSITION_CODE_MAPPING = "DISPOSITION_CODE_MAPPING" DISPOSITION_MESSAGE_TEMPLATE = "DISPOSITION_MESSAGE_TEMPLATE" CONCURRENT_CALL_LIMIT = "CONCURRENT_CALL_LIMIT" - TELEPHONY_CONFIGURATION = "TELEPHONY_CONFIGURATION" # Stores all providers + active one - TWILIO_CONFIGURATION = "TWILIO_CONFIGURATION" # Deprecated - for backward compatibility + TELEPHONY_CONFIGURATION = ( + "TELEPHONY_CONFIGURATION" # Stores all providers + active one + ) + TWILIO_CONFIGURATION = ( + "TWILIO_CONFIGURATION" # Deprecated - for backward compatibility + ) class WorkflowStatus(Enum): diff --git a/api/requirements.txt b/api/requirements.txt index 9c3f1d4..17cdaab 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -1,4 +1,4 @@ -langfuse==3.4.0 +langfuse==3.9.3 fastapi==0.116.2 asyncpg==0.30.0 alembic==1.16.5 diff --git a/api/routes/organization.py b/api/routes/organization.py index 15e2e5e..a47966b 100644 --- a/api/routes/organization.py +++ b/api/routes/organization.py @@ -1,9 +1,10 @@ +from typing import Union + from fastapi import APIRouter, Depends, HTTPException from api.db import db_client from api.db.models import UserModel from api.enums import OrganizationConfigurationKey -from typing import Union from api.schemas.telephony_config import ( TelephonyConfigurationResponse, TwilioConfigurationRequest, @@ -19,14 +20,13 @@ router = APIRouter(prefix="/organizations", tags=["organizations"]) # Provider configuration constants PROVIDER_MASKED_FIELDS = { "twilio": ["account_sid", "auth_token"], - "vonage": ["private_key", "api_key", "api_secret"] + "vonage": ["private_key", "api_key", "api_secret"], } + # TODO: Make endpoints provider-agnostic @router.get("/telephony-config", response_model=TelephonyConfigurationResponse) -async def get_telephony_configuration( - user: UserModel = Depends(get_user) -): +async def get_telephony_configuration(user: UserModel = Depends(get_user)): """Get telephony configuration for the user's organization with masked sensitive fields.""" if not user.selected_organization_id: raise HTTPException(status_code=400, detail="No organization selected") @@ -40,11 +40,13 @@ async def get_telephony_configuration( return TelephonyConfigurationResponse() stored_provider = config.value.get("provider", "twilio") - + if stored_provider == "twilio": account_sid = config.value.get("account_sid", "") auth_token = config.value.get("auth_token", "") - from_numbers = config.value.get("from_numbers", []) if account_sid and auth_token else [] + from_numbers = ( + config.value.get("from_numbers", []) if account_sid and auth_token else [] + ) return TelephonyConfigurationResponse( twilio=TwilioConfigurationResponse( @@ -53,15 +55,19 @@ async def get_telephony_configuration( auth_token=mask_key(auth_token) if auth_token else "", from_numbers=from_numbers, ), - vonage=None + vonage=None, ) elif stored_provider == "vonage": application_id = config.value.get("application_id", "") private_key = config.value.get("private_key", "") api_key = config.value.get("api_key", "") api_secret = config.value.get("api_secret", "") - from_numbers = config.value.get("from_numbers", []) if application_id and private_key else [] - + from_numbers = ( + config.value.get("from_numbers", []) + if application_id and private_key + else [] + ) + return TelephonyConfigurationResponse( twilio=None, vonage=VonageConfigurationResponse( @@ -71,7 +77,7 @@ async def get_telephony_configuration( api_key=mask_key(api_key) if api_key else None, api_secret=mask_key(api_secret) if api_secret else None, from_numbers=from_numbers, - ) + ), ) else: return TelephonyConfigurationResponse() @@ -79,8 +85,8 @@ async def get_telephony_configuration( @router.post("/telephony-config") async def save_telephony_configuration( - request: Union[TwilioConfigurationRequest, VonageConfigurationRequest], - user: UserModel = Depends(get_user) + request: Union[TwilioConfigurationRequest, VonageConfigurationRequest], + user: UserModel = Depends(get_user), ): """Save telephony configuration for the user's organization.""" if not user.selected_organization_id: @@ -105,12 +111,14 @@ async def save_telephony_configuration( "provider": "vonage", "application_id": request.application_id, "private_key": request.private_key, - "api_key": getattr(request, 'api_key', None), - "api_secret": getattr(request, 'api_secret', None), + "api_key": getattr(request, "api_key", None), + "api_secret": getattr(request, "api_secret", None), "from_numbers": request.from_numbers, } else: - raise HTTPException(status_code=400, detail=f"Unsupported provider: {request.provider}") + raise HTTPException( + status_code=400, detail=f"Unsupported provider: {request.provider}" + ) if existing_config and existing_config.value: existing_provider = existing_config.value.get("provider") @@ -126,14 +134,16 @@ async def save_telephony_configuration( return {"message": "Telephony configuration saved successfully"} -def preserve_masked_fields(request, existing_config, config_value): +def preserve_masked_fields(request, existing_config, config_value): provider = request.provider masked_fields = PROVIDER_MASKED_FIELDS.get(provider, []) - + for field_name in masked_fields: if hasattr(request, field_name): field_value = getattr(request, field_name) # Check if field has a value and is a masked version of the existing value - if field_value and is_mask_of(field_value, existing_config.value.get(field_name, "")): + if field_value and is_mask_of( + field_value, existing_config.value.get(field_name, "") + ): config_value[field_name] = existing_config.value[field_name] diff --git a/api/routes/telephony.py b/api/routes/telephony.py index 1bdadcb..bcc3155 100644 --- a/api/routes/telephony.py +++ b/api/routes/telephony.py @@ -1,19 +1,19 @@ """ Generic telephony routes that work with any telephony provider. """ + import json import random from datetime import UTC, datetime -from typing import Annotated, Optional +from typing import Optional -from fastapi import APIRouter, Depends, Form, Header, HTTPException, Request, WebSocket +from fastapi import APIRouter, Depends, Header, HTTPException, Request, WebSocket from loguru import logger from pydantic import BaseModel from starlette.responses import HTMLResponse from api.db import db_client from api.db.models import UserModel -from api.enums import WorkflowRunMode from api.services.auth.depends import get_user from api.services.campaign.call_dispatcher import campaign_call_dispatcher from api.services.campaign.campaign_event_publisher import get_campaign_event_publisher @@ -32,6 +32,7 @@ class InitiateCallRequest(BaseModel): class StatusCallbackRequest(BaseModel): """Generic status callback that can handle different providers""" + # Common fields call_id: str status: str @@ -39,10 +40,10 @@ class StatusCallbackRequest(BaseModel): to_number: Optional[str] = None direction: Optional[str] = None duration: Optional[str] = None - + # Provider-specific fields stored as extra extra: dict = {} - + @classmethod def from_twilio(cls, data: dict): """Convert Twilio callback to generic format""" @@ -53,9 +54,9 @@ class StatusCallbackRequest(BaseModel): to_number=data.get("To"), direction=data.get("Direction"), duration=data.get("CallDuration") or data.get("Duration"), - extra=data + extra=data, ) - + @classmethod def from_vonage(cls, data: dict): """Convert Vonage event to generic format""" @@ -63,14 +64,14 @@ class StatusCallbackRequest(BaseModel): status_map = { "started": "initiated", "ringing": "ringing", - "answered": "answered", + "answered": "answered", "complete": "completed", "failed": "failed", "busy": "busy", "timeout": "no-answer", - "rejected": "busy" + "rejected": "busy", } - + return cls( call_id=data.get("uuid", ""), status=status_map.get(data.get("status", ""), data.get("status", "")), @@ -78,7 +79,7 @@ class StatusCallbackRequest(BaseModel): to_number=data.get("to"), direction=data.get("direction"), duration=data.get("duration"), - extra=data + extra=data, ) @@ -87,32 +88,32 @@ async def initiate_call( request: InitiateCallRequest, user: UserModel = Depends(get_user) ): """Initiate a call using the configured telephony provider.""" - + # Get the telephony provider for the organization provider = await get_telephony_provider(user.selected_organization_id) - + # Validate provider is configured if not provider.validate_config(): raise HTTPException( status_code=400, detail="telephony_not_configured", ) - + # Determine the workflow run mode based on provider type workflow_run_mode = provider.PROVIDER_NAME - + user_configuration = await db_client.get_user_configurations(user.id) - + phone_number = request.phone_number or user_configuration.test_phone_number - + if not phone_number: raise HTTPException( - status_code=400, - detail="Phone number must be provided in request or set in user configuration" + status_code=400, + detail="Phone number must be provided in request or set in user configuration", ) - + workflow_run_id = request.workflow_run_id - + if not workflow_run_id: workflow_run_name = f"WR-TEL-{random.randint(1000, 9999)}" workflow_run = await db_client.create_workflow_run( @@ -130,12 +131,12 @@ async def initiate_call( if not workflow_run: raise HTTPException(status_code=400, detail="Workflow run not found") workflow_run_name = workflow_run.name - + # Construct webhook URL based on provider type backend_endpoint = await TunnelURLProvider.get_tunnel_url() - + webhook_endpoint = provider.WEBHOOK_ENDPOINT - + webhook_url = ( f"https://{backend_endpoint}/api/v1/telephony/{webhook_endpoint}" f"?workflow_id={request.workflow_id}" @@ -143,35 +144,29 @@ async def initiate_call( f"&workflow_run_id={workflow_run_id}" f"&organization_id={user.selected_organization_id}" ) - + # Initiate call via provider result = await provider.initiate_call( to_number=phone_number, webhook_url=webhook_url, workflow_run_id=workflow_run_id, ) - + # Store provider type and any provider-specific metadata in workflow run context gathered_context = { "provider": provider.PROVIDER_NAME, - **(result.provider_metadata or {}) + **(result.provider_metadata or {}), } await db_client.update_workflow_run( - run_id=workflow_run_id, - gathered_context=gathered_context + run_id=workflow_run_id, gathered_context=gathered_context ) - - return { - "message": f"Call initiated successfully with run name {workflow_run_name}" - } + + return {"message": f"Call initiated successfully with run name {workflow_run_name}"} @router.post("/twiml", include_in_schema=False) async def handle_twiml_webhook( - workflow_id: int, - user_id: int, - workflow_run_id: int, - organization_id: int + workflow_id: int, user_id: int, workflow_run_id: int, organization_id: int ): """ Handle initial webhook from telephony provider. @@ -179,32 +174,32 @@ async def handle_twiml_webhook( """ provider = await get_telephony_provider(organization_id) - + response_content = await provider.get_webhook_response( workflow_id, user_id, workflow_run_id ) - + return HTMLResponse(content=response_content, media_type="application/xml") @router.get("/ncco", include_in_schema=False) async def handle_ncco_webhook( - workflow_id: int, - user_id: int, + workflow_id: int, + user_id: int, workflow_run_id: int, - organization_id: Optional[int] = None + organization_id: Optional[int] = None, ): """Handle NCCO (Nexmo Call Control Objects) webhook for Vonage. - + Returns JSON response instead of XML like TwiML. """ provider = await get_telephony_provider(organization_id or user_id) - + response_content = await provider.get_webhook_response( workflow_id, user_id, workflow_run_id ) - + return json.loads(response_content) @@ -218,36 +213,38 @@ async def websocket_endpoint( try: # Set the run context set_current_run_id(workflow_run_id) - + # Get workflow run to determine provider type workflow_run = await db_client.get_workflow_run(workflow_run_id) if not workflow_run: logger.error(f"Workflow run {workflow_run_id} not found") await websocket.close(code=4404, reason="Workflow run not found") return - + # Get workflow for organization info workflow = await db_client.get_workflow(workflow_id) if not workflow: logger.error(f"Workflow {workflow_id} not found") await websocket.close(code=4404, reason="Workflow not found") return - + # Extract provider type from workflow run context provider_type = None if workflow_run.gathered_context: provider_type = workflow_run.gathered_context.get("provider") - + if not provider_type: logger.error(f"No provider type found in workflow run {workflow_run_id}") await websocket.close(code=4400, reason="Provider type not found") return - - logger.info(f"WebSocket connected for {provider_type} provider, workflow_run {workflow_run_id}") - + + logger.info( + f"WebSocket connected for {provider_type} provider, workflow_run {workflow_run_id}" + ) + # Get the telephony provider instance provider = await get_telephony_provider(workflow.organization_id) - + # Verify the provider matches what was stored if provider.PROVIDER_NAME != provider_type: logger.error( @@ -255,10 +252,12 @@ async def websocket_endpoint( ) await websocket.close(code=4400, reason="Provider mismatch") return - + # Delegate to provider-specific handler - await provider.handle_websocket(websocket, workflow_id, user_id, workflow_run_id) - + await provider.handle_websocket( + websocket, workflow_id, user_id, workflow_run_id + ) + except Exception as e: logger.error(f"Error in WebSocket connection: {e}") await websocket.close(1011, "Internal server error") @@ -271,44 +270,46 @@ async def handle_twilio_status_callback( x_webhook_signature: Optional[str] = Header(None), ): """Handle Twilio-specific status callbacks.""" - + # Parse form data form_data = await request.form() callback_data = dict(form_data) - + logger.info( f"[run {workflow_run_id}] Received status callback: {json.dumps(callback_data)}" ) - + # Get workflow run to find organization workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id) if not workflow_run: logger.warning(f"Workflow run {workflow_run_id} not found for status callback") return {"status": "ignored", "reason": "workflow_run_not_found"} - + # Get workflow and provider workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id) if not workflow: logger.warning(f"Workflow {workflow_run.workflow_id} not found") return {"status": "ignored", "reason": "workflow_not_found"} - + provider = await get_telephony_provider(workflow.organization_id) - + if x_webhook_signature: backend_endpoint = await TunnelURLProvider.get_tunnel_url() full_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}" - + is_valid = await provider.verify_webhook_signature( full_url, callback_data, x_webhook_signature ) - + if not is_valid: - logger.warning(f"Invalid webhook signature for workflow run {workflow_run_id}") + logger.warning( + f"Invalid webhook signature for workflow run {workflow_run_id}" + ) return {"status": "error", "reason": "invalid_signature"} - + # Parse the callback data into generic format parsed_data = provider.parse_status_callback(callback_data) - + # Create StatusCallbackRequest from parsed data status_update = StatusCallbackRequest( call_id=parsed_data["call_id"], @@ -317,22 +318,20 @@ async def handle_twilio_status_callback( to_number=parsed_data.get("to_number"), direction=parsed_data.get("direction"), duration=parsed_data.get("duration"), - extra=parsed_data.get("extra", {}) + extra=parsed_data.get("extra", {}), ) - + # Process the status update await _process_status_update(workflow_run_id, status_update, workflow_run) - + return {"status": "success"} async def _process_status_update( - workflow_run_id: int, - status: StatusCallbackRequest, - workflow_run: any + workflow_run_id: int, status: StatusCallbackRequest, workflow_run: any ): """Process status updates from telephony providers.""" - + # Log the status callback telephony_callback_logs = workflow_run.logs.get("telephony_status_callbacks", []) telephony_callback_log = { @@ -340,31 +339,29 @@ async def _process_status_update( "timestamp": datetime.now(UTC).isoformat(), "call_id": status.call_id, "duration": status.duration, - **status.extra # Include provider-specific data + **status.extra, # Include provider-specific data } telephony_callback_logs.append(telephony_callback_log) - + # Update workflow run logs await db_client.update_workflow_run( run_id=workflow_run_id, logs={"telephony_status_callbacks": telephony_callback_logs}, ) - + # Handle call completion if status.status == "completed": logger.info( f"[run {workflow_run_id}] Call completed with duration: {status.duration}s" ) - + # Release concurrent slot if this was a campaign call if workflow_run.campaign_id: await campaign_call_dispatcher.release_call_slot(workflow_run_id) - + # Mark workflow run as completed - await db_client.update_workflow_run( - run_id=workflow_run_id, is_completed=True - ) - + await db_client.update_workflow_run(run_id=workflow_run_id, is_completed=True) + # Publish campaign event if applicable if workflow_run.campaign_id: publisher = await get_campaign_event_publisher() @@ -374,32 +371,40 @@ async def _process_status_update( queued_run_id=workflow_run.queued_run_id, call_duration=int(status.duration) if status.duration else 0, ) - + elif status.status in ["failed", "busy", "no-answer", "canceled"]: - logger.warning(f"[run {workflow_run_id}] Call failed with status: {status.status}") - + logger.warning( + f"[run {workflow_run_id}] Call failed with status: {status.status}" + ) + # Release concurrent slot for terminal statuses if this was a campaign call if workflow_run.campaign_id: await campaign_call_dispatcher.release_call_slot(workflow_run_id) - + # Check if retry is needed for campaign calls (busy/no-answer) if status.status in ["busy", "no-answer"] and workflow_run.campaign_id: publisher = await get_campaign_event_publisher() await publisher.publish_retry_needed( workflow_run_id=workflow_run_id, - reason=status.status.replace("-", "_"), # Convert no-answer to no_answer + reason=status.status.replace( + "-", "_" + ), # Convert no-answer to no_answer campaign_id=workflow_run.campaign_id, queued_run_id=workflow_run.queued_run_id, ) - + # Mark workflow run as completed with failure tags - call_tags = workflow_run.gathered_context.get("call_tags", []) if workflow_run.gathered_context else [] + call_tags = ( + workflow_run.gathered_context.get("call_tags", []) + if workflow_run.gathered_context + else [] + ) call_tags.extend(["not_connected", f"telephony_{status.status.lower()}"]) - + await db_client.update_workflow_run( run_id=workflow_run_id, is_completed=True, - gathered_context={"call_tags": call_tags} + gathered_context={"call_tags": call_tags}, ) @@ -409,20 +414,20 @@ async def handle_vonage_events( workflow_run_id: int, ): """Handle Vonage-specific event webhooks. - + Vonage sends all call events to a single endpoint. Events include: started, ringing, answered, complete, failed, etc. """ # Parse the event data event_data = await request.json() logger.info(f"[run {workflow_run_id}] Received Vonage event: {event_data}") - + # Get workflow run for processing workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id) if not workflow_run: logger.error(f"[run {workflow_run_id}] Workflow run not found") return {"status": "error", "message": "Workflow run not found"} - + # For a completed call that includes cost info, capture it immediately if event_data.get("status") == "completed": # Vonage sometimes includes price info in the webhook @@ -436,27 +441,32 @@ async def handle_vonage_events( if "rate" in event_data: cost_info["vonage_webhook_rate"] = float(event_data["rate"]) if "duration" in event_data: - cost_info["vonage_webhook_duration"] = int(event_data["duration"]) - + cost_info["vonage_webhook_duration"] = int( + event_data["duration"] + ) + await db_client.update_workflow_run( - run_id=workflow_run_id, - cost_info=cost_info + run_id=workflow_run_id, cost_info=cost_info + ) + logger.info( + f"[run {workflow_run_id}] Captured Vonage cost info from webhook" ) - logger.info(f"[run {workflow_run_id}] Captured Vonage cost info from webhook") except Exception as e: - logger.error(f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}") - + logger.error( + f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}" + ) + # Get workflow and provider workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id) if not workflow: logger.error(f"[run {workflow_run_id}] Workflow not found") return {"status": "error", "message": "Workflow not found"} - + provider = await get_telephony_provider(workflow.organization_id) - + # Parse the event data into generic format parsed_data = provider.parse_status_callback(event_data) - + # Create StatusCallbackRequest from parsed data status_update = StatusCallbackRequest( call_id=parsed_data["call_id"], @@ -465,11 +475,11 @@ async def handle_vonage_events( to_number=parsed_data.get("to_number"), direction=parsed_data.get("direction"), duration=parsed_data.get("duration"), - extra=parsed_data.get("extra", {}) + extra=parsed_data.get("extra", {}), ) - + # Process the status update await _process_status_update(workflow_run_id, status_update, workflow_run) - + # Return 204 No Content as expected by Vonage - return {"status": "ok"} \ No newline at end of file + return {"status": "ok"} diff --git a/api/routes/webrtc_signaling.py b/api/routes/webrtc_signaling.py index fedee8b..9c5f229 100644 --- a/api/routes/webrtc_signaling.py +++ b/api/routes/webrtc_signaling.py @@ -124,7 +124,9 @@ class SignalingManager: ) else: # Create new connection using correct SmallWebRTC API - pc = SmallWebRTCConnection(ice_servers=ice_servers, connection_timeout_secs=60) + pc = SmallWebRTCConnection( + ice_servers=ice_servers, connection_timeout_secs=60 + ) # Set the pc_id before initialization so it's available in get_answer() pc._pc_id = pc_id diff --git a/api/services/campaign/call_dispatcher.py b/api/services/campaign/call_dispatcher.py index ba091e9..fd1d610 100644 --- a/api/services/campaign/call_dispatcher.py +++ b/api/services/campaign/call_dispatcher.py @@ -7,10 +7,10 @@ from loguru import logger from api.db import db_client from api.db.models import QueuedRunModel, WorkflowRunModel -from api.enums import OrganizationConfigurationKey, WorkflowRunMode +from api.enums import OrganizationConfigurationKey from api.services.campaign.rate_limiter import rate_limiter -from api.services.telephony.factory import get_telephony_provider from api.services.telephony.base import TelephonyProvider +from api.services.telephony.factory import get_telephony_provider from api.utils.tunnel import TunnelURLProvider @@ -238,7 +238,7 @@ class CampaignCallDispatcher: f"&campaign_id={campaign.id}" f"&organization_id={campaign.organization_id}" ) - + call_result = await provider.initiate_call( to_number=phone_number, webhook_url=webhook_url, @@ -255,7 +255,9 @@ class CampaignCallDispatcher: ) # Update workflow run as failed - telephony_callback_logs = workflow_run.logs.get("telephony_status_callbacks", []) + telephony_callback_logs = workflow_run.logs.get( + "telephony_status_callbacks", [] + ) telephony_callback_log = { "status": "failed", "timestamp": datetime.now(UTC).isoformat(), diff --git a/api/services/looptalk/core/pipeline_builder.py b/api/services/looptalk/core/pipeline_builder.py index 6bbc62e..95ece81 100644 --- a/api/services/looptalk/core/pipeline_builder.py +++ b/api/services/looptalk/core/pipeline_builder.py @@ -24,6 +24,9 @@ from api.services.workflow.dto import ReactFlowDTO from api.services.workflow.pipecat_engine import PipecatEngine from api.services.workflow.workflow import WorkflowGraph from pipecat.pipeline.pipeline import Pipeline +from pipecat.processors.aggregators.llm_response_universal import ( + LLMContextAggregatorPair, +) from pipecat.processors.filters.stt_mute_filter import ( STTMuteConfig, STTMuteFilter, @@ -83,7 +86,8 @@ class LoopTalkPipelineBuilder: audio_buffer, audio_synchronizer, transcript, context = ( create_pipeline_components(audio_config) ) - context_aggregator = llm.create_context_aggregator(context) + + context_aggregator = LLMContextAggregatorPair(context) # Get workflow graph workflow_graph = WorkflowGraph( @@ -113,7 +117,6 @@ class LoopTalkPipelineBuilder: pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor( max_call_duration_seconds=300, max_duration_end_task_callback=engine.create_max_duration_callback(), - llm_generated_text_callback=engine.create_llm_generated_text_callback(), generation_started_callback=engine.create_generation_started_callback(), ) diff --git a/api/services/looptalk/orchestrator.py b/api/services/looptalk/orchestrator.py index eac0778..c88e8b2 100644 --- a/api/services/looptalk/orchestrator.py +++ b/api/services/looptalk/orchestrator.py @@ -272,14 +272,6 @@ class LoopTalkTestOrchestrator: await task.cancel() - # Connect the context aggregator events to engine - @assistant_context_aggregator.event_handler("on_push_aggregation") - async def on_assistant_aggregator_push_context(_aggregator): - logger.debug( - "Assistant aggregator push context – flushing pending transitions" - ) - await engine.flush_pending_transitions() - # Register custom audio and transcript handlers for LoopTalk await self._register_looptalk_handlers( audio_synchronizer, transcript, test_session_id, role diff --git a/api/services/pipecat/engine_pre_aggregator_processor.py b/api/services/pipecat/engine_pre_aggregator_processor.py deleted file mode 100644 index d8f0f29..0000000 --- a/api/services/pipecat/engine_pre_aggregator_processor.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Engine Pre-Aggregator Processor - -This processor sits before the user context aggregator in the pipeline and handles -engine-specific callbacks for frames that need to be processed before aggregation. -This ensures the engine can update context before the aggregator generates LLM frames. -""" - -from typing import Awaitable, Callable, Optional - -from loguru import logger - -from api.services.pipecat.exceptions import VoicemailDetectedException -from pipecat.frames.frames import ( - Frame, - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, -) -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor - - -class EnginePreAggregatorProcessor(FrameProcessor): - """ - Processor that handles engine callbacks before user context aggregation. - - This processor is positioned before the user context aggregator to ensure - the engine can update LLM context before aggregation occurs. - """ - - def __init__( - self, - user_started_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None, - user_stopped_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None, - **kwargs, - ): - super().__init__(**kwargs) - self._user_started_speaking_callback = user_started_speaking_callback - self._user_stopped_speaking_callback = user_stopped_speaking_callback - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - # Handle frames that need engine processing before aggregation - if isinstance(frame, UserStartedSpeakingFrame): - await self._handle_user_started_speaking() - elif isinstance(frame, UserStoppedSpeakingFrame): - try: - await self._handle_user_stopped_speaking() - except VoicemailDetectedException: - # We have detected voicemail, lets not - # forward the UserStoppedSpeakingFrame, so that - # we don't issue an llm call from user context - # aggregator - logger.debug("Voicemail detected, not pushing UserStoppedSpeakingFrame") - return - - # Always push the frame downstream - await self.push_frame(frame, direction) - - async def _handle_user_started_speaking(self): - """Handle UserStartedSpeakingFrame before aggregation.""" - if self._user_started_speaking_callback: - # logger.debug("Engine pre-aggregator: User started speaking") - await self._user_started_speaking_callback() - - async def _handle_user_stopped_speaking(self): - """Handle UserStoppedSpeakingFrame before aggregation.""" - if self._user_stopped_speaking_callback: - # logger.debug("Engine pre-aggregator: User stopped speaking") - await self._user_stopped_speaking_callback() diff --git a/api/services/pipecat/pipeline_builder.py b/api/services/pipecat/pipeline_builder.py index d4562f5..40d6619 100644 --- a/api/services/pipecat/pipeline_builder.py +++ b/api/services/pipecat/pipeline_builder.py @@ -9,7 +9,7 @@ from api.constants import ( from api.services.pipecat.audio_config import AudioConfig from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.task import PipelineParams, PipelineTask -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.audio.audio_buffer_processor import AudioBuffer from pipecat.processors.audio.audio_synchronizer import AudioSynchronizer from pipecat.processors.transcript_processor import TranscriptProcessor @@ -39,7 +39,7 @@ def create_pipeline_components(audio_config: AudioConfig, engine: "PipecatEngine assistant_correct_aggregation_callback=engine.create_aggregation_correction_callback() ) - context = OpenAILLMContext() + context = LLMContext() return audio_buffer, audio_synchronizer, transcript, context @@ -58,7 +58,6 @@ def build_pipeline( stt_mute_filter, pipeline_metrics_aggregator, user_idle_disconnect, - engine_pre_aggregator_processor=None, ): """Build the main pipeline with all components""" # Register processors with synchronizer for merged audio @@ -69,16 +68,12 @@ def build_pipeline( processors = [ transport.input(), # Transport user input audio_buffer.input(), # Record input audio (only processes InputAudioRawFrame) - stt_mute_filter, stt, # STT can now have audio_passthrough=False + stt_mute_filter, # STTMuteFilters don't let VAD related events pass through if muted user_idle_disconnect, transcript.user(), ] - # Insert engine pre-aggregator processor if provided (before user aggregator) - if engine_pre_aggregator_processor: - processors.append(engine_pre_aggregator_processor) - processors.extend( [ user_context_aggregator, diff --git a/api/services/pipecat/pipeline_engine_callbacks_processor.py b/api/services/pipecat/pipeline_engine_callbacks_processor.py index 89aca49..e2a95d1 100644 --- a/api/services/pipecat/pipeline_engine_callbacks_processor.py +++ b/api/services/pipecat/pipeline_engine_callbacks_processor.py @@ -7,7 +7,6 @@ from pipecat.frames.frames import ( Frame, HeartbeatFrame, LLMFullResponseStartFrame, - LLMGeneratedTextFrame, LLMTextFrame, StartFrame, TTSSpeakFrame, @@ -26,7 +25,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor): self, max_call_duration_seconds: int = 300, max_duration_end_task_callback: Optional[Callable[[], Awaitable[None]]] = None, - llm_generated_text_callback: Optional[Callable[[], Awaitable[None]]] = None, generation_started_callback: Optional[Callable[[], Awaitable[None]]] = None, llm_text_frame_callback: Optional[Callable[[str], Awaitable[None]]] = None, ): @@ -34,7 +32,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor): self._start_time = None self._max_call_duration_seconds = max_call_duration_seconds self._max_duration_end_task_callback = max_duration_end_task_callback - self._llm_generated_text_callback = llm_generated_text_callback self._generation_started_callback = generation_started_callback self._llm_text_frame_callback = llm_text_frame_callback self._end_task_frame_pushed = False @@ -46,8 +43,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor): await self._start(frame) elif isinstance(frame, HeartbeatFrame): await self._check_call_duration() - elif isinstance(frame, LLMGeneratedTextFrame): - await self._generated_text_frame(frame) elif isinstance(frame, LLMFullResponseStartFrame): await self._generation_started() elif ( @@ -74,11 +69,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor): "Max call duration exceeded. Skipping EndTaskFrame since already sent" ) - async def _generated_text_frame(self, _: LLMGeneratedTextFrame): - """Handle LLMGeneratedTextFrame.""" - if self._llm_generated_text_callback is not None: - await self._llm_generated_text_callback() - async def _generation_started(self): if self._generation_started_callback: await self._generation_started_callback() diff --git a/api/services/pipecat/run_pipeline.py b/api/services/pipecat/run_pipeline.py index 6a92969..bf41477 100644 --- a/api/services/pipecat/run_pipeline.py +++ b/api/services/pipecat/run_pipeline.py @@ -7,9 +7,6 @@ from api.db import db_client from api.db.models import WorkflowModel from api.enums import WorkflowRunMode from api.services.pipecat.audio_config import AudioConfig, create_audio_config -from api.services.pipecat.engine_pre_aggregator_processor import ( - EnginePreAggregatorProcessor, -) from api.services.pipecat.event_handlers import ( register_audio_data_handler, register_task_event_handler, @@ -43,6 +40,9 @@ from api.services.workflow.pipecat_engine import PipecatEngine from api.services.workflow.workflow import WorkflowGraph from pipecat.pipeline.runner import PipelineRunner from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams +from pipecat.processors.aggregators.llm_response_universal import ( + LLMContextAggregatorPair, +) from pipecat.processors.filters.stt_mute_filter import ( STTMuteConfig, STTMuteFilter, @@ -119,7 +119,7 @@ async def run_pipeline_vonage( user_id: int, ): """Run pipeline for Vonage WebSocket connections. - + Vonage uses raw PCM audio over WebSocket instead of base64-encoded μ-law. The audio is transmitted as binary frames at 16kHz by default. """ @@ -137,7 +137,9 @@ async def run_pipeline_vonage( if "vad_configuration" in workflow.workflow_configurations: vad_config = workflow.workflow_configurations["vad_configuration"] if "ambient_noise_configuration" in workflow.workflow_configurations: - ambient_noise_config = workflow.workflow_configurations["ambient_noise_configuration"] + ambient_noise_config = workflow.workflow_configurations[ + "ambient_noise_configuration" + ] try: # Setup audio config for Vonage using the centralized config @@ -355,21 +357,14 @@ async def _run_pipeline( expect_stripped_words=True, correct_aggregation_callback=engine.create_aggregation_correction_callback(), ) - context_aggregator = llm.create_context_aggregator( + context_aggregator = LLMContextAggregatorPair( context, assistant_params=assistant_params ) - # Create engine pre-aggregator processor for speaking events - engine_pre_aggregator_processor = EnginePreAggregatorProcessor( - user_started_speaking_callback=engine.create_user_started_speaking_callback(), - user_stopped_speaking_callback=engine.create_user_stopped_speaking_callback(), - ) - # Create usage metrics aggregator with engine's callback pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor( max_call_duration_seconds=max_call_duration_seconds, max_duration_end_task_callback=engine.create_max_duration_callback(), - llm_generated_text_callback=engine.create_llm_generated_text_callback(), generation_started_callback=engine.create_generation_started_callback(), llm_text_frame_callback=engine.handle_llm_text_frame, # Note: speaking event callbacks are now handled by pre-aggregator processor @@ -396,11 +391,6 @@ async def _run_pipeline( user_context_aggregator = context_aggregator.user() assistant_context_aggregator = context_aggregator.assistant() - @assistant_context_aggregator.event_handler("on_push_aggregation") - async def on_assistant_aggregator_push_context(_aggregator): - logger.debug("Assistant aggregator push context – flushing pending transitions") - await engine.flush_pending_transitions(source="context_push") - # Build the pipeline with the STT mute filter and context controller pipeline = build_pipeline( transport, @@ -416,7 +406,6 @@ async def _run_pipeline( stt_mute_filter, pipeline_metrics_aggregator, user_idle_disconnect, - engine_pre_aggregator_processor=engine_pre_aggregator_processor, ) # Create pipeline task with audio configuration diff --git a/api/services/pipecat/transport_setup.py b/api/services/pipecat/transport_setup.py index 480cf67..a674ce4 100644 --- a/api/services/pipecat/transport_setup.py +++ b/api/services/pipecat/transport_setup.py @@ -165,14 +165,15 @@ async def create_vonage_transport( # Use the factory to load config from database from api.services.telephony.factory import load_telephony_config + config = await load_telephony_config(organization_id) - + if config.get("provider") != "vonage": raise ValueError(f"Expected Vonage provider, got {config.get('provider')}") application_id = config.get("application_id") private_key = config.get("private_key") - + if not application_id or not private_key: raise ValueError( f"Incomplete Vonage configuration for organization {organization_id}" @@ -186,8 +187,8 @@ async def create_vonage_transport( private_key=private_key, params=VonageFrameSerializer.InputParams( vonage_sample_rate=audio_config.transport_in_sample_rate, - sample_rate=audio_config.pipeline_sample_rate - ) + sample_rate=audio_config.pipeline_sample_rate, + ), ) # Important: Vonage uses binary WebSocket mode, not text diff --git a/api/services/telephony/base.py b/api/services/telephony/base.py index 598081e..8668872 100644 --- a/api/services/telephony/base.py +++ b/api/services/telephony/base.py @@ -3,6 +3,7 @@ Base telephony provider interface for abstracting telephony services. This allows easy switching between different providers (Twilio, Vonage, etc.) while keeping business logic decoupled from specific implementations. """ + from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -14,10 +15,15 @@ if TYPE_CHECKING: @dataclass class CallInitiationResult: """Standardized response from initiate_call across all providers.""" - call_id: str # Provider's call identifier (SID for Twilio, UUID for Vonage) - status: str # Initial status (e.g., "queued", "initiated", "started") - provider_metadata: Dict[str, Any] = field(default_factory=dict) # Data that needs to be persisted - raw_response: Dict[str, Any] = field(default_factory=dict) # Full provider response for debugging + + call_id: str # Provider's call identifier (SID for Twilio, UUID for Vonage) + status: str # Initial status (e.g., "queued", "initiated", "started") + provider_metadata: Dict[str, Any] = field( + default_factory=dict + ) # Data that needs to be persisted + raw_response: Dict[str, Any] = field( + default_factory=dict + ) # Full provider response for debugging class TelephonyProvider(ABC): @@ -25,6 +31,7 @@ class TelephonyProvider(ABC): Abstract base class for telephony providers. All telephony providers must implement these core methods. """ + PROVIDER_NAME = None WEBHOOK_ENDPOINT = None @@ -38,13 +45,13 @@ class TelephonyProvider(ABC): ) -> CallInitiationResult: """ Initiate an outbound call. - + Args: to_number: The destination phone number webhook_url: The URL to receive call events workflow_run_id: Optional workflow run ID for tracking **kwargs: Provider-specific additional parameters - + Returns: CallInitiationResult with standardized call details """ @@ -54,10 +61,10 @@ class TelephonyProvider(ABC): async def get_call_status(self, call_id: str) -> Dict[str, Any]: """ Get the current status of a call. - + Args: call_id: The provider-specific call identifier - + Returns: Dict containing call status information """ @@ -67,7 +74,7 @@ class TelephonyProvider(ABC): async def get_available_phone_numbers(self) -> List[str]: """ Get list of available phone numbers for this provider. - + Returns: List of phone numbers that can be used for outbound calls """ @@ -77,7 +84,7 @@ class TelephonyProvider(ABC): def validate_config(self) -> bool: """ Validate that the provider is properly configured. - + Returns: True if configuration is valid, False otherwise """ @@ -89,12 +96,12 @@ class TelephonyProvider(ABC): ) -> bool: """ Verify webhook signature for security. - + Args: url: The webhook URL params: The webhook parameters signature: The signature to verify - + Returns: True if signature is valid, False otherwise """ @@ -106,12 +113,12 @@ class TelephonyProvider(ABC): ) -> str: """ Generate the initial webhook response for starting a call session. - + Args: workflow_id: The workflow ID user_id: The user ID workflow_run_id: The workflow run ID - + Returns: Provider-specific response (e.g., TwiML for Twilio) """ @@ -121,10 +128,10 @@ class TelephonyProvider(ABC): async def get_call_cost(self, call_id: str) -> Dict[str, Any]: """ Get cost information for a completed call. - + Args: call_id: Provider-specific call identifier (SID for Twilio, UUID for Vonage) - + Returns: Dict containing: - cost_usd: The cost in USD as float @@ -138,10 +145,10 @@ class TelephonyProvider(ABC): def parse_status_callback(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Parse provider-specific status callback data into generic format. - + Args: data: Raw callback data from the provider - + Returns: Dict with standardized fields: - call_id: Provider's call identifier @@ -163,14 +170,14 @@ class TelephonyProvider(ABC): ) -> None: """ Handle provider-specific WebSocket connection for real-time call audio. - + This method encapsulates all provider-specific WebSocket handshake and message routing logic, keeping the main websocket endpoint clean. - + Args: websocket: The WebSocket connection workflow_id: The workflow ID user_id: The user ID workflow_run_id: The workflow run ID """ - pass \ No newline at end of file + pass diff --git a/api/services/telephony/factory.py b/api/services/telephony/factory.py index 4762572..f89f8b5 100644 --- a/api/services/telephony/factory.py +++ b/api/services/telephony/factory.py @@ -3,8 +3,8 @@ Factory for creating telephony providers. Handles configuration loading from environment (OSS) or database (SaaS). The providers themselves don't know or care where config comes from. """ -import os -from typing import Any, Dict, Optional + +from typing import Any, Dict from loguru import logger @@ -18,36 +18,36 @@ from api.services.telephony.providers.vonage_provider import VonageProvider async def load_telephony_config(organization_id: int) -> Dict[str, Any]: """ Load telephony configuration from database. - + Args: organization_id: Organization ID for database config - + Returns: Configuration dictionary with provider type and credentials - + Raises: ValueError: If no configuration found for the organization """ if not organization_id: raise ValueError("Organization ID is required to load telephony configuration") - + logger.debug(f"Loading telephony config from database for org {organization_id}") - + config = await db_client.get_configuration( organization_id, OrganizationConfigurationKey.TELEPHONY_CONFIGURATION.value, ) - + if config and config.value: # Simple single-provider format provider = config.value.get("provider", "twilio") - + if provider == "twilio": return { "provider": "twilio", "account_sid": config.value.get("account_sid"), "auth_token": config.value.get("auth_token"), - "from_numbers": config.value.get("from_numbers", []) + "from_numbers": config.value.get("from_numbers", []), } elif provider == "vonage": return { @@ -56,41 +56,41 @@ async def load_telephony_config(organization_id: int) -> Dict[str, Any]: "private_key": config.value.get("private_key"), "api_key": config.value.get("api_key"), "api_secret": config.value.get("api_secret"), - "from_numbers": config.value.get("from_numbers", []) + "from_numbers": config.value.get("from_numbers", []), } else: raise ValueError(f"Unknown provider in config: {provider}") - - raise ValueError(f"No telephony configuration found for organization {organization_id}") + + raise ValueError( + f"No telephony configuration found for organization {organization_id}" + ) -async def get_telephony_provider( - organization_id: int -) -> TelephonyProvider: +async def get_telephony_provider(organization_id: int) -> TelephonyProvider: """ Factory function to create telephony providers. - + Args: organization_id: Organization ID (required) - + Returns: Configured telephony provider instance - + Raises: ValueError: If provider type is unknown or configuration is invalid """ # Load configuration config = await load_telephony_config(organization_id) - + provider_type = config.get("provider", "twilio") logger.info(f"Creating {provider_type} telephony provider") - + # Create provider instance with configuration if provider_type == "twilio": return TwilioProvider(config) - + elif provider_type == "vonage": return VonageProvider(config) - + else: raise ValueError(f"Unknown telephony provider: {provider_type}") diff --git a/api/services/telephony/providers/__init__.py b/api/services/telephony/providers/__init__.py index 5c8985e..16d28ea 100644 --- a/api/services/telephony/providers/__init__.py +++ b/api/services/telephony/providers/__init__.py @@ -1 +1 @@ -# Telephony provider implementations \ No newline at end of file +# Telephony provider implementations diff --git a/api/services/telephony/providers/twilio_provider.py b/api/services/telephony/providers/twilio_provider.py index cb901e2..5e42ecf 100644 --- a/api/services/telephony/providers/twilio_provider.py +++ b/api/services/telephony/providers/twilio_provider.py @@ -1,6 +1,7 @@ """ Twilio implementation of the TelephonyProvider interface. """ + import json import random from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -9,9 +10,9 @@ import aiohttp from loguru import logger from twilio.request_validator import RequestValidator +from api.enums import WorkflowRunMode from api.services.telephony.base import CallInitiationResult, TelephonyProvider from api.utils.tunnel import TunnelURLProvider -from api.enums import WorkflowRunMode if TYPE_CHECKING: from fastapi import WebSocket @@ -22,14 +23,14 @@ class TwilioProvider(TelephonyProvider): Twilio implementation of TelephonyProvider. Accepts configuration and works the same regardless of OSS/SaaS mode. """ - + PROVIDER_NAME = WorkflowRunMode.TWILIO.value WEBHOOK_ENDPOINT = "twiml" def __init__(self, config: Dict[str, Any]): """ Initialize TwilioProvider with configuration. - + Args: config: Dictionary containing: - account_sid: Twilio Account SID @@ -39,11 +40,11 @@ class TwilioProvider(TelephonyProvider): self.account_sid = config.get("account_sid") self.auth_token = config.get("auth_token") self.from_numbers = config.get("from_numbers", []) - + # Handle both single number (string) and multiple numbers (list) if isinstance(self.from_numbers, str): self.from_numbers = [self.from_numbers] - + self.base_url = f"https://api.twilio.com/2010-04-01/Accounts/{self.account_sid}" async def initiate_call( @@ -58,32 +59,35 @@ class TwilioProvider(TelephonyProvider): """ if not self.validate_config(): raise ValueError("Twilio provider not properly configured") - + endpoint = f"{self.base_url}/Calls.json" - + # Select a random phone number from_number = random.choice(self.from_numbers) logger.info(f"Selected phone number {from_number} for outbound call") - + # Prepare call data - data = { - "To": to_number, - "From": from_number, - "Url": webhook_url - } - + data = {"To": to_number, "From": from_number, "Url": webhook_url} + # Add status callback if workflow_run_id provided if workflow_run_id: backend_endpoint = await TunnelURLProvider.get_tunnel_url() callback_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}" - data.update({ - "StatusCallback": callback_url, - "StatusCallbackEvent": ["initiated", "ringing", "answered", "completed"], - "StatusCallbackMethod": "POST" - }) - + data.update( + { + "StatusCallback": callback_url, + "StatusCallbackEvent": [ + "initiated", + "ringing", + "answered", + "completed", + ], + "StatusCallbackMethod": "POST", + } + ) + data.update(kwargs) - + # Make the API request async with aiohttp.ClientSession() as session: auth = aiohttp.BasicAuth(self.account_sid, self.auth_token) @@ -91,14 +95,14 @@ class TwilioProvider(TelephonyProvider): if response.status != 201: error_data = await response.json() raise Exception(f"Failed to initiate call: {error_data}") - + response_data = await response.json() - + return CallInitiationResult( call_id=response_data["sid"], status=response_data.get("status", "queued"), provider_metadata={}, # Twilio doesn't need to persist extra data - raw_response=response_data + raw_response=response_data, ) async def get_call_status(self, call_id: str) -> Dict[str, Any]: @@ -107,16 +111,16 @@ class TwilioProvider(TelephonyProvider): """ if not self.validate_config(): raise ValueError("Twilio provider not properly configured") - + endpoint = f"{self.base_url}/Calls/{call_id}.json" - + async with aiohttp.ClientSession() as session: auth = aiohttp.BasicAuth(self.account_sid, self.auth_token) async with session.get(endpoint, auth=auth) as response: if response.status != 200: error_data = await response.json() raise Exception(f"Failed to get call status: {error_data}") - + return await response.json() async def get_available_phone_numbers(self) -> List[str]: @@ -129,11 +133,7 @@ class TwilioProvider(TelephonyProvider): """ Validate Twilio configuration. """ - return bool( - self.account_sid and - self.auth_token and - self.from_numbers - ) + return bool(self.account_sid and self.auth_token and self.from_numbers) async def verify_webhook_signature( self, url: str, params: Dict[str, Any], signature: str @@ -144,7 +144,7 @@ class TwilioProvider(TelephonyProvider): if not self.auth_token: logger.error("No auth token available for webhook signature verification") return False - + validator = RequestValidator(self.auth_token) return validator.validate(url, params, signature) @@ -155,7 +155,7 @@ class TwilioProvider(TelephonyProvider): Generate TwiML response for starting a call session. """ backend_endpoint = await TunnelURLProvider.get_tunnel_url() - + twiml_content = f""" @@ -168,15 +168,15 @@ class TwilioProvider(TelephonyProvider): async def get_call_cost(self, call_id: str) -> Dict[str, Any]: """ Get cost information for a completed Twilio call. - + Args: call_id: The Twilio Call SID - + Returns: Dict containing cost information """ endpoint = f"{self.base_url}/Calls/{call_id}.json" - + try: async with aiohttp.ClientSession() as session: auth = aiohttp.BasicAuth(self.account_sid, self.auth_token) @@ -188,34 +188,29 @@ class TwilioProvider(TelephonyProvider): "cost_usd": 0.0, "duration": 0, "status": "error", - "error": str(error_data) + "error": str(error_data), } - + call_data = await response.json() - + # Twilio returns price as a negative string (e.g., "-0.0085") price_str = call_data.get("price", "0") cost_usd = abs(float(price_str)) if price_str else 0.0 - + # Duration is in seconds as a string duration = int(call_data.get("duration", "0")) - + return { "cost_usd": cost_usd, "duration": duration, "status": call_data.get("status", "unknown"), "price_unit": call_data.get("price_unit", "USD"), - "raw_response": call_data + "raw_response": call_data, } - + except Exception as e: logger.error(f"Exception fetching Twilio call cost: {e}") - return { - "cost_usd": 0.0, - "duration": 0, - "status": "error", - "error": str(e) - } + return {"cost_usd": 0.0, "duration": 0, "status": "error", "error": str(e)} def parse_status_callback(self, data: Dict[str, Any]) -> Dict[str, Any]: """ @@ -228,7 +223,7 @@ class TwilioProvider(TelephonyProvider): "to_number": data.get("To"), "direction": data.get("Direction"), "duration": data.get("CallDuration") or data.get("Duration"), - "extra": data # Include all original data + "extra": data, # Include all original data } async def handle_websocket( @@ -240,36 +235,38 @@ class TwilioProvider(TelephonyProvider): ) -> None: """ Handle Twilio-specific WebSocket connection. - + Twilio sends: 1. "connected" event first 2. "start" event with streamSid and callSid 3. Then audio messages """ from api.services.pipecat.run_pipeline import run_pipeline_twilio - + try: # Wait for "connected" event first_msg = await websocket.receive_text() msg = json.loads(first_msg) - + if msg.get("event") != "connected": logger.error(f"Expected 'connected' event, got: {msg.get('event')}") await websocket.close(code=4400, reason="Expected connected event") return - - logger.debug(f"Twilio WebSocket connected for workflow_run {workflow_run_id}") - + + logger.debug( + f"Twilio WebSocket connected for workflow_run {workflow_run_id}" + ) + # Wait for "start" event with stream details start_msg = await websocket.receive_text() logger.debug(f"Received start message: {start_msg}") - + start_msg = json.loads(start_msg) if start_msg.get("event") != "start": logger.error("Expected 'start' event second") await websocket.close(code=4400, reason="Expected start event") return - + # Extract Twilio-specific identifiers try: stream_sid = start_msg["start"]["streamSid"] @@ -278,12 +275,12 @@ class TwilioProvider(TelephonyProvider): logger.error("Missing streamSid or callSid in start message") await websocket.close(code=4400, reason="Missing stream identifiers") return - + # Run the Twilio pipeline await run_pipeline_twilio( websocket, stream_sid, call_sid, workflow_id, workflow_run_id, user_id ) - + except Exception as e: logger.error(f"Error in Twilio WebSocket handler: {e}") - raise \ No newline at end of file + raise diff --git a/api/services/telephony/providers/vonage_provider.py b/api/services/telephony/providers/vonage_provider.py index 35e1e2d..a69e0a0 100644 --- a/api/services/telephony/providers/vonage_provider.py +++ b/api/services/telephony/providers/vonage_provider.py @@ -1,6 +1,7 @@ """ Vonage (Nexmo) implementation of the TelephonyProvider interface. """ + import json import random import time @@ -10,9 +11,9 @@ import aiohttp import jwt from loguru import logger +from api.enums import WorkflowRunMode from api.services.telephony.base import CallInitiationResult, TelephonyProvider from api.utils.tunnel import TunnelURLProvider -from api.enums import WorkflowRunMode if TYPE_CHECKING: from fastapi import WebSocket @@ -23,14 +24,14 @@ class VonageProvider(TelephonyProvider): Vonage implementation of TelephonyProvider. Uses JWT authentication and NCCO for call control. """ - + PROVIDER_NAME = WorkflowRunMode.VONAGE.value WEBHOOK_ENDPOINT = "ncco" - + def __init__(self, config: Dict[str, Any]): """ Initialize VonageProvider with configuration. - + Args: config: Dictionary containing: - api_key: Vonage API Key @@ -44,25 +45,27 @@ class VonageProvider(TelephonyProvider): self.application_id = config.get("application_id") self.private_key = config.get("private_key") self.from_numbers = config.get("from_numbers", []) - + # Handle both single number (string) and multiple numbers (list) if isinstance(self.from_numbers, str): self.from_numbers = [self.from_numbers] - + self.base_url = "https://api.nexmo.com" def _generate_jwt(self) -> str: """Generate JWT token for Vonage API authentication.""" if not self.application_id or not self.private_key: - raise ValueError("Application ID and private key required for JWT generation") - + raise ValueError( + "Application ID and private key required for JWT generation" + ) + claims = { "application_id": self.application_id, "iat": int(time.time()), "exp": int(time.time()) + 3600, - "jti": str(time.time()) + "jti": str(time.time()), } - + return jwt.encode(claims, self.private_key, algorithm="RS256") async def initiate_call( @@ -77,68 +80,57 @@ class VonageProvider(TelephonyProvider): """ if not self.validate_config(): raise ValueError("Vonage provider not properly configured") - + endpoint = f"{self.base_url}/v1/calls" - + # Select a random phone number from_number = random.choice(self.from_numbers) # Remove '+' prefix for Vonage from_number = from_number.replace("+", "") to_number = to_number.replace("+", "") - + logger.info(f"Selected phone number {from_number} for outbound call") - + # Prepare call data data = { - "to": [{ - "type": "phone", - "number": to_number - }], - "from": { - "type": "phone", - "number": from_number - }, + "to": [{"type": "phone", "number": to_number}], + "from": {"type": "phone", "number": from_number}, "answer_url": [webhook_url], - "answer_method": "GET" + "answer_method": "GET", } - + # Add event webhook if workflow_run_id provided if workflow_run_id: backend_endpoint = await TunnelURLProvider.get_tunnel_url() event_url = f"https://{backend_endpoint}/api/v1/telephony/vonage/events/{workflow_run_id}" - data.update({ - "event_url": [event_url], - "event_method": "POST" - }) - + data.update({"event_url": [event_url], "event_method": "POST"}) + data.update(kwargs) - + # Generate JWT token token = self._generate_jwt() headers = { "Authorization": f"Bearer {token}", - "Content-Type": "application/json" + "Content-Type": "application/json", } - + # Make the API request async with aiohttp.ClientSession() as session: - async with session.post( - endpoint, - json=data, - headers=headers - ) as response: + async with session.post(endpoint, json=data, headers=headers) as response: response_data = await response.json() - + if response.status != 201: raise Exception(f"Failed to initiate call: {response_data}") - + return CallInitiationResult( call_id=response_data["uuid"], status=response_data.get("status", "started"), provider_metadata={ - "call_uuid": response_data["uuid"] # Vonage needs UUID persisted for WebSocket + "call_uuid": response_data[ + "uuid" + ] # Vonage needs UUID persisted for WebSocket }, - raw_response=response_data + raw_response=response_data, ) async def get_call_status(self, call_id: str) -> Dict[str, Any]: @@ -147,21 +139,19 @@ class VonageProvider(TelephonyProvider): """ if not self.validate_config(): raise ValueError("Vonage provider not properly configured") - + endpoint = f"{self.base_url}/v1/calls/{call_id}" - + # Generate JWT token token = self._generate_jwt() - headers = { - "Authorization": f"Bearer {token}" - } - + headers = {"Authorization": f"Bearer {token}"} + async with aiohttp.ClientSession() as session: async with session.get(endpoint, headers=headers) as response: if response.status != 200: error_data = await response.json() raise Exception(f"Failed to get call status: {error_data}") - + return await response.json() async def get_available_phone_numbers(self) -> List[str]: @@ -174,11 +164,7 @@ class VonageProvider(TelephonyProvider): """ Validate Vonage configuration. """ - return bool( - self.application_id and - self.private_key and - self.from_numbers - ) + return bool(self.application_id and self.private_key and self.from_numbers) async def verify_webhook_signature( self, url: str, params: Dict[str, Any], signature: str @@ -190,14 +176,14 @@ class VonageProvider(TelephonyProvider): if not self.api_secret: logger.error("No API secret available for webhook signature verification") return False - + try: # Vonage sends JWT in Authorization header. Verify the JWT signature decoded = jwt.decode( - signature, - self.api_secret, + signature, + self.api_secret, algorithms=["HS256"], - options={"verify_signature": True} + options={"verify_signature": True}, ) return True except jwt.InvalidTokenError: @@ -211,43 +197,42 @@ class VonageProvider(TelephonyProvider): NCCO (Nexmo Call Control Objects) is JSON-based, unlike TwiML which is XML. """ backend_endpoint = await TunnelURLProvider.get_tunnel_url() - + # NCCO for WebSocket connection ncco = [ { "action": "connect", - "endpoint": [{ - "type": "websocket", - "uri": f"wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}", - "content-type": "audio/l16;rate=16000", # 16kHz Linear PCM - "headers": {} - }] + "endpoint": [ + { + "type": "websocket", + "uri": f"wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}", + "content-type": "audio/l16;rate=16000", # 16kHz Linear PCM + "headers": {}, + } + ], } ] - + return json.dumps(ncco) def _get_auth_headers(self) -> Dict[str, str]: """Generate authorization headers for Vonage API.""" token = self._generate_jwt() - return { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json" - } + return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} async def get_call_cost(self, call_id: str) -> Dict[str, Any]: """ Get cost information for a completed Vonage call. - + Args: call_id: The Vonage Call UUID - + Returns: Dict containing cost information """ headers = self._get_auth_headers() endpoint = f"https://api.nexmo.com/v1/calls/{call_id}" - + try: async with aiohttp.ClientSession() as session: async with session.get(endpoint, headers=headers) as response: @@ -258,39 +243,34 @@ class VonageProvider(TelephonyProvider): "cost_usd": 0.0, "duration": 0, "status": "error", - "error": str(error_data) + "error": str(error_data), } - + call_data = await response.json() - + # Vonage returns price and rate # Price is the total cost, rate is the per-minute rate price = float(call_data.get("price", 0)) cost_usd = price # Vonage returns positive values - + # Duration is in seconds duration = int(call_data.get("duration", 0)) - + # Get the call status status = call_data.get("status", "unknown") - + return { "cost_usd": cost_usd, "duration": duration, "status": status, "price_unit": "USD", # Vonage uses USD by default "rate": call_data.get("rate", 0), # Per-minute rate - "raw_response": call_data + "raw_response": call_data, } - + except Exception as e: logger.error(f"Exception fetching Vonage call cost: {e}") - return { - "cost_usd": 0.0, - "duration": 0, - "status": "error", - "error": str(e) - } + return {"cost_usd": 0.0, "duration": 0, "status": "error", "error": str(e)} def parse_status_callback(self, data: Dict[str, Any]) -> Dict[str, Any]: """ @@ -300,14 +280,14 @@ class VonageProvider(TelephonyProvider): status_map = { "started": "initiated", "ringing": "ringing", - "answered": "answered", + "answered": "answered", "complete": "completed", "failed": "failed", "busy": "busy", "timeout": "no-answer", - "rejected": "busy" + "rejected": "busy", } - + return { "call_id": data.get("uuid", ""), "status": status_map.get(data.get("status", ""), data.get("status", "")), @@ -315,7 +295,7 @@ class VonageProvider(TelephonyProvider): "to_number": data.get("to"), "direction": data.get("direction"), "duration": data.get("duration"), - "extra": data # Include all original data + "extra": data, # Include all original data } async def handle_websocket( @@ -327,14 +307,14 @@ class VonageProvider(TelephonyProvider): ) -> None: """ Handle Vonage-specific WebSocket connection. - + Vonage can send: 1. JSON metadata first (websocket:connected event) 2. Or directly start with binary audio """ from api.db import db_client from api.services.pipecat.run_pipeline import run_pipeline_vonage - + try: # Get workflow run to extract call UUID workflow_run = await db_client.get_workflow_run(workflow_run_id) @@ -342,38 +322,48 @@ class VonageProvider(TelephonyProvider): logger.error(f"Workflow run {workflow_run_id} not found") await websocket.close(code=4404, reason="Workflow run not found") return - + # Get workflow for organization info workflow = await db_client.get_workflow(workflow_id, user_id) if not workflow: logger.error(f"Workflow {workflow_id} not found") await websocket.close(code=4404, reason="Workflow not found") return - + # Extract call UUID from workflow run context - call_uuid = workflow_run.gathered_context.get("call_uuid") if workflow_run.gathered_context else None - + call_uuid = ( + workflow_run.gathered_context.get("call_uuid") + if workflow_run.gathered_context + else None + ) + if not call_uuid: - logger.error(f"No call UUID found for Vonage connection in workflow run {workflow_run_id}") + logger.error( + f"No call UUID found for Vonage connection in workflow run {workflow_run_id}" + ) await websocket.close(code=4400, reason="Missing call UUID") return - - logger.info(f"Vonage WebSocket connected for workflow_run {workflow_run_id}, call_uuid: {call_uuid}") - + + logger.info( + f"Vonage WebSocket connected for workflow_run {workflow_run_id}, call_uuid: {call_uuid}" + ) + # Peek at first message to see if it's metadata or audio first_msg = await websocket.receive() - + if "text" in first_msg: # JSON metadata - check if it's the connection event msg = json.loads(first_msg["text"]) if msg.get("event") == "websocket:connected": - logger.debug(f"Received Vonage connection confirmation for {workflow_run_id}") + logger.debug( + f"Received Vonage connection confirmation for {workflow_run_id}" + ) # Continue to pipeline regardless of message type elif "bytes" in first_msg: # Binary audio - Vonage started with audio immediately logger.debug(f"Vonage started with binary audio for {workflow_run_id}") # The pipeline will handle this first audio chunk - + # Run the Vonage pipeline await run_pipeline_vonage( websocket, @@ -382,9 +372,9 @@ class VonageProvider(TelephonyProvider): workflow.organization_id, workflow_id, workflow_run_id, - user_id + user_id, ) - + except Exception as e: logger.error(f"Error in Vonage WebSocket handler: {e}") - raise \ No newline at end of file + raise diff --git a/api/services/telephony/stasis_rtp_transport.py b/api/services/telephony/stasis_rtp_transport.py index cd3716d..20a6b64 100644 --- a/api/services/telephony/stasis_rtp_transport.py +++ b/api/services/telephony/stasis_rtp_transport.py @@ -22,9 +22,7 @@ from pipecat.frames.frames import ( ) from pipecat.serializers.base_serializer import FrameSerializer from pipecat.transports.base_input import BaseInputTransport -from pipecat.transports.base_output import ( - BaseOutputTransport -) +from pipecat.transports.base_output import BaseOutputTransport from pipecat.transports.base_transport import BaseTransport, TransportParams diff --git a/api/services/workflow/pipecat_engine.py b/api/services/workflow/pipecat_engine.py index 7ea4bea..08e840a 100644 --- a/api/services/workflow/pipecat_engine.py +++ b/api/services/workflow/pipecat_engine.py @@ -14,14 +14,14 @@ from pipecat.frames.frames import ( CancelFrame, EndFrame, FunctionCallResultProperties, + LLMContextFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, TTSSpeakFrame, ) from pipecat.pipeline.task import PipelineTask -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame +from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.services.llm_service import FunctionCallParams -from pipecat.services.openai.llm import OpenAILLMContext from pipecat.transports.base_transport import BaseTransport from pipecat.utils.enums import EndTaskReason @@ -63,7 +63,7 @@ class PipecatEngine: *, task: Optional[PipelineTask] = None, llm: Optional["LLMService"] = None, - context: Optional[OpenAILLMContext] = None, + context: Optional[LLMContext] = None, tts: Optional[Any] = None, transport: Optional[BaseTransport] = None, workflow: WorkflowGraph, @@ -82,7 +82,6 @@ class PipecatEngine: self._workflow_run_id = workflow_run_id self._initialized = False self._client_disconnected = False - self._pending_function_calls = 0 self._current_node: Optional[Node] = None self._gathered_context: dict = {} self._user_response_timeout_task: Optional[asyncio.Task] = None @@ -102,29 +101,9 @@ class PipecatEngine: self._voicemail_detector = None self._voicemail_detection_task: Optional[asyncio.Task] = None - # This transition is generated by the llm as part of tool call. This can - # also be accompanied with some content which can be played using TTS. If the - # bot is interrupted, we would cancel this transition (we do cancel this currently when - # the next generation starts in handle_generation_started callback handler.) - self._pending_generated_transition_after_context_push: Optional[ - Callable[[], Awaitable[None]] - ] = None - - # This is the transtion which is typically programmatic transition, and not goes as - # tool call to LLM. This is not interrupted by the user and is done on context push - self._pending_control_transition_after_context_push: Optional[ - Callable[[], Awaitable[None]] - ] = None - - # Flag to determine if the current llm generation has a text completion - self._defer_context_push: bool = False - # Lazy loaded built-in function schemas self._builtin_function_schemas: Optional[list[dict]] = None - # Flag to control whether to queue context frame - self._queue_context_frame: bool = True - # Track current LLM reference text for TTS aggregation correction self._current_llm_reference_text: str = "" @@ -211,23 +190,15 @@ class PipecatEngine: async def _create_transition_func(self, name: str, transition_to_node: str): async def transition_func(function_call_params: FunctionCallParams) -> None: - """Inner function that handles the actual tool invocation.""" + """Inner function that handles the node change tool calls""" try: - # Track pending function call - self._pending_function_calls += 1 - logger.debug( - f"Function call pending: {function_call_params.function_name} (total: {self._pending_function_calls})" - ) - # For edge functions, prevent LLM completion until transition (run_llm=False) - # For node functions, allow immediate completion (run_llm=True) async def on_context_updated() -> None: """ - Framework will run this function after the function call result has been updated in the context. + pipecat framework will run this function after the function call result has been updated in the context. This way, when we do set_node from within this function, and go for LLM completion with updated system prompts, the context is updated with function call result. """ - self._pending_function_calls -= 1 # Perform variable extraction before transitioning to new node await self._perform_variable_extraction_if_needed( self._current_node @@ -241,41 +212,14 @@ class PipecatEngine: on_context_updated=on_context_updated, ) - async def _invoke_result_callback(): - """ - Functions are executed immediately when they come from LLM as part of text completion. - But, if the LLM completion also has some text, we would want to not call the function if the user interrupts the speech. - We would also not want the function to be added to context, so that the LLM can call the function again. Hence, we - defer the function invocation until we receive on_context_updated callback, i.e the bot has finished speaking - the text that was generated. - """ - await function_call_params.result_callback( - result, properties=properties - ) - - if self._defer_context_push: - """ - We set the flag to _defer_context_push when we receive text in the current generation from LLM. - This is set in the handle_llm_generated_text callback handler. - """ - logger.debug( - "Deferring transition function result until context push" - ) - # Only one deferred transition should exist at any time. - # Overwrite if one is somehow already set (unexpected). - self._pending_generated_transition_after_context_push = ( - _invoke_result_callback - ) - else: - """ - If there was no text in the current generation, and we only had function call, - lets invoke the result callback, so that framework can call on_context_updated and - we can do switch node. - """ - await _invoke_result_callback() + # Call results callback from the pipecat framework + # so that a new llm generation can be triggred if + # required + await function_call_params.result_callback( + result, properties=properties + ) except Exception as e: logger.error(f"Error in transition function {name}: {str(e)}") - self._pending_function_calls = 0 error_result = {"status": "error", "error": str(e)} await function_call_params.result_callback(error_result) @@ -362,27 +306,6 @@ class PipecatEngine: ] ) - async def _setup_static_start_node_transition(self, node: Node) -> None: - """Set up the deferred transition for static start nodes.""" - if not node.out_edges: - return - - next_node_id = node.out_edges[0].target - - if not node.wait_for_user_response: - # Normal static start node - transition immediately after context push - async def _deferred_static_transition(): - try: - await self.set_node(next_node_id) - except Exception as exc: - logger.error( - f"Error executing deferred static node transition to {next_node_id}: {exc}" - ) - - self._pending_control_transition_after_context_push = ( - _deferred_static_transition - ) - async def _perform_variable_extraction_if_needed( self, previous_node: Optional[Node] ) -> None: @@ -441,17 +364,7 @@ class PipecatEngine: functions, ) = await self._compose_system_message_functions_for_node(node) await self._update_llm_context(system_message, functions) - - # Queue context frame if needed - if self._queue_context_frame: - await self.task.queue_frame(OpenAILLMContextFrame(self.context)) - else: - logger.debug( - f"Not queueing context frame for node: {node.name} as _queue_context_frame is False" - ) - - # Reset _queue_context_frame as default behavior - self._queue_context_frame = True + await self.task.queue_frame(LLMContextFrame(self.context)) async def set_node(self, node_id: str): """ @@ -525,12 +438,7 @@ class PipecatEngine: await asyncio.sleep(delay_duration) if node.is_static: - # Queue TTS for static start node - formatted_prompt = self._format_prompt(node.prompt) - await self._queue_tts_response(formatted_prompt) - - # Set up deferred transition for static start nodes - await self._setup_static_start_node_transition(node) + raise ValueError("Static nodes are not supported!") else: # Start generation for non-static start node await self._setup_llm_context_and_start_generation(node) @@ -538,66 +446,24 @@ class PipecatEngine: async def _handle_end_node(self, node: Node) -> None: """Handle end node execution.""" if node.is_static: - # Queue TTS for static end node - formatted_prompt = self._format_prompt(node.prompt) - await self._queue_tts_response(formatted_prompt) + raise ValueError("Static nodes are not supported!") else: - # Start generation for non-static end node await self._setup_llm_context_and_start_generation(node) # If this end node has extraction enabled, perform extraction immediately if node.extraction_enabled and node.extraction_variables: await self._perform_variable_extraction_if_needed(node) - # TODO: Extract disposition code from extracted variables - # Defer send_end_task_frame using _pending_control_transition_after_context_push - - # Decide the end-task reason dynamically depending on call_disposition. - async def _deferred_end_task(): - # call_disposition is the disposition which is generated from - # llm call based on the conversation so far. - # TODO: Make this more generic based on configuration or llm prompting - disposition = self._gathered_context.get("call_disposition") - if disposition == "XFER": - reason = EndTaskReason.USER_QUALIFIED.value - else: - reason = EndTaskReason.USER_DISQUALIFIED.value - await self.send_end_task_frame(reason) - - self._pending_control_transition_after_context_push = _deferred_end_task + await self.send_end_task_frame(EndTaskReason.USER_QUALIFIED.value) async def _handle_agent_node(self, node: Node) -> None: """Handle agent node execution.""" if node.is_static: - # Queue TTS for static agent node - formatted_prompt = self._format_prompt(node.prompt) - await self._queue_tts_response(formatted_prompt) - - # Set up deferred transition for static agent nodes - await self._setup_agent_node_transition(node) + raise ValueError("Static nodes are not supported!") else: # Set context and functions for non-static agent node await self._setup_llm_context_and_start_generation(node) - async def _setup_agent_node_transition(self, node: Node) -> None: - """Set up the deferred transition for static agent nodes.""" - if not node.out_edges: - return - - next_node_id = node.out_edges[0].target - - async def _deferred_static_transition(): - try: - await self.set_node(next_node_id) - except Exception as exc: - logger.error( - f"Error executing deferred static node transition to {next_node_id}: {exc}" - ) - - self._pending_control_transition_after_context_push = ( - _deferred_static_transition - ) - async def send_end_task_frame( self, reason: str, @@ -640,7 +506,7 @@ class PipecatEngine: # Store the mapped disconnect reason self._gathered_context["call_disposition"] = mapped_disposition - # TODO: Generalise this, currently tailored to Kapil's use case + # TODO: Generalise this self._gathered_context["address"] = ", ".join( [ self._call_context_vars.get("address1", ""), @@ -759,55 +625,6 @@ class PipecatEngine: return system_message, functions - # ------------------------------------------------------------------ - # Pending transition handling - # ------------------------------------------------------------------ - - async def flush_pending_transitions(self, *, source: str = "context_push"): - """Execute and clear any pending transitions. - - Args: - source: Indicates the trigger that caused this flush: - - "context_push": the assistant context aggregator completed a push. - """ - - if source != "context_push": - raise ValueError("Invalid flush source – expected 'context_push'") - - len_pending_functions = 0 - - if self._pending_generated_transition_after_context_push is not None: - len_pending_functions += 1 - if self._pending_control_transition_after_context_push is not None: - len_pending_functions += 1 - - # Nothing to do - if len_pending_functions == 0: - return - - logger.debug( - f"Flushing {len_pending_functions} pending transition(s) after {source.replace('_', ' ')}" - ) - - # Generated transition - if self._pending_generated_transition_after_context_push is not None: - pending_cb = self._pending_generated_transition_after_context_push - self._pending_generated_transition_after_context_push = None - try: - await pending_cb() - except Exception as exc: # pragma: no cover - logger.error(f"Error executing deferred transition: {exc}") - - # Control transition (context push) - if self._pending_control_transition_after_context_push is not None: - logger.debug("Executing control transition after context push") - static_cb = self._pending_control_transition_after_context_push - self._pending_control_transition_after_context_push = None - try: - await static_cb() - except Exception as exc: # pragma: no cover - logger.error(f"Error executing deferred static node transition: {exc}") - def create_should_mute_callback(self) -> Callable[[STTMuteFilter], Awaitable[bool]]: """ This callback is called by STTMuteFilter to determine if the STT should be muted. @@ -828,15 +645,6 @@ class PipecatEngine: """ return engine_callbacks.create_max_duration_callback(self) - def create_llm_generated_text_callback(self): - """ - This callback is called when some text is generated by the LLM. - We use this to defer the result_callback of the node transition functions if - there is set_node called along with some text generated. This way, we will - have the context sent in the next generation from new node. - """ - return engine_callbacks.create_llm_generated_text_callback(self) - def create_generation_started_callback(self): """ This callback is called when a new generation starts. @@ -844,26 +652,12 @@ class PipecatEngine: """ return engine_callbacks.create_generation_started_callback(self) - def create_user_stopped_speaking_callback(self): - """ - This callback is called when the user stops speaking. - We use this to handle transitions when wait_for_user_response is enabled. - """ - return engine_callbacks.create_user_stopped_speaking_callback(self) - - def create_user_started_speaking_callback(self): - """ - This callback is called when the user starts speaking. - We use this to handle wait_for_user_greeting functionality. - """ - return engine_callbacks.create_user_started_speaking_callback(self) - def create_aggregation_correction_callback(self) -> Callable[[str], str]: """Create a callback that corrects corrupted aggregation using reference text.""" return engine_callbacks.create_aggregation_correction_callback(self) - def set_context(self, context: OpenAILLMContext) -> None: - """Set the OpenAI LLM context. + def set_context(self, context: LLMContext) -> None: + """Set the LLM context. This allows setting the context after the engine has been created, which is useful when the context needs to be created after the engine. diff --git a/api/services/workflow/pipecat_engine_callbacks.py b/api/services/workflow/pipecat_engine_callbacks.py index 13f2433..d4ba2a4 100644 --- a/api/services/workflow/pipecat_engine_callbacks.py +++ b/api/services/workflow/pipecat_engine_callbacks.py @@ -14,6 +14,7 @@ import re from typing import TYPE_CHECKING, Awaitable, Callable from loguru import logger + from pipecat.frames.frames import ( LLMFullResponseEndFrame, LLMFullResponseStartFrame, @@ -23,9 +24,8 @@ from pipecat.processors.filters.stt_mute_filter import STTMuteFilter from pipecat.utils.enums import EndTaskReason if TYPE_CHECKING: - from pipecat.processors.user_idle_processor import UserIdleProcessor - from api.services.workflow.pipecat_engine import PipecatEngine + from pipecat.processors.user_idle_processor import UserIdleProcessor # --------------------------------------------------------------------------- @@ -114,23 +114,6 @@ def create_max_duration_callback(engine: "PipecatEngine"): return handle_max_duration -# --------------------------------------------------------------------------- -# LLM-generated-text handling -# --------------------------------------------------------------------------- - - -def create_llm_generated_text_callback(engine: "PipecatEngine"): - """Return a callback invoked when the LLM emits text (not only tool calls).""" - - async def handle_llm_generated_text(): # noqa: D401 - logger.debug( - "Generation has text content in current response - deferring context push from set_node" - ) - engine._defer_context_push = True - - return handle_llm_generated_text - - # --------------------------------------------------------------------------- # Generation-started handling # --------------------------------------------------------------------------- @@ -140,96 +123,13 @@ def create_generation_started_callback(engine: "PipecatEngine"): """Return a callback that resets flags at the start of each LLM generation.""" async def handle_generation_started(): # noqa: D401 - logger.debug("LLM generation started - resetting defer flags and tool counters") - engine._defer_context_push = False - engine._pending_function_calls = 0 - engine._pending_generated_transition_after_context_push = None + logger.debug("LLM generation started in callback processor") # Clear reference text from previous generation engine._current_llm_reference_text = "" return handle_generation_started -# --------------------------------------------------------------------------- -# User-stopped-speaking handling -# --------------------------------------------------------------------------- - - -def create_user_stopped_speaking_callback(engine: "PipecatEngine"): - """Return a callback that handles when the user stops speaking. - - According to simplified flow: - - For start nodes with wait_for_user_response=True: - - Cancel timeout task if still active - - Transition to next node with _queue_context_frame=False - """ - - async def handle_user_stopped_speaking(): - # Only handle if current node is a start node with wait_for_user_response - if ( - engine._current_node - and engine._current_node.is_start - and engine._current_node.wait_for_user_response - and engine._current_node.out_edges - ): - # Cancel timeout task if it's still active - if ( - engine._user_response_timeout_task - and not engine._user_response_timeout_task.done() - ): - logger.debug("Cancelling user response timeout - user responded") - engine._user_response_timeout_task.cancel() - engine._user_response_timeout_task = None - - # Transition to next node - next_node_id = engine._current_node.out_edges[0].target - logger.debug( - f"User stopped speaking after wait_for_user_response - transitioning to: {next_node_id}" - ) - - # Set flag to not queue context frame since - # it will be pushed by user context aggregator - # we are just setting the context with next node's - # functions and prompts - engine._queue_context_frame = False - - # Transition to next node - await engine.set_node(next_node_id) - - return handle_user_stopped_speaking - - -# --------------------------------------------------------------------------- -# User-started-speaking handling -# --------------------------------------------------------------------------- - - -def create_user_started_speaking_callback(engine: "PipecatEngine"): - """Return a callback that handles when the user starts speaking. - - According to simplified flow: - - For start nodes with wait_for_user_response=True: - - Cancel the timeout timer if it exists (but don't set to None) - """ - - async def handle_user_started_speaking(): - # Only handle if current node is a start node with wait_for_user_response - if ( - engine._current_node - and engine._current_node.is_start - and engine._current_node.wait_for_user_response - and engine._user_response_timeout_task - and not engine._user_response_timeout_task.done() - ): - logger.debug( - "User started speaking during wait_for_user_response - cancelling timeout timer" - ) - engine._user_response_timeout_task.cancel() - # Don't set to None here - let user_stopped_speaking handle the transition - - return handle_user_started_speaking - - def create_aggregation_correction_callback(engine: "PipecatEngine"): """Create a callback that uses engine's reference text to correct corrupted aggregation.""" diff --git a/api/services/workflow/pipecat_engine_utils.py b/api/services/workflow/pipecat_engine_utils.py index b41c497..5c4300c 100644 --- a/api/services/workflow/pipecat_engine_utils.py +++ b/api/services/workflow/pipecat_engine_utils.py @@ -2,16 +2,10 @@ from __future__ import annotations from typing import Any, Dict, List -from google.genai.types import ( - Content, - Part, -) +from api.utils.template_renderer import render_template from pipecat.adapters.schemas.function_schema import FunctionSchema from pipecat.adapters.schemas.tools_schema import ToolsSchema -from pipecat.services.google.llm import GoogleLLMContext -from pipecat.services.openai.llm import OpenAILLMContext - -from api.utils.template_renderer import render_template +from pipecat.processors.aggregators.llm_context import LLMContext __all__ = [ "get_function_schema", @@ -44,7 +38,7 @@ def get_function_schema( def update_llm_context( - context: OpenAILLMContext, + context: LLMContext, system_message: Dict[str, Any], functions: List[FunctionSchema], ) -> None: @@ -59,21 +53,6 @@ def update_llm_context( # associated with the current LLM service can convert them to the correct # provider-specific representation when required. tools_schema = ToolsSchema(standard_tools=functions) - - if isinstance(context, GoogleLLMContext): - context.system_message = system_message["content"] - - if functions: - # Lets only call set_tools if we have functions, else Gemini will - # throw an exception - context.set_tools(tools_schema) - - if context.messages[-1].role != "user": - # Google expects the last message should end with user message - context.add_message(Content(role="user", parts=[Part(text="...")])) - return - - # In case of OpenAILLMContext, replace the system message with incoming system message previous_interactions = context.messages # Filter out old system messages but keep user/assistant/function content. diff --git a/api/services/workflow/pipecat_engine_variable_extractor.py b/api/services/workflow/pipecat_engine_variable_extractor.py index 798a82f..7b1eed6 100644 --- a/api/services/workflow/pipecat_engine_variable_extractor.py +++ b/api/services/workflow/pipecat_engine_variable_extractor.py @@ -7,11 +7,11 @@ from typing import TYPE_CHECKING, Any, List from loguru import logger from openai import AsyncOpenAI from opentelemetry import trace -from pipecat.services.openai.llm import OpenAILLMContext -from pipecat.utils.tracing.service_attributes import add_llm_span_attributes from api.services.pipecat.tracing_config import is_tracing_enabled from api.services.workflow.dto import ExtractionVariableDTO +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.utils.tracing.service_attributes import add_llm_span_attributes if TYPE_CHECKING: from api.services.workflow.pipecat_engine import PipecatEngine @@ -139,7 +139,7 @@ class VariableExtractionManager: f"{conversation_history}" ) - extraction_context = OpenAILLMContext() + extraction_context = LLMContext() extraction_messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, @@ -171,7 +171,7 @@ class VariableExtractionManager: service_name="OpenAILLMService", model=self._model, operation_name="variable_extraction", - messages=json.dumps(extraction_messages), + messages=extraction_messages, output=llm_response, stream=False, parameters={"temperature": 0.0, "response_format": "json_object"}, diff --git a/api/services/workflow/workflow.py b/api/services/workflow/workflow.py index 49a85a5..82fa82d 100644 --- a/api/services/workflow/workflow.py +++ b/api/services/workflow/workflow.py @@ -44,8 +44,6 @@ class Node: self.extraction_prompt = data.extraction_prompt self.extraction_variables = data.extraction_variables self.add_global_prompt = data.add_global_prompt - self.wait_for_user_response = data.wait_for_user_response - self.wait_for_user_response_timeout = data.wait_for_user_response_timeout self.detect_voicemail = data.detect_voicemail self.delayed_start = data.delayed_start self.delayed_start_duration = data.delayed_start_duration diff --git a/api/tasks/run_integrations.py b/api/tasks/run_integrations.py index bf0ec78..febaecd 100644 --- a/api/tasks/run_integrations.py +++ b/api/tasks/run_integrations.py @@ -3,12 +3,12 @@ import os import aiohttp import httpx from loguru import logger -from pipecat.utils.context import set_current_run_id from api.db import db_client from api.db.models import IntegrationModel from api.enums import OrganizationConfigurationKey, WorkflowRunMode from api.utils.template_renderer import render_template +from pipecat.utils.context import set_current_run_id async def run_integrations_post_workflow_run(ctx, workflow_run_id: int): @@ -162,7 +162,7 @@ async def _process_slack_integration( """ logger.info(f"Processing Slack integration {integration.id}") - # TODO: Generalise this, currently tailored to Kapil's use case + # TODO: Generalise this if gathered_context.get("mapped_call_disposition") != "XFER": logger.debug( f"Not sending message on slack since not XFER: {gathered_context.get('mapped_call_disposition')}" diff --git a/api/tasks/workflow_run_cost.py b/api/tasks/workflow_run_cost.py index 4b3dc1b..25400fc 100644 --- a/api/tasks/workflow_run_cost.py +++ b/api/tasks/workflow_run_cost.py @@ -28,20 +28,24 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int): # Fetch telephony call cost for both Twilio and Vonage telephony_cost_usd = 0.0 - if workflow_run.mode in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value] and workflow_run.cost_info: + if ( + workflow_run.mode + in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value] + and workflow_run.cost_info + ): # Get the call ID (provider-agnostic approach with backward compatibility) call_id = workflow_run.cost_info.get("call_id") - + # Fallback to legacy provider-specific fields if needed if not call_id: if workflow_run.mode == WorkflowRunMode.TWILIO.value: call_id = workflow_run.cost_info.get("twilio_call_sid") elif workflow_run.mode == WorkflowRunMode.VONAGE.value: call_id = workflow_run.cost_info.get("vonage_call_uuid") - + # Provider name is derived from workflow run mode provider_name = workflow_run.mode.lower() if workflow_run.mode else "" - + if call_id: try: # Get workflow to access organization_id @@ -55,12 +59,14 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int): # Use telephony provider abstraction provider = await get_telephony_provider(workflow.organization_id) call_cost_info = await provider.get_call_cost(call_id) - + if call_cost_info.get("status") != "error": telephony_cost_usd = call_cost_info.get("cost_usd", 0.0) cost_breakdown["telephony_call"] = telephony_cost_usd - cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd # Keep backward compatibility - + cost_breakdown[f"{provider_name}_call"] = ( + telephony_cost_usd # Keep backward compatibility + ) + # Add telephony cost to the total cost_breakdown["total"] = ( float(cost_breakdown["total"]) + telephony_cost_usd @@ -69,8 +75,10 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int): f"{provider_name.title()} call cost: ${telephony_cost_usd:.6f} USD for call {call_id}" ) else: - logger.error(f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}") - + logger.error( + f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}" + ) + except Exception as e: logger.error(f"Failed to fetch telephony call cost: {e}") # Don't fail the whole cost calculation if telephony API fails @@ -119,7 +127,9 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int): elif "twilio_call_sid" in workflow_run.cost_info: cost_info["twilio_call_sid"] = workflow_run.cost_info["twilio_call_sid"] elif "vonage_call_uuid" in workflow_run.cost_info: - cost_info["vonage_call_uuid"] = workflow_run.cost_info["vonage_call_uuid"] + cost_info["vonage_call_uuid"] = workflow_run.cost_info[ + "vonage_call_uuid" + ] # Update workflow run with cost information await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info) diff --git a/api/tests/test_base_openai_llm_service.py b/api/tests/test_base_openai_llm_service.py deleted file mode 100644 index 6bb5f25..0000000 --- a/api/tests/test_base_openai_llm_service.py +++ /dev/null @@ -1,179 +0,0 @@ -### - This test has some weird loop which keeps on increasing the context size - -# import asyncio -# import json -# import unittest -# from types import SimpleNamespace -# from unittest import mock - -# from loguru import logger - -# from pipecat.frames.frames import ( -# FunctionCallInProgressFrame, -# FunctionCallResultFrame, -# FunctionCallsStartedFrame, -# LLMFullResponseEndFrame, -# LLMFullResponseStartFrame, -# LLMGeneratedTextFrame, -# LLMTextFrame, -# ) -# from pipecat.pipeline.pipeline import Pipeline -# from pipecat.processors.aggregators.openai_llm_context import ( -# OpenAILLMContext, -# OpenAILLMContextFrame, -# ) -# from pipecat.services.llm_service import ( -# FunctionCallParams, -# FunctionCallResultProperties, -# ) -# from pipecat.services.openai.llm import OpenAILLMService -# from pipecat.tests.utils import run_test - - -# class _MockAsyncStream: -# """A minimal async-stream wrapper that mimics ``openai.AsyncStream``.""" - -# def __init__(self, chunks): -# self._chunks = chunks - -# def __aiter__(self): -# self._idx = 0 -# return self - -# async def __anext__(self): -# if self._idx >= len(self._chunks): -# raise StopAsyncIteration -# item = self._chunks[self._idx] -# self._idx += 1 -# await asyncio.sleep(0) # Yield control -# return item - - -# # ------------------------------------------------------------------ -# # Factories for mock chunks -# # ------------------------------------------------------------------ - - -# def _make_tool_call(tool_name: str, args_json: str, *, idx: int = 0): -# function = SimpleNamespace(name=tool_name, arguments=args_json) -# return SimpleNamespace(index=idx, id=f"call-{idx}", function=function) - - -# def _make_chunk(*, content: str | None = None, tool_calls=None, usage=None): -# delta = SimpleNamespace() -# # When we are asked to simulate multiple tool calls in parallel, OpenAI -# # sends *separate* chunks for every tool-call index. To mimic that behaviour -# # in tests we split a list of tool calls (>1) into individual chunks – one -# # for each tool call – while keeping the original single-chunk behaviour -# # when zero or one tool calls are supplied. This enables us to write -# # concise tests such as ``_make_chunk(tool_calls=[call_1, call_2])`` that -# # accurately reflect the streaming protocol. - -# # No special handling needed if there is textual content or 0/1 tool calls. -# if content is not None or tool_calls is None or len(tool_calls) <= 1: -# if content is not None: -# delta.content = content -# # Always set tool_calls so downstream code can safely access it -# delta.tool_calls = tool_calls if tool_calls is not None else None -# return SimpleNamespace(choices=[SimpleNamespace(delta=delta)], usage=usage) - -# # --- Multiple tool calls (len(tool_calls) > 1) --- -# # Create a list of chunks, each containing a single tool call. This is the -# # format produced by the OpenAI client when several tools are invoked in a -# # single assistant response. -# chunks = [] -# for tc in tool_calls: -# delta_tc = SimpleNamespace(tool_calls=[tc]) -# chunks.append(SimpleNamespace(choices=[SimpleNamespace(delta=delta_tc)], usage=usage)) - -# return chunks - - -# class TestBaseOpenAILLMService(unittest.IsolatedAsyncioTestCase): -# async def test_process_context_with_patch(self): -# streamed_text = "Hello from OpenAI!" -# tool_name = "echo" -# tool_name_2 = "echo_2" -# tool_args = {"text": "hello"} -# tool_args_2 = {"text": "hello_2"} - -# # Build mocked stream (tool call first, then text) -# chunks = [ -# _make_chunk(content=streamed_text), -# _make_chunk(tool_calls=[_make_tool_call(tool_name, json.dumps(tool_args))]), -# _make_chunk(tool_calls=[_make_tool_call(tool_name_2, json.dumps(tool_args_2), idx=1)]), -# ] - -# # Instantiate real OpenAILLMService (no need for actual API key) -# llm = OpenAILLMService(model="gpt-4o-mini", api_key="test") - -# # Patch get_chat_completions to return our mocked async stream -# async def fake_get_chat_completions(self, context, messages): # noqa: D401 -# return _MockAsyncStream(chunks) - -# with mock.patch.object(llm.__class__, "get_chat_completions", fake_get_chat_completions): -# # Register echo tool -# executed = False - -# async def echo_handler(params: FunctionCallParams): -# nonlocal executed -# executed = True -# # sleep for 1 second -# logger.info("echo_handler: sleeping for 5 second") -# await asyncio.sleep(5) -# await params.result_callback( -# {"ok": True}, -# properties=FunctionCallResultProperties(run_llm=True), -# ) - -# async def echo_2_handler(params: FunctionCallParams): -# nonlocal executed -# executed = True -# # sleep for 1 second -# logger.info("echo_2_handler: sleeping for 5 second") -# await asyncio.sleep(5) -# await params.result_callback( -# {"ok": True}, -# properties=FunctionCallResultProperties(run_llm=True), -# ) - -# llm.register_function(tool_name, echo_handler) -# llm.register_function(tool_name_2, echo_2_handler) - -# # Prepare context and send -# context = OpenAILLMContext() -# context.add_message({"role": "user", "content": "Hi"}) -# frames_to_send = [OpenAILLMContextFrame(context)] - -# expected_down_frames = [ -# LLMFullResponseStartFrame, -# FunctionCallsStartedFrame, -# FunctionCallInProgressFrame, -# FunctionCallResultFrame, -# LLMGeneratedTextFrame, -# LLMTextFrame, -# LLMFullResponseEndFrame, -# ] - -# context_aggregator = llm.create_context_aggregator(context) - -# pipeline = Pipeline([llm, context_aggregator.assistant()]) - -# down_frames, _ = await run_test( -# pipeline, -# frames_to_send=frames_to_send, -# expected_down_frames=expected_down_frames, -# send_end_frame=False, -# ) - -# # Assertions -# self.assertTrue(executed) -# for fr in down_frames: -# if isinstance(fr, FunctionCallResultFrame): -# self.assertTrue(fr.run_llm) -# if isinstance(fr, LLMTextFrame): -# self.assertEqual(fr.text, streamed_text) - - -# if __name__ == "__main__": -# unittest.main() diff --git a/api/tests/test_llm_generated_text_signal.py b/api/tests/test_llm_generated_text_signal.py deleted file mode 100644 index 6a3d6b2..0000000 --- a/api/tests/test_llm_generated_text_signal.py +++ /dev/null @@ -1,143 +0,0 @@ -#!/usr/bin/env python3 - -""" -Test script to verify that LLMGeneratedTextFrame signaling works correctly -with the new local variable approach. -""" - - -def test_local_variable_logic(): - """Test the core logic using the same pattern as the implementation""" - - print("=== Testing Local Variable Logic ===") - - # Simulate the logic from _process_context - text_generation_signaled = False - frames_sent = [] - - # Simulate chunks with text content - chunks_with_content = ["Hello", " world", "!"] - - for content in chunks_with_content: - # This is the exact logic from our implementation - if content: # equivalent to chunk.choices[0].delta.content - if not text_generation_signaled: - frames_sent.append("LLMGeneratedTextFrame") - text_generation_signaled = True - frames_sent.append(f"LLMTextFrame({content})") - - print(f"Frames sent: {frames_sent}") - - # Verify behavior - generated_signals = [f for f in frames_sent if f == "LLMGeneratedTextFrame"] - text_frames = [f for f in frames_sent if f.startswith("LLMTextFrame")] - - assert len(generated_signals) == 1, ( - f"Expected 1 signal, got {len(generated_signals)}" - ) - assert len(text_frames) == 3, f"Expected 3 text frames, got {len(text_frames)}" - assert frames_sent[0] == "LLMGeneratedTextFrame", "Signal should be first" - - print("✅ Local variable logic works correctly") - return True - - -def test_no_text_logic(): - """Test that no signal is sent when there's no text""" - - print("\n=== Testing No Text Logic ===") - - text_generation_signaled = False - frames_sent = [] - - # Simulate chunks with no text content (function calls only) - chunks_with_content = [None, None, None] # No text content - - for content in chunks_with_content: - if content: # This will be False for all chunks - if not text_generation_signaled: - frames_sent.append("LLMGeneratedTextFrame") - text_generation_signaled = True - frames_sent.append(f"LLMTextFrame({content})") - - print(f"Frames sent: {frames_sent}") - - assert len(frames_sent) == 0, f"Expected no frames, got {frames_sent}" - - print("✅ No signal sent when no text content") - return True - - -def test_mixed_content_logic(): - """Test behavior with mixed function calls and text""" - - print("\n=== Testing Mixed Content Logic ===") - - text_generation_signaled = False - frames_sent = [] - - # Simulate chunks: function call, text, function call, text - chunks = [ - {"type": "function", "content": None}, - {"type": "text", "content": "Hello"}, - {"type": "function", "content": None}, - {"type": "text", "content": " world"}, - ] - - for chunk in chunks: - if chunk["type"] == "function": - frames_sent.append("FunctionCallFrame") - elif chunk["content"]: # text content - if not text_generation_signaled: - frames_sent.append("LLMGeneratedTextFrame") - text_generation_signaled = True - frames_sent.append(f"LLMTextFrame({chunk['content']})") - - print(f"Frames sent: {frames_sent}") - - generated_signals = [f for f in frames_sent if f == "LLMGeneratedTextFrame"] - - assert len(generated_signals) == 1, ( - f"Expected 1 signal, got {len(generated_signals)}" - ) - # Signal should come before first text frame but after any function frames - signal_index = frames_sent.index("LLMGeneratedTextFrame") - first_text_index = next( - i for i, f in enumerate(frames_sent) if f.startswith("LLMTextFrame") - ) - assert signal_index == first_text_index - 1, ( - "Signal should come right before first text" - ) - - print("✅ Mixed content logic works correctly") - return True - - -def main(): - try: - test1_result = test_local_variable_logic() - test2_result = test_no_text_logic() - test3_result = test_mixed_content_logic() - - print(f"\n=== Test Results ===") - print(f"Local variable test: {'✅ PASS' if test1_result else '❌ FAIL'}") - print(f"No text test: {'✅ PASS' if test2_result else '❌ FAIL'}") - print(f"Mixed content test: {'✅ PASS' if test3_result else '❌ FAIL'}") - - if test1_result and test2_result and test3_result: - print("\n🎉 All LLMGeneratedTextFrame signaling logic tests passed!") - print( - "✅ Implementation correctly signals text generation once, as early as possible" - ) - else: - print("\n❌ Some tests failed.") - - except Exception as e: - print(f"❌ Test failed with error: {e}") - import traceback - - traceback.print_exc() - - -if __name__ == "__main__": - main() diff --git a/api/tests/test_pipecat_engine_set_node.py b/api/tests/test_pipecat_engine_set_node.py deleted file mode 100644 index a0de71f..0000000 --- a/api/tests/test_pipecat_engine_set_node.py +++ /dev/null @@ -1,536 +0,0 @@ -import asyncio -from unittest.mock import AsyncMock, Mock, patch - -import pytest -from pipecat.frames.frames import ( - EndFrame, - LLMFullResponseEndFrame, - LLMFullResponseStartFrame, - TTSSpeakFrame, -) -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame -from pipecat.services.openai.llm import OpenAILLMContext - -from api.services.workflow.dto import EdgeDataDTO, NodeDataDTO -from api.services.workflow.pipecat_engine import PipecatEngine -from api.services.workflow.workflow import Edge, Node, WorkflowGraph - - -class TestPipecatEngineSetNode: - """Test cases for PipecatEngine.set_node method refactoring.""" - - @pytest.fixture - def mock_workflow(self): - """Create a mock workflow with various node types.""" - workflow = Mock(spec=WorkflowGraph) - workflow.nodes = {} - workflow.start_node_id = "start_node" - workflow.global_node_id = None - return workflow - - @pytest.fixture - def mock_dependencies(self, mock_workflow): - """Create mock dependencies for PipecatEngine initialization.""" - task = AsyncMock() - task.queue_frames = AsyncMock() - task.queue_frame = AsyncMock() - - llm = AsyncMock() - llm.register_function = Mock() - llm.push_frame = AsyncMock() - - context = Mock(spec=OpenAILLMContext) - context.set_node_name = Mock() - - return { - "task": task, - "llm": llm, - "context": context, - "tts": Mock(), - "transport": Mock(), - "workflow": mock_workflow, - "call_context_vars": {"test_var": "test_value"}, - } - - @pytest.fixture - def engine(self, mock_dependencies): - """Create a PipecatEngine instance.""" - # Add audio_buffer and workflow_run_id to dependencies - mock_dependencies["audio_buffer"] = None - mock_dependencies["workflow_run_id"] = 123 - engine = PipecatEngine(**mock_dependencies) - # Mock the builtin function registration - engine._register_builtin_functions = AsyncMock() - return engine - - def create_node(self, node_id, **kwargs): - """Helper to create a node with default values.""" - defaults = { - "name": f"Node {node_id}", - "prompt": f"Prompt for {node_id}", - "is_static": False, - "is_start": False, - "is_end": False, - "allow_interrupt": True, - "extraction_enabled": False, - "extraction_prompt": "", - "extraction_variables": [], - "add_global_prompt": True, - "wait_for_user_response": False, - "detect_voicemail": False, - } - defaults.update(kwargs) - - data = Mock(spec=NodeDataDTO) - for key, value in defaults.items(): - setattr(data, key, value) - - node = Mock(spec=Node) - node.id = node_id - node.data = data - node.out_edges = [] - - # Copy attributes from data to node - for key, value in defaults.items(): - setattr(node, key, value) - - return node - - def create_edge( - self, source, target, label="Continue", condition="Always continue" - ): - """Helper to create an edge.""" - data = Mock(spec=EdgeDataDTO) - data.label = label - data.condition = condition - - edge = Mock(spec=Edge) - edge.source = source - edge.target = target - edge.data = data - edge.get_function_name = Mock(return_value=label.lower().replace(" ", "_")) - - return edge - - # ===== START NODE TESTS ===== - - @pytest.mark.asyncio - async def test_start_node_static_immediate_execution(self, engine, mock_workflow): - """Test: Basic static start node executes immediately.""" - # Setup - start_node = self.create_node( - "start_node", - is_start=True, - is_static=True, - prompt="Welcome to our service!", - ) - next_node = self.create_node("next_node", is_static=False) - - edge = self.create_edge("start_node", "next_node") - start_node.out_edges = [edge] - - mock_workflow.nodes = {"start_node": start_node, "next_node": next_node} - - # Execute - await engine.set_node("start_node") - - # Verify - # Should queue TTS immediately - engine.task.queue_frames.assert_called_once() - frames = engine.task.queue_frames.call_args[0][0] - assert len(frames) == 3 - assert isinstance(frames[0], LLMFullResponseStartFrame) - assert isinstance(frames[1], TTSSpeakFrame) - assert frames[1].text == "Welcome to our service!" - assert isinstance(frames[2], LLMFullResponseEndFrame) - - # Static start nodes now set pending transition after context push - assert engine._pending_control_transition_after_context_push is not None - - # Should not have set detect_voicemail for static start without it - assert not engine._detect_voicemail - - @pytest.mark.asyncio - async def test_start_node_with_detect_voicemail_no_audio_buffer( - self, engine, mock_workflow - ): - """Test: Start node with voicemail detection but no audio buffer logs warning.""" - # Setup - start_node = self.create_node( - "start_node", - is_start=True, - is_static=True, - detect_voicemail=True, - prompt="Hello, this is a business call.", - ) - - mock_workflow.nodes = {"start_node": start_node} - - # Engine has no audio buffer (None) - assert engine._audio_buffer is None - - # Execute - await engine.set_node("start_node") - - # Verify - # Should NOT set voicemail detection flag since no audio buffer - assert engine._detect_voicemail is False - assert engine._voicemail_detector is None - - # Should queue TTS immediately - engine.task.queue_frames.assert_called_once() - frames = engine.task.queue_frames.call_args[0][0] - assert isinstance(frames[1], TTSSpeakFrame) - assert frames[1].text == "Hello, this is a business call." - - @pytest.mark.asyncio - async def test_start_node_non_static_with_detect_voicemail( - self, engine, mock_workflow - ): - """Test: Non-static start node with voicemail detection without audio buffer.""" - # Setup - start_node = self.create_node( - "start_node", - is_start=True, - is_static=False, # Non-static - detect_voicemail=True, - prompt="You are an AI assistant. Start the conversation.", - ) - - mock_workflow.nodes = {"start_node": start_node} - - # Mock the context update method - engine._update_llm_context = AsyncMock() - engine._compose_system_message_functions_for_node = AsyncMock( - return_value=({"role": "system", "content": "Test prompt"}, []) - ) - - # Execute - await engine.set_node("start_node") - - # Verify - # Should NOT set voicemail detection flags (no audio buffer) - assert engine._detect_voicemail is False - assert engine._voicemail_detector is None - - # Should update LLM context for non-static node - engine._update_llm_context.assert_called_once() - - # Should queue context frame - engine.task.queue_frame.assert_called_once() - frame = engine.task.queue_frame.call_args[0][0] - assert isinstance(frame, OpenAILLMContextFrame) - - @pytest.mark.asyncio - async def test_start_node_static_with_wait_for_user_response( - self, engine, mock_workflow - ): - """Test: Static start node with wait_for_user_response.""" - # Setup - start_node = self.create_node( - "start_node", - is_start=True, - is_static=True, - wait_for_user_response=True, - prompt="Please tell me your name.", - ) - next_node = self.create_node("next_node") - - edge = self.create_edge("start_node", "next_node") - start_node.out_edges = [edge] - - mock_workflow.nodes = {"start_node": start_node, "next_node": next_node} - - # Execute - await engine.set_node("start_node") - - # Verify - # Should queue TTS immediately - engine.task.queue_frames.assert_called_once() - - # Should have a pending control transition that will start the timer - assert engine._pending_control_transition_after_context_push is not None - - # Timer task should not exist yet - assert ( - not hasattr(engine, "_user_response_timeout_task") - or engine._user_response_timeout_task is None - ) - - # Simulate context push to start the timer - await engine.flush_pending_transitions(source="context_push") - - # Now the timeout task should be created - assert engine._user_response_timeout_task is not None - assert not engine._user_response_timeout_task.done() - - # Clean up the task - engine._user_response_timeout_task.cancel() - - @pytest.mark.asyncio - async def test_start_node_non_static(self, engine, mock_workflow): - """Test: Non-static start node sends context to LLM.""" - # Setup - start_node = self.create_node( - "start_node", - is_start=True, - is_static=False, - prompt="You are a helpful assistant. Greet the user.", - ) - - mock_workflow.nodes = {"start_node": start_node} - - # Mock the context update method - engine._update_llm_context = AsyncMock() - engine._compose_system_message_functions_for_node = AsyncMock( - return_value=({"role": "system", "content": "Test prompt"}, []) - ) - - # Execute - await engine.set_node("start_node") - - # Verify - # Should set context name - engine.context.set_node_name.assert_called_once_with("Node start_node") - - # Should update LLM context - engine._update_llm_context.assert_called_once() - - # Should queue context frame - engine.task.queue_frame.assert_called_once() - frame = engine.task.queue_frame.call_args[0][0] - assert isinstance(frame, OpenAILLMContextFrame) - - # ===== AGENT NODE TESTS ===== - - @pytest.mark.asyncio - async def test_agent_node_static(self, engine, mock_workflow): - """Test: Static agent node plays TTS and transitions.""" - # Setup - agent_node = self.create_node( - "agent_node", is_static=True, prompt="Processing your request..." - ) - next_node = self.create_node("next_node") - - edge = self.create_edge("agent_node", "next_node") - agent_node.out_edges = [edge] - - mock_workflow.nodes = {"agent_node": agent_node, "next_node": next_node} - - # Execute - await engine.set_node("agent_node") - - # Verify - # Should queue TTS - engine.task.queue_frames.assert_called_once() - frames = engine.task.queue_frames.call_args[0][0] - assert isinstance(frames[1], TTSSpeakFrame) - assert frames[1].text == "Processing your request..." - - # Should have pending transition - assert engine._pending_control_transition_after_context_push is not None - - @pytest.mark.asyncio - async def test_agent_node_non_static(self, engine, mock_workflow): - """Test: Non-static agent node sends context to LLM.""" - # Setup - agent_node = self.create_node( - "agent_node", - is_static=False, - prompt="Analyze the user's request and respond appropriately.", - ) - decision_node = self.create_node("decision_node") - - edge = self.create_edge("agent_node", "decision_node", "analyze_complete") - agent_node.out_edges = [edge] - - mock_workflow.nodes = {"agent_node": agent_node, "decision_node": decision_node} - - # Mock methods - engine._update_llm_context = AsyncMock() - engine._compose_system_message_functions_for_node = AsyncMock( - return_value=( - {"role": "system", "content": "Test"}, - [{"name": "test_func"}], - ) - ) - - # Execute - await engine.set_node("agent_node") - - # Verify - # Should register transition function - engine.llm.register_function.assert_called_once() - call_args = engine.llm.register_function.call_args - assert call_args[0][0] == "analyze_complete" - assert callable(call_args[0][1]) # Check it's a function - assert call_args[1]["cancel_on_interruption"] is True - - # Should update context and send frame - engine._update_llm_context.assert_called_once() - engine.task.queue_frame.assert_called_once() - - @pytest.mark.asyncio - async def test_agent_node_with_interruption_control(self, engine, mock_workflow): - """Test: Agent node respects allow_interrupt flag.""" - # Setup - no_interrupt_node = self.create_node( - "no_interrupt", - is_static=True, - allow_interrupt=False, - prompt="Please wait while I process...", - ) - - mock_workflow.nodes = {"no_interrupt": no_interrupt_node} - - # Execute - await engine.set_node("no_interrupt") - - # Verify current node is set (for STT mute callback) - assert engine._current_node == no_interrupt_node - assert engine._current_node.allow_interrupt is False - - # ===== END NODE TESTS ===== - - @pytest.mark.asyncio - async def test_end_node_static(self, engine, mock_workflow): - """Test: Static end node plays final message and schedules end task.""" - # Setup - end_node = self.create_node( - "end_node", - is_static=True, - is_end=True, - prompt="Thank you for calling. Goodbye!", - ) - - mock_workflow.nodes = {"end_node": end_node} - - # Execute - await engine.set_node("end_node") - - # Verify - # Should queue TTS - engine.task.queue_frames.assert_called_once() - frames = engine.task.queue_frames.call_args[0][0] - assert frames[1].text == "Thank you for calling. Goodbye!" - - # Should have pending end task - assert engine._pending_control_transition_after_context_push is not None - - # Execute the pending transition - await engine._pending_control_transition_after_context_push() - - # Should have sent EndFrame via task.queue_frame - # The second call should be the EndFrame (first was TTS frames) - assert engine.task.queue_frame.call_count >= 1 - end_frame = engine.task.queue_frame.call_args[0][0] - assert isinstance(end_frame, EndFrame) - - @pytest.mark.asyncio - async def test_end_node_with_extraction(self, engine, mock_workflow): - """Test: End node with variable extraction.""" - # Setup - end_node = self.create_node( - "end_node", - is_end=True, - is_static=False, - extraction_enabled=True, - extraction_variables=["user_name", "satisfaction_level"], - extraction_prompt="Extract user name and satisfaction", - ) - - mock_workflow.nodes = {"end_node": end_node} - - # Mock the extraction manager - engine._variable_extraction_manager = Mock() - engine._perform_variable_extraction_if_needed = AsyncMock() - - # Mock context update and composition methods - engine._update_llm_context = AsyncMock() - engine._compose_system_message_functions_for_node = AsyncMock( - return_value=({"role": "system", "content": "Test"}, []) - ) - - # Execute - await engine.set_node("end_node") - - # Verify - # Should trigger extraction - engine._perform_variable_extraction_if_needed.assert_called_once_with(end_node) - - # Should have pending end task - assert engine._pending_control_transition_after_context_push is not None - - # ===== CALLBACK INTEGRATION TESTS ===== - - @pytest.mark.asyncio - async def test_user_stopped_speaking_during_response_wait( - self, engine, mock_workflow - ): - """Test: User stops speaking triggers transition during wait_for_response.""" - # Setup - start_node = self.create_node( - "start_node", is_start=True, is_static=True, wait_for_user_response=True - ) - next_node = self.create_node("next_node") - - edge = self.create_edge("start_node", "next_node") - start_node.out_edges = [edge] - - mock_workflow.nodes = {"start_node": start_node, "next_node": next_node} - - # Set current node to start node - engine._current_node = start_node - engine._user_response_timeout_task = asyncio.create_task(asyncio.sleep(3)) - - # Create callback and execute - callback = engine.create_user_stopped_speaking_callback() - - # Mock set_node to avoid recursion - with patch.object(engine, "set_node", new=AsyncMock()) as mock_set_node: - await callback() - - # Verify - mock_set_node.assert_called_once_with("next_node") - assert engine._queue_context_frame is False # Should be set to False - - @pytest.mark.asyncio - async def test_context_push_callback_executes_pending_transitions(self, engine): - """Test: flush_pending_transitions executes deferred transitions.""" - # Setup pending transitions - mock_generated_transition = AsyncMock() - mock_control_transition = AsyncMock() - - engine._pending_generated_transition_after_context_push = ( - mock_generated_transition - ) - engine._pending_control_transition_after_context_push = mock_control_transition - - # Execute - await engine.flush_pending_transitions(source="context_push") - - # Verify both transitions were executed - mock_generated_transition.assert_called_once() - mock_control_transition.assert_called_once() - - # Verify they were cleared - assert engine._pending_generated_transition_after_context_push is None - assert engine._pending_control_transition_after_context_push is None - - # ===== COMPLEX SCENARIO TESTS ===== - - -# Add helper for testing with real async behavior -def ANY(cls=None): - """Helper for matching any argument in mock calls.""" - - class AnyMatcher: - def __init__(self, cls): - self.cls = cls - - def __eq__(self, other): - if self.cls: - return isinstance(other, self.cls) - return True - - return AnyMatcher(cls) diff --git a/api/tests/test_provider_switching.py b/api/tests/test_provider_switching.py index c2a5f01..9f11db7 100644 --- a/api/tests/test_provider_switching.py +++ b/api/tests/test_provider_switching.py @@ -5,46 +5,41 @@ handles provider switches correctly without losing billing data. """ import asyncio -import json -from datetime import datetime, timezone -from typing import Dict, Any # Test scenarios to validate + async def test_scenario_1_mid_call_provider_switch(): """ Test: What happens if provider is switched while a call is active? - + Expected behavior: - Active call continues with original provider - Call is billed to original provider - New calls use new provider """ print("Test 1: Mid-call provider switching") - + # Simulate workflow run with Twilio twilio_run = { "id": 1, "mode": "twilio", - "cost_info": { - "twilio_call_sid": "CA123456789", - "provider": "twilio" - }, - "is_completed": False + "cost_info": {"twilio_call_sid": "CA123456789", "provider": "twilio"}, + "is_completed": False, } - + # Provider switch happens here (in real scenario, user changes config) # But the call continues... - + # When cost calculation runs, it should: # 1. Use the provider stored in cost_info # 2. Fetch cost from Twilio using twilio_call_sid # 3. Store cost with provider attribution - + result = { "test": "mid_call_switch", "status": "PASS", - "reason": "Call continues with original provider, billing intact" + "reason": "Call continues with original provider, billing intact", } print(f" ✓ {result['reason']}") return result @@ -53,41 +48,41 @@ async def test_scenario_1_mid_call_provider_switch(): async def test_scenario_2_pending_cost_calculation(): """ Test: Calls that ended but cost not yet calculated when provider switches. - + Expected behavior: - Background job should use the provider info stored in cost_info - Cost should be fetched from correct provider """ print("\nTest 2: Pending cost calculation during switch") - + # Workflow runs that ended but cost job hasn't run yet pending_runs = [ { "id": 2, - "mode": "twilio", + "mode": "twilio", "cost_info": {"twilio_call_sid": "CA987654321", "provider": "twilio"}, - "is_completed": True + "is_completed": True, }, { "id": 3, "mode": "vonage", "cost_info": {"vonage_call_uuid": "uuid-123", "provider": "vonage"}, - "is_completed": True - } + "is_completed": True, + }, ] - + # Provider switch happens here # Cost calculation jobs run after switch - + # Each job should: # 1. Check the provider field in cost_info # 2. Use appropriate provider API to fetch cost # 3. Handle gracefully if credentials changed - + result = { "test": "pending_cost_calculation", "status": "PASS", - "reason": "Cost jobs use stored provider info correctly" + "reason": "Cost jobs use stored provider info correctly", } print(f" ✓ {result['reason']}") return result @@ -96,33 +91,37 @@ async def test_scenario_2_pending_cost_calculation(): async def test_scenario_3_mixed_provider_history(): """ Test: Organization has calls from both Twilio and Vonage. - + Expected behavior: - Historical costs remain intact - Reports show correct attribution - Total costs aggregate correctly """ print("\nTest 3: Mixed provider history") - + historical_runs = [ {"provider": "twilio", "cost_usd": 0.15, "date": "2024-01-01"}, {"provider": "vonage", "cost_usd": 0.12, "date": "2024-01-02"}, {"provider": "twilio", "cost_usd": 0.18, "date": "2024-01-03"}, {"provider": "vonage", "cost_usd": 0.14, "date": "2024-01-04"}, ] - + # Calculate totals total_cost = sum(run["cost_usd"] for run in historical_runs) - twilio_cost = sum(run["cost_usd"] for run in historical_runs if run["provider"] == "twilio") - vonage_cost = sum(run["cost_usd"] for run in historical_runs if run["provider"] == "vonage") - + twilio_cost = sum( + run["cost_usd"] for run in historical_runs if run["provider"] == "twilio" + ) + vonage_cost = sum( + run["cost_usd"] for run in historical_runs if run["provider"] == "vonage" + ) + result = { "test": "mixed_provider_history", - "status": "PASS", + "status": "PASS", "total_cost": total_cost, "twilio_cost": twilio_cost, "vonage_cost": vonage_cost, - "reason": f"Costs correctly aggregated: Total ${total_cost:.2f} (Twilio: ${twilio_cost:.2f}, Vonage: ${vonage_cost:.2f})" + "reason": f"Costs correctly aggregated: Total ${total_cost:.2f} (Twilio: ${twilio_cost:.2f}, Vonage: ${vonage_cost:.2f})", } print(f" ✓ {result['reason']}") return result @@ -131,41 +130,41 @@ async def test_scenario_3_mixed_provider_history(): async def test_scenario_4_cost_api_failure(): """ Test: Provider API fails when fetching cost. - + Expected behavior: - Error logged but system continues - Call record preserved - Cost marked as 0 or unknown """ print("\nTest 4: Cost API failure handling") - + # Simulate API failure scenarios failure_scenarios = [ { "provider": "twilio", "error": "401 Unauthorized - credentials changed", - "expected": "Cost set to 0, error logged" + "expected": "Cost set to 0, error logged", }, { "provider": "vonage", "error": "404 Not Found - call record deleted", - "expected": "Cost set to 0, error logged" + "expected": "Cost set to 0, error logged", }, { "provider": "twilio", "error": "500 Internal Server Error", - "expected": "Cost set to 0, retry possible" - } + "expected": "Cost set to 0, retry possible", + }, ] - + for scenario in failure_scenarios: print(f" - {scenario['provider']}: {scenario['error']}") print(f" Expected: {scenario['expected']}") - + result = { "test": "cost_api_failure", "status": "PASS", - "reason": "All failure scenarios handled gracefully" + "reason": "All failure scenarios handled gracefully", } print(f" ✓ {result['reason']}") return result @@ -174,22 +173,22 @@ async def test_scenario_4_cost_api_failure(): async def test_scenario_5_configuration_migration(): """ Test: Database migration from single to multi-provider format. - + Expected behavior: - Old TWILIO_CONFIGURATION migrated to TELEPHONY_CONFIGURATION - Single provider config wrapped in multi-provider structure - Existing cost_info gets provider field added """ print("\nTest 5: Configuration migration") - + # Old format old_config = { "account_sid": "AC123", - "auth_token": "token123", + "auth_token": "token123", "from_numbers": ["+1234567890"], - "provider": "twilio" + "provider": "twilio", } - + # New format after migration new_config = { "active_provider": "twilio", @@ -197,20 +196,20 @@ async def test_scenario_5_configuration_migration(): "twilio": { "account_sid": "AC123", "auth_token": "token123", - "from_numbers": ["+1234567890"] + "from_numbers": ["+1234567890"], } - } + }, } - + # Validate migration assert new_config["active_provider"] == "twilio" assert "providers" in new_config assert new_config["providers"]["twilio"]["account_sid"] == old_config["account_sid"] - + result = { "test": "configuration_migration", "status": "PASS", - "reason": "Configuration migrated to multi-provider format correctly" + "reason": "Configuration migrated to multi-provider format correctly", } print(f" ✓ {result['reason']}") return result @@ -219,39 +218,34 @@ async def test_scenario_5_configuration_migration(): async def test_scenario_6_provider_cost_discrepancy(): """ Test: Webhook cost vs API cost discrepancy. - + Expected behavior: - Webhook cost stored immediately if available - API cost fetched later for verification - Both costs stored for auditing """ print("\nTest 6: Provider cost discrepancy handling") - + # Vonage webhook provides immediate cost - webhook_cost = { - "vonage_webhook_price": 0.15, - "vonage_webhook_duration": 120 - } - + webhook_cost = {"vonage_webhook_price": 0.15, "vonage_webhook_duration": 120} + # API call provides authoritative cost api_cost = { "cost_usd": 0.14, # Slight difference - "duration": 120 + "duration": 120, } - + # Both should be stored final_cost_info = { **webhook_cost, - "cost_breakdown": { - "telephony_call": api_cost["cost_usd"] - }, - "provider": "vonage" + "cost_breakdown": {"telephony_call": api_cost["cost_usd"]}, + "provider": "vonage", } - + result = { "test": "cost_discrepancy", "status": "PASS", - "reason": "Both webhook and API costs stored for auditing" + "reason": "Both webhook and API costs stored for auditing", } print(f" ✓ {result['reason']}") return result @@ -262,40 +256,40 @@ async def run_all_tests(): print("=" * 60) print("PROVIDER SWITCHING TEST SUITE") print("=" * 60) - + tests = [ test_scenario_1_mid_call_provider_switch, test_scenario_2_pending_cost_calculation, test_scenario_3_mixed_provider_history, test_scenario_4_cost_api_failure, test_scenario_5_configuration_migration, - test_scenario_6_provider_cost_discrepancy + test_scenario_6_provider_cost_discrepancy, ] - + results = [] for test in tests: result = await test() results.append(result) - + print("\n" + "=" * 60) print("TEST SUMMARY") print("=" * 60) - + passed = sum(1 for r in results if r["status"] == "PASS") failed = sum(1 for r in results if r["status"] == "FAIL") - + print(f"Total Tests: {len(results)}") print(f"Passed: {passed}") print(f"Failed: {failed}") - + if failed == 0: print("\n✅ ALL TESTS PASSED - Provider switching is working correctly!") else: print("\n❌ Some tests failed - Review the implementation") - + return results if __name__ == "__main__": # Run the test suite - asyncio.run(run_all_tests()) \ No newline at end of file + asyncio.run(run_all_tests()) diff --git a/pipecat b/pipecat index fa68d2c..5365365 160000 --- a/pipecat +++ b/pipecat @@ -1 +1 @@ -Subproject commit fa68d2ce261544398013307d2c6a69e0556b4449 +Subproject commit 53653657d851e8052f9cc5b73b6f675a44c86fe7 diff --git a/ui/src/components/flow/nodes/EndCall.tsx b/ui/src/components/flow/nodes/EndCall.tsx index 7efa4d7..0ceee0d 100644 --- a/ui/src/components/flow/nodes/EndCall.tsx +++ b/ui/src/components/flow/nodes/EndCall.tsx @@ -18,8 +18,6 @@ interface EndCallEditFormProps { nodeData: FlowNodeData; prompt: string; setPrompt: (value: string) => void; - isStatic: boolean; - setIsStatic: (value: boolean) => void; name: string; setName: (value: string) => void; extractionEnabled: boolean; @@ -45,7 +43,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => { // Form state const [prompt, setPrompt] = useState(data.prompt); - const [isStatic, setIsStatic] = useState(data.is_static ?? true); const [name, setName] = useState(data.name); // Variable Extraction state @@ -58,7 +55,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => { handleSaveNodeData({ ...data, prompt, - is_static: isStatic, name, allow_interrupt: false, // Always set to false for end nodes extraction_enabled: extractionEnabled, @@ -77,7 +73,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => { const handleOpenChange = (newOpen: boolean) => { if (newOpen) { setPrompt(data.prompt); - setIsStatic(data.is_static ?? true); setName(data.name); setExtractionEnabled(data.extraction_enabled ?? false); setExtractionPrompt(data.extraction_prompt ?? ""); @@ -91,7 +86,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => { useEffect(() => { if (open) { setPrompt(data.prompt); - setIsStatic(data.is_static ?? true); setName(data.name); setExtractionEnabled(data.extraction_enabled ?? false); setExtractionPrompt(data.extraction_prompt ?? ""); @@ -137,8 +131,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => { nodeData={data} prompt={prompt} setPrompt={setPrompt} - isStatic={isStatic} - setIsStatic={setIsStatic} name={name} setName={setName} extractionEnabled={extractionEnabled} @@ -159,8 +151,6 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => { const EndCallEditForm = ({ prompt, setPrompt, - isStatic, - setIsStatic, name, setName, extractionEnabled, @@ -206,14 +196,10 @@ const EndCallEditForm = ({ setName(e.target.value)} /> - + -
- - -