mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
fix: fix cloudonix call hangup (#154)
This commit is contained in:
parent
a75bc72cb5
commit
b9ddd30813
10 changed files with 104 additions and 111 deletions
|
|
@ -1,3 +1,7 @@
|
|||
"""
|
||||
Route for 3rd party integrations. Currently being backed by nango.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from typing import Dict, List, Optional
|
|||
|
||||
from aiortc import RTCIceServer
|
||||
from aiortc.sdp import candidate_from_sdp
|
||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from loguru import logger
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
|
|
@ -390,6 +390,11 @@ async def signaling_websocket(
|
|||
user: UserModel = Depends(get_user_ws),
|
||||
):
|
||||
"""WebSocket endpoint for WebRTC signaling with ICE trickling."""
|
||||
workflow_run = await db_client.get_workflow_run(workflow_run_id, user.id)
|
||||
if not workflow_run:
|
||||
logger.warning(f"workflow run {workflow_run_id} not found for user {user.id}")
|
||||
raise HTTPException(status_code=400, detail="Bad workflow_run_id")
|
||||
|
||||
await signaling_manager.handle_websocket(
|
||||
websocket, workflow_id, workflow_run_id, user
|
||||
)
|
||||
|
|
|
|||
|
|
@ -266,7 +266,6 @@ async def run_pipeline_vobiz(
|
|||
async def run_pipeline_cloudonix(
|
||||
websocket_client: WebSocket,
|
||||
stream_sid: str,
|
||||
call_sid: str,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user_id: int,
|
||||
|
|
@ -275,10 +274,15 @@ async def run_pipeline_cloudonix(
|
|||
logger.debug(
|
||||
f"Running pipeline for Cloudonix connection with workflow_id: {workflow_id} and workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
call_id = workflow_run.gathered_context.get("call_id")
|
||||
if not call_id:
|
||||
logger.warning("call_id not found in gathered_context")
|
||||
raise Exception()
|
||||
|
||||
# Store call ID in cost_info for later cost calculation (provider-agnostic)
|
||||
cost_info = {"call_id": call_sid}
|
||||
cost_info = {"call_id": call_id}
|
||||
await db_client.update_workflow_run(workflow_run_id, cost_info=cost_info)
|
||||
|
||||
# Get workflow to extract all pipeline configurations
|
||||
|
|
@ -293,26 +297,18 @@ async def run_pipeline_cloudonix(
|
|||
"ambient_noise_configuration"
|
||||
]
|
||||
|
||||
# Retrieve session_token from workflow_run gathered_context
|
||||
workflow_run = await db_client.get_workflow_run(workflow_run_id)
|
||||
session_token = None
|
||||
if workflow_run and workflow_run.gathered_context:
|
||||
session_token = workflow_run.gathered_context.get("session_token")
|
||||
logger.debug(f"Retrieved session_token from workflow_run: {session_token}")
|
||||
|
||||
# Create audio configuration for Cloudonix
|
||||
audio_config = create_audio_config(WorkflowRunMode.CLOUDONIX.value)
|
||||
|
||||
transport = await create_cloudonix_transport(
|
||||
websocket_client,
|
||||
call_id,
|
||||
stream_sid,
|
||||
call_sid,
|
||||
workflow_run_id,
|
||||
audio_config,
|
||||
workflow.organization_id,
|
||||
vad_config,
|
||||
ambient_noise_config,
|
||||
session_token,
|
||||
)
|
||||
await _run_pipeline(
|
||||
transport,
|
||||
|
|
|
|||
|
|
@ -94,14 +94,13 @@ async def create_twilio_transport(
|
|||
|
||||
async def create_cloudonix_transport(
|
||||
websocket_client: WebSocket,
|
||||
call_id: str,
|
||||
stream_sid: str,
|
||||
call_sid: str,
|
||||
workflow_run_id: int,
|
||||
audio_config: AudioConfig,
|
||||
organization_id: int,
|
||||
vad_config: dict | None = None,
|
||||
ambient_noise_config: dict | None = None,
|
||||
session_token: str | None = None,
|
||||
):
|
||||
"""Create a transport for Cloudonix connections"""
|
||||
|
||||
|
|
@ -125,11 +124,10 @@ async def create_cloudonix_transport(
|
|||
from pipecat.serializers.cloudonix import CloudonixFrameSerializer
|
||||
|
||||
serializer = CloudonixFrameSerializer(
|
||||
call_id=call_id,
|
||||
stream_sid=stream_sid,
|
||||
call_sid=call_sid,
|
||||
domain_id=domain_id,
|
||||
bearer_token=bearer_token,
|
||||
session_token=session_token,
|
||||
)
|
||||
|
||||
return FastAPIWebsocketTransport(
|
||||
|
|
|
|||
|
|
@ -395,10 +395,6 @@ class CloudonixProvider(TelephonyProvider):
|
|||
await websocket.close(code=4400, reason="Expected connected event")
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Cloudonix WebSocket connected for workflow_run {workflow_run_id}"
|
||||
)
|
||||
|
||||
# Wait for "start" event with stream details
|
||||
start_msg = await websocket.receive_text()
|
||||
logger.debug(f"Received start message: {start_msg}")
|
||||
|
|
@ -418,9 +414,14 @@ class CloudonixProvider(TelephonyProvider):
|
|||
await websocket.close(code=4400, reason="Missing stream identifiers")
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Cloudonix WebSocket connected for workflow_run {workflow_run_id} "
|
||||
f"stream_sid: {stream_sid} call_sid: {call_sid}"
|
||||
)
|
||||
|
||||
# Run the Cloudonix pipeline
|
||||
await run_pipeline_cloudonix(
|
||||
websocket, stream_sid, call_sid, workflow_id, workflow_run_id, user_id
|
||||
websocket, stream_sid, workflow_id, workflow_run_id, user_id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class TwilioProvider(TelephonyProvider):
|
|||
return CallInitiationResult(
|
||||
call_id=response_data["sid"],
|
||||
status=response_data.get("status", "queued"),
|
||||
provider_metadata={}, # Twilio doesn't need to persist extra data
|
||||
provider_metadata={"call_id": response_data["sid"]},
|
||||
raw_response=response_data,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ class VobizProvider(TelephonyProvider):
|
|||
return CallInitiationResult(
|
||||
call_id=call_id,
|
||||
status="queued", # Vobiz returns "message": "call fired"
|
||||
provider_metadata={},
|
||||
provider_metadata={"call_id": call_id},
|
||||
raw_response=response_data,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -138,10 +138,8 @@ class VonageProvider(TelephonyProvider):
|
|||
call_id=response_data["uuid"],
|
||||
status=response_data.get("status", "started"),
|
||||
provider_metadata={
|
||||
"call_uuid": response_data[
|
||||
"uuid"
|
||||
] # Vonage needs UUID persisted for WebSocket
|
||||
},
|
||||
"call_uuid": response_data["uuid"]
|
||||
}, # Vonage needs UUID persisted for WebSocket
|
||||
raw_response=response_data,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,61 @@ from api.services.telephony.factory import get_telephony_provider
|
|||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
||||
|
||||
async def _fetch_telephony_cost(workflow_run) -> dict | None:
|
||||
"""Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None."""
|
||||
if (
|
||||
workflow_run.mode
|
||||
not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
|
||||
or not workflow_run.cost_info
|
||||
):
|
||||
return None
|
||||
|
||||
call_id = workflow_run.cost_info.get("call_id")
|
||||
if not call_id:
|
||||
logger.warning(f"call_id not found in cost_info")
|
||||
return None
|
||||
|
||||
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning("Workflow not found for workflow run")
|
||||
raise Exception("Workflow not found")
|
||||
|
||||
provider = await get_telephony_provider(workflow.organization_id)
|
||||
call_cost_info = await provider.get_call_cost(call_id)
|
||||
|
||||
if call_cost_info.get("status") == "error":
|
||||
logger.error(
|
||||
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
|
||||
)
|
||||
return None
|
||||
|
||||
cost_usd = call_cost_info.get("cost_usd", 0.0)
|
||||
logger.info(
|
||||
f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}"
|
||||
)
|
||||
return {"cost_usd": cost_usd, "provider_name": provider_name}
|
||||
|
||||
|
||||
async def _update_organization_usage(
|
||||
org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None
|
||||
) -> None:
|
||||
"""Update organization usage after a workflow run."""
|
||||
org_id = org.id
|
||||
await db_client.update_usage_after_run(
|
||||
org_id, dograh_tokens, duration_seconds, charge_usd
|
||||
)
|
||||
if charge_usd is not None:
|
||||
logger.info(
|
||||
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
|
||||
|
||||
async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
|
||||
# Set the run_id in context variable for consistent logging format
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
|
@ -26,62 +81,20 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
|
|||
# Calculate cost breakdown
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(workflow_usage_info)
|
||||
|
||||
# Fetch telephony call cost for both Twilio and Vonage
|
||||
telephony_cost_usd = 0.0
|
||||
if (
|
||||
workflow_run.mode
|
||||
in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
|
||||
and workflow_run.cost_info
|
||||
):
|
||||
# Get the call ID (provider-agnostic approach with backward compatibility)
|
||||
call_id = workflow_run.cost_info.get("call_id")
|
||||
|
||||
# Fallback to legacy provider-specific fields if needed
|
||||
if not call_id:
|
||||
if workflow_run.mode == WorkflowRunMode.TWILIO.value:
|
||||
call_id = workflow_run.cost_info.get("twilio_call_sid")
|
||||
elif workflow_run.mode == WorkflowRunMode.VONAGE.value:
|
||||
call_id = workflow_run.cost_info.get("vonage_call_uuid")
|
||||
|
||||
# Provider name is derived from workflow run mode
|
||||
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
|
||||
|
||||
if call_id:
|
||||
try:
|
||||
# Get workflow to access organization_id
|
||||
workflow = await db_client.get_workflow_by_id(
|
||||
workflow_run.workflow_id
|
||||
)
|
||||
if not workflow:
|
||||
logger.warning("Workflow not found for workflow run")
|
||||
raise Exception("Workflow not found")
|
||||
|
||||
# Use telephony provider abstraction
|
||||
provider = await get_telephony_provider(workflow.organization_id)
|
||||
call_cost_info = await provider.get_call_cost(call_id)
|
||||
|
||||
if call_cost_info.get("status") != "error":
|
||||
telephony_cost_usd = call_cost_info.get("cost_usd", 0.0)
|
||||
cost_breakdown["telephony_call"] = telephony_cost_usd
|
||||
cost_breakdown[f"{provider_name}_call"] = (
|
||||
telephony_cost_usd # Keep backward compatibility
|
||||
)
|
||||
|
||||
# Add telephony cost to the total
|
||||
cost_breakdown["total"] = (
|
||||
float(cost_breakdown["total"]) + telephony_cost_usd
|
||||
)
|
||||
logger.info(
|
||||
f"{provider_name.title()} call cost: ${telephony_cost_usd:.6f} USD for call {call_id}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch telephony call cost: {e}")
|
||||
# Don't fail the whole cost calculation if telephony API fails
|
||||
# Fetch telephony call cost
|
||||
try:
|
||||
telephony_cost = await _fetch_telephony_cost(workflow_run)
|
||||
if telephony_cost:
|
||||
telephony_cost_usd = telephony_cost["cost_usd"]
|
||||
provider_name = telephony_cost["provider_name"]
|
||||
cost_breakdown["telephony_call"] = telephony_cost_usd
|
||||
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
|
||||
cost_breakdown["total"] = (
|
||||
float(cost_breakdown["total"]) + telephony_cost_usd
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch telephony call cost: {e}")
|
||||
# Don't fail the whole cost calculation if telephony API fails
|
||||
|
||||
# Store cost information back to the workflow run
|
||||
# We'll add the cost breakdown to the workflow run
|
||||
|
|
@ -106,6 +119,7 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
|
|||
charge_usd = duration_seconds * org.price_per_second_usd
|
||||
|
||||
cost_info = {
|
||||
**workflow_run.cost_info,
|
||||
"cost_breakdown": cost_breakdown,
|
||||
"total_cost_usd": float(cost_breakdown["total"]),
|
||||
"dograh_token_usage": dograh_tokens,
|
||||
|
|
@ -118,42 +132,19 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
|
|||
cost_info["charge_usd"] = charge_usd
|
||||
cost_info["price_per_second_usd"] = org.price_per_second_usd
|
||||
|
||||
# Preserve call ID (provider-agnostic with backward compatibility)
|
||||
if workflow_run.cost_info:
|
||||
# Preserve generic call_id if it exists
|
||||
if "call_id" in workflow_run.cost_info:
|
||||
cost_info["call_id"] = workflow_run.cost_info["call_id"]
|
||||
# Also preserve legacy fields for backward compatibility
|
||||
elif "twilio_call_sid" in workflow_run.cost_info:
|
||||
cost_info["twilio_call_sid"] = workflow_run.cost_info["twilio_call_sid"]
|
||||
elif "vonage_call_uuid" in workflow_run.cost_info:
|
||||
cost_info["vonage_call_uuid"] = workflow_run.cost_info[
|
||||
"vonage_call_uuid"
|
||||
]
|
||||
|
||||
# Update workflow run with cost information
|
||||
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
|
||||
|
||||
# Update organization usage if applicable
|
||||
if org:
|
||||
org_id = org.id
|
||||
try:
|
||||
duration_seconds = workflow_usage_info.get("call_duration_seconds", 0)
|
||||
# Pass USD amount if organization has pricing
|
||||
await db_client.update_usage_after_run(
|
||||
org_id, dograh_tokens, duration_seconds, charge_usd
|
||||
await _update_organization_usage(
|
||||
org, dograh_tokens, duration_seconds, charge_usd
|
||||
)
|
||||
if charge_usd is not None:
|
||||
logger.info(
|
||||
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for org {org_id}: {e}"
|
||||
f"Failed to update organization usage for org {org.id}: {e}"
|
||||
)
|
||||
# Don't fail the whole task if usage update fails
|
||||
|
||||
|
|
|
|||
2
pipecat
2
pipecat
|
|
@ -1 +1 @@
|
|||
Subproject commit 1bd0ea6b44518040c87074c6086e4fedc0864ca9
|
||||
Subproject commit e180bd3c2abc3cebbdf5e2d7955d9928cca5d219
|
||||
Loading…
Add table
Add a link
Reference in a new issue