diff --git a/api/routes/integration.py b/api/routes/integration.py index 71eb727..ae4d98c 100644 --- a/api/routes/integration.py +++ b/api/routes/integration.py @@ -1,3 +1,7 @@ +""" +Route for 3rd party integrations. Currently being backed by nango. +""" + from dataclasses import dataclass from typing import Any, Dict, List, Optional, TypedDict diff --git a/api/routes/webrtc_signaling.py b/api/routes/webrtc_signaling.py index d33aa5d..2dadd83 100644 --- a/api/routes/webrtc_signaling.py +++ b/api/routes/webrtc_signaling.py @@ -22,7 +22,7 @@ from typing import Dict, List, Optional from aiortc import RTCIceServer from aiortc.sdp import candidate_from_sdp -from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect from loguru import logger from starlette.websockets import WebSocketState @@ -390,6 +390,11 @@ async def signaling_websocket( user: UserModel = Depends(get_user_ws), ): """WebSocket endpoint for WebRTC signaling with ICE trickling.""" + workflow_run = await db_client.get_workflow_run(workflow_run_id, user.id) + if not workflow_run: + logger.warning(f"workflow run {workflow_run_id} not found for user {user.id}") + raise HTTPException(status_code=400, detail="Bad workflow_run_id") + await signaling_manager.handle_websocket( websocket, workflow_id, workflow_run_id, user ) diff --git a/api/services/pipecat/run_pipeline.py b/api/services/pipecat/run_pipeline.py index 41a1882..f0107e9 100644 --- a/api/services/pipecat/run_pipeline.py +++ b/api/services/pipecat/run_pipeline.py @@ -266,7 +266,6 @@ async def run_pipeline_vobiz( async def run_pipeline_cloudonix( websocket_client: WebSocket, stream_sid: str, - call_sid: str, workflow_id: int, workflow_run_id: int, user_id: int, @@ -275,10 +274,15 @@ async def run_pipeline_cloudonix( logger.debug( f"Running pipeline for Cloudonix connection with workflow_id: {workflow_id} and workflow_run_id: {workflow_run_id}" ) - set_current_run_id(workflow_run_id) + + workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id) + call_id = workflow_run.gathered_context.get("call_id") + if not call_id: + logger.warning("call_id not found in gathered_context") + raise Exception() # Store call ID in cost_info for later cost calculation (provider-agnostic) - cost_info = {"call_id": call_sid} + cost_info = {"call_id": call_id} await db_client.update_workflow_run(workflow_run_id, cost_info=cost_info) # Get workflow to extract all pipeline configurations @@ -293,26 +297,18 @@ async def run_pipeline_cloudonix( "ambient_noise_configuration" ] - # Retrieve session_token from workflow_run gathered_context - workflow_run = await db_client.get_workflow_run(workflow_run_id) - session_token = None - if workflow_run and workflow_run.gathered_context: - session_token = workflow_run.gathered_context.get("session_token") - logger.debug(f"Retrieved session_token from workflow_run: {session_token}") - # Create audio configuration for Cloudonix audio_config = create_audio_config(WorkflowRunMode.CLOUDONIX.value) transport = await create_cloudonix_transport( websocket_client, + call_id, stream_sid, - call_sid, workflow_run_id, audio_config, workflow.organization_id, vad_config, ambient_noise_config, - session_token, ) await _run_pipeline( transport, diff --git a/api/services/pipecat/transport_setup.py b/api/services/pipecat/transport_setup.py index 0290062..6cee7fb 100644 --- a/api/services/pipecat/transport_setup.py +++ b/api/services/pipecat/transport_setup.py @@ -94,14 +94,13 @@ async def create_twilio_transport( async def create_cloudonix_transport( websocket_client: WebSocket, + call_id: str, stream_sid: str, - call_sid: str, workflow_run_id: int, audio_config: AudioConfig, organization_id: int, vad_config: dict | None = None, ambient_noise_config: dict | None = None, - session_token: str | None = None, ): """Create a transport for Cloudonix connections""" @@ -125,11 +124,10 @@ async def create_cloudonix_transport( from pipecat.serializers.cloudonix import CloudonixFrameSerializer serializer = CloudonixFrameSerializer( + call_id=call_id, stream_sid=stream_sid, - call_sid=call_sid, domain_id=domain_id, bearer_token=bearer_token, - session_token=session_token, ) return FastAPIWebsocketTransport( diff --git a/api/services/telephony/providers/cloudonix_provider.py b/api/services/telephony/providers/cloudonix_provider.py index 8449499..c6b2a55 100644 --- a/api/services/telephony/providers/cloudonix_provider.py +++ b/api/services/telephony/providers/cloudonix_provider.py @@ -395,10 +395,6 @@ class CloudonixProvider(TelephonyProvider): await websocket.close(code=4400, reason="Expected connected event") return - logger.debug( - f"Cloudonix WebSocket connected for workflow_run {workflow_run_id}" - ) - # Wait for "start" event with stream details start_msg = await websocket.receive_text() logger.debug(f"Received start message: {start_msg}") @@ -418,9 +414,14 @@ class CloudonixProvider(TelephonyProvider): await websocket.close(code=4400, reason="Missing stream identifiers") return + logger.debug( + f"Cloudonix WebSocket connected for workflow_run {workflow_run_id} " + f"stream_sid: {stream_sid} call_sid: {call_sid}" + ) + # Run the Cloudonix pipeline await run_pipeline_cloudonix( - websocket, stream_sid, call_sid, workflow_id, workflow_run_id, user_id + websocket, stream_sid, workflow_id, workflow_run_id, user_id ) except Exception as e: diff --git a/api/services/telephony/providers/twilio_provider.py b/api/services/telephony/providers/twilio_provider.py index 713e282..3c020b0 100644 --- a/api/services/telephony/providers/twilio_provider.py +++ b/api/services/telephony/providers/twilio_provider.py @@ -110,7 +110,7 @@ class TwilioProvider(TelephonyProvider): return CallInitiationResult( call_id=response_data["sid"], status=response_data.get("status", "queued"), - provider_metadata={}, # Twilio doesn't need to persist extra data + provider_metadata={"call_id": response_data["sid"]}, raw_response=response_data, ) diff --git a/api/services/telephony/providers/vobiz_provider.py b/api/services/telephony/providers/vobiz_provider.py index ddaf9c2..0666d85 100644 --- a/api/services/telephony/providers/vobiz_provider.py +++ b/api/services/telephony/providers/vobiz_provider.py @@ -150,7 +150,7 @@ class VobizProvider(TelephonyProvider): return CallInitiationResult( call_id=call_id, status="queued", # Vobiz returns "message": "call fired" - provider_metadata={}, + provider_metadata={"call_id": call_id}, raw_response=response_data, ) diff --git a/api/services/telephony/providers/vonage_provider.py b/api/services/telephony/providers/vonage_provider.py index 315f04d..ee25d78 100644 --- a/api/services/telephony/providers/vonage_provider.py +++ b/api/services/telephony/providers/vonage_provider.py @@ -138,10 +138,8 @@ class VonageProvider(TelephonyProvider): call_id=response_data["uuid"], status=response_data.get("status", "started"), provider_metadata={ - "call_uuid": response_data[ - "uuid" - ] # Vonage needs UUID persisted for WebSocket - }, + "call_uuid": response_data["uuid"] + }, # Vonage needs UUID persisted for WebSocket raw_response=response_data, ) diff --git a/api/tasks/workflow_run_cost.py b/api/tasks/workflow_run_cost.py index a0c06a1..176e3a8 100644 --- a/api/tasks/workflow_run_cost.py +++ b/api/tasks/workflow_run_cost.py @@ -7,6 +7,61 @@ from api.services.telephony.factory import get_telephony_provider from pipecat.utils.run_context import set_current_run_id +async def _fetch_telephony_cost(workflow_run) -> dict | None: + """Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None.""" + if ( + workflow_run.mode + not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value] + or not workflow_run.cost_info + ): + return None + + call_id = workflow_run.cost_info.get("call_id") + if not call_id: + logger.warning(f"call_id not found in cost_info") + return None + + provider_name = workflow_run.mode.lower() if workflow_run.mode else "" + + workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id) + if not workflow: + logger.warning("Workflow not found for workflow run") + raise Exception("Workflow not found") + + provider = await get_telephony_provider(workflow.organization_id) + call_cost_info = await provider.get_call_cost(call_id) + + if call_cost_info.get("status") == "error": + logger.error( + f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}" + ) + return None + + cost_usd = call_cost_info.get("cost_usd", 0.0) + logger.info( + f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}" + ) + return {"cost_usd": cost_usd, "provider_name": provider_name} + + +async def _update_organization_usage( + org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None +) -> None: + """Update organization usage after a workflow run.""" + org_id = org.id + await db_client.update_usage_after_run( + org_id, dograh_tokens, duration_seconds, charge_usd + ) + if charge_usd is not None: + logger.info( + f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}" + ) + else: + logger.info( + f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}" + ) + + async def calculate_workflow_run_cost(ctx, workflow_run_id: int): # Set the run_id in context variable for consistent logging format set_current_run_id(workflow_run_id) @@ -26,62 +81,20 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int): # Calculate cost breakdown cost_breakdown = cost_calculator.calculate_total_cost(workflow_usage_info) - # Fetch telephony call cost for both Twilio and Vonage - telephony_cost_usd = 0.0 - if ( - workflow_run.mode - in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value] - and workflow_run.cost_info - ): - # Get the call ID (provider-agnostic approach with backward compatibility) - call_id = workflow_run.cost_info.get("call_id") - - # Fallback to legacy provider-specific fields if needed - if not call_id: - if workflow_run.mode == WorkflowRunMode.TWILIO.value: - call_id = workflow_run.cost_info.get("twilio_call_sid") - elif workflow_run.mode == WorkflowRunMode.VONAGE.value: - call_id = workflow_run.cost_info.get("vonage_call_uuid") - - # Provider name is derived from workflow run mode - provider_name = workflow_run.mode.lower() if workflow_run.mode else "" - - if call_id: - try: - # Get workflow to access organization_id - workflow = await db_client.get_workflow_by_id( - workflow_run.workflow_id - ) - if not workflow: - logger.warning("Workflow not found for workflow run") - raise Exception("Workflow not found") - - # Use telephony provider abstraction - provider = await get_telephony_provider(workflow.organization_id) - call_cost_info = await provider.get_call_cost(call_id) - - if call_cost_info.get("status") != "error": - telephony_cost_usd = call_cost_info.get("cost_usd", 0.0) - cost_breakdown["telephony_call"] = telephony_cost_usd - cost_breakdown[f"{provider_name}_call"] = ( - telephony_cost_usd # Keep backward compatibility - ) - - # Add telephony cost to the total - cost_breakdown["total"] = ( - float(cost_breakdown["total"]) + telephony_cost_usd - ) - logger.info( - f"{provider_name.title()} call cost: ${telephony_cost_usd:.6f} USD for call {call_id}" - ) - else: - logger.error( - f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}" - ) - - except Exception as e: - logger.error(f"Failed to fetch telephony call cost: {e}") - # Don't fail the whole cost calculation if telephony API fails + # Fetch telephony call cost + try: + telephony_cost = await _fetch_telephony_cost(workflow_run) + if telephony_cost: + telephony_cost_usd = telephony_cost["cost_usd"] + provider_name = telephony_cost["provider_name"] + cost_breakdown["telephony_call"] = telephony_cost_usd + cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd + cost_breakdown["total"] = ( + float(cost_breakdown["total"]) + telephony_cost_usd + ) + except Exception as e: + logger.error(f"Failed to fetch telephony call cost: {e}") + # Don't fail the whole cost calculation if telephony API fails # Store cost information back to the workflow run # We'll add the cost breakdown to the workflow run @@ -106,6 +119,7 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int): charge_usd = duration_seconds * org.price_per_second_usd cost_info = { + **workflow_run.cost_info, "cost_breakdown": cost_breakdown, "total_cost_usd": float(cost_breakdown["total"]), "dograh_token_usage": dograh_tokens, @@ -118,42 +132,19 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int): cost_info["charge_usd"] = charge_usd cost_info["price_per_second_usd"] = org.price_per_second_usd - # Preserve call ID (provider-agnostic with backward compatibility) - if workflow_run.cost_info: - # Preserve generic call_id if it exists - if "call_id" in workflow_run.cost_info: - cost_info["call_id"] = workflow_run.cost_info["call_id"] - # Also preserve legacy fields for backward compatibility - elif "twilio_call_sid" in workflow_run.cost_info: - cost_info["twilio_call_sid"] = workflow_run.cost_info["twilio_call_sid"] - elif "vonage_call_uuid" in workflow_run.cost_info: - cost_info["vonage_call_uuid"] = workflow_run.cost_info[ - "vonage_call_uuid" - ] - # Update workflow run with cost information await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info) # Update organization usage if applicable if org: - org_id = org.id try: duration_seconds = workflow_usage_info.get("call_duration_seconds", 0) - # Pass USD amount if organization has pricing - await db_client.update_usage_after_run( - org_id, dograh_tokens, duration_seconds, charge_usd + await _update_organization_usage( + org, dograh_tokens, duration_seconds, charge_usd ) - if charge_usd is not None: - logger.info( - f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}" - ) - else: - logger.info( - f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}" - ) except Exception as e: logger.error( - f"Failed to update organization usage for org {org_id}: {e}" + f"Failed to update organization usage for org {org.id}: {e}" ) # Don't fail the whole task if usage update fails diff --git a/pipecat b/pipecat index 1bd0ea6..e180bd3 160000 --- a/pipecat +++ b/pipecat @@ -1 +1 @@ -Subproject commit 1bd0ea6b44518040c87074c6086e4fedc0864ca9 +Subproject commit e180bd3c2abc3cebbdf5e2d7955d9928cca5d219