"""WebSocket-based WebRTC signaling endpoint with ICE trickling support. This implementation uses WebSocket-based signaling instead of HTTP PATCH for ICE candidates, which is suitable for multi-worker FastAPI deployments where local _pcs_map cannot be shared. Uses the SmallWebRTC API contract: - SmallWebRTCConnection for peer connection management - candidate_from_sdp() for parsing ICE candidates - add_ice_candidate() for trickling support TURN Authentication: - Uses time-limited credentials (TURN REST API) when TURN_SECRET is configured - Credentials are generated per-connection using HMAC-SHA1 - Falls back to static credentials if TURN_SECRET is not set (legacy mode) """ import asyncio import ipaddress import os from datetime import UTC, datetime from typing import Dict, List, Optional from aiortc import RTCIceServer from aiortc.sdp import candidate_from_sdp from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect from loguru import logger from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection from pipecat.utils.run_context import set_current_org_id, set_current_run_id from starlette.websockets import WebSocketState from api.constants import ENVIRONMENT, FORCE_TURN_RELAY from api.db import db_client from api.db.models import UserModel from api.enums import Environment from api.routes.turn_credentials import ( TURN_HOST, TURN_PORT, TURN_SECRET, generate_turn_credentials, ) from api.services.auth.depends import get_user_ws from api.services.pipecat.run_pipeline import run_pipeline_smallwebrtc from api.services.pipecat.ws_sender_registry import ( register_ws_sender, unregister_ws_sender, ) from api.services.quota_service import check_dograh_quota router = APIRouter(prefix="/ws") def is_private_ip_candidate(candidate_str: str) -> bool: """Check if ICE candidate contains a private IP address or CGNAT IP Address. Parses the candidate string to extract the IP address and checks if it's private. This is used to filter out host candidates with private IPs in non-local environments, preventing TURN relay errors when coturn blocks private IP ranges or CGNAT IP Addresses. Args: candidate_str: ICE candidate string, e.g., "candidate:123 1 udp 2122260223 192.168.50.24 63603 typ host ..." Returns: True if the candidate contains a private IP, False otherwise. """ try: parts = candidate_str.split() # Find "typ" and get the IP which is 2 positions before it if "typ" in parts: typ_index = parts.index("typ") ip_str = parts[typ_index - 2] ip = ipaddress.ip_address(ip_str) is_cgnat = ip in ipaddress.ip_network("100.64.0.0/10") return ip.is_private or is_cgnat except (ValueError, IndexError): pass return False def filter_outbound_sdp(sdp: str) -> str: """Strip ICE candidates from an outbound answer SDP based on env config. Two filters apply: 1. In non-LOCAL environments, drop host candidates with private/CGNAT IPs. aiortc gathers host candidates from every interface on the box, including Docker bridges (172.17.0.1, 172.18.0.1). Advertising those to the browser causes coturn "peer IP X denied" errors when the browser asks TURN to permit them. 2. When FORCE_TURN_RELAY is set, drop every non-relay candidate so the only path the browser can use is via TURN. Lets you verify TURN connectivity end-to-end — if TURN is broken, the call simply fails. """ if ENVIRONMENT == Environment.LOCAL.value and not FORCE_TURN_RELAY: return sdp lines = sdp.split("\r\n") filtered: List[str] = [] dropped_non_relay = 0 kept_relay = 0 for line in lines: if line.startswith("a=candidate:"): candidate_str = line[2:] if FORCE_TURN_RELAY and " typ relay" not in candidate_str: dropped_non_relay += 1 continue if ENVIRONMENT != Environment.LOCAL.value and is_private_ip_candidate( candidate_str ): continue if FORCE_TURN_RELAY: kept_relay += 1 filtered.append(line) if FORCE_TURN_RELAY: if kept_relay == 0: logger.warning( "FORCE_TURN_RELAY is on but the answer SDP has no relay candidates " f"(dropped {dropped_non_relay} non-relay). TURN may be unreachable; " "the connection will fail." ) else: logger.info( f"FORCE_TURN_RELAY: kept {kept_relay} relay candidates, " f"dropped {dropped_non_relay} non-relay" ) return "\r\n".join(filtered) def get_ice_servers(user_id: Optional[str] = None) -> List[RTCIceServer]: """Build ICE servers configuration including TURN if configured. Args: user_id: Optional user ID for generating time-limited TURN credentials. If provided and TURN_SECRET is configured, uses TURN REST API. Returns: List of RTCIceServer configurations for WebRTC peer connection. """ servers: List[RTCIceServer] = [RTCIceServer(urls="stun:stun.l.google.com:19302")] # Check if TURN is configured if not TURN_HOST: return servers # Use time-limited credentials if TURN_SECRET is configured (recommended) if TURN_SECRET and user_id: try: credentials = generate_turn_credentials(user_id) servers.append( RTCIceServer( urls=credentials["uris"], username=credentials["username"], credential=credentials["password"], ) ) logger.info( f"TURN server configured with time-limited credentials, TTL: {credentials['ttl']}s" ) return servers except Exception as e: logger.error(f"Failed to generate TURN credentials: {e}") # Fallback to static credentials (legacy mode - not recommended for production) turn_username = os.getenv("TURN_USERNAME") turn_password = os.getenv("TURN_PASSWORD") if turn_username and turn_password: servers.append( RTCIceServer( urls=[ f"turn:{TURN_HOST}:{TURN_PORT}", f"turn:{TURN_HOST}:{TURN_PORT}?transport=tcp", ], username=turn_username, credential=turn_password, ) ) logger.warning( f"TURN server configured with static credentials (consider using TURN_SECRET for time-limited auth)" ) return servers class SignalingManager: """Manages WebSocket connections and WebRTC peer connections.""" def __init__(self): self._connections: Dict[str, WebSocket] = {} self._peer_connections: Dict[str, SmallWebRTCConnection] = {} async def handle_websocket( self, websocket: WebSocket, workflow_id: int, workflow_run_id: int, user: UserModel, ): """Handle WebSocket connection for signaling.""" await websocket.accept() connection_id = f"{workflow_id}:{workflow_run_id}:{user.id}" self._connections[connection_id] = websocket try: while True: message = await websocket.receive_json() await self._handle_message( websocket, message, workflow_id, workflow_run_id, user ) except WebSocketDisconnect: logger.info(f"WebSocket disconnected for {connection_id}") except Exception as e: logger.error(f"WebSocket error for {connection_id}: {e}") finally: # Cleanup self._connections.pop(connection_id, None) # Unregister WebSocket sender for real-time feedback unregister_ws_sender(workflow_run_id) # Clean up all peer connections for this workflow run # Note: In a WebSocket-based signaling approach (vs HTTP PATCH), # we maintain our own connection map instead of relying on # SmallWebRTCRequestHandler's _pcs_map. This is suitable for # multi-worker FastAPI deployments where state cannot be shared. for pc_id in list(self._peer_connections.keys()): pc = self._peer_connections.pop(pc_id, None) if pc: await pc.disconnect() logger.debug(f"Disconnected peer connection: {pc_id}") async def _handle_message( self, ws: WebSocket, message: dict, workflow_id: int, workflow_run_id: int, user: UserModel, ): """Handle incoming WebSocket messages.""" msg_type = message.get("type") payload = message.get("payload", {}) if msg_type == "offer": await self._handle_offer(ws, payload, workflow_id, workflow_run_id, user) elif msg_type == "ice-candidate": await self._handle_ice_candidate(ws, payload, workflow_run_id) elif msg_type == "renegotiate": await self._handle_renegotiation(ws, payload, workflow_id, workflow_run_id) async def _handle_offer( self, ws: WebSocket, payload: dict, workflow_id: int, workflow_run_id: int, user: UserModel, ): """Handle offer message and create answer with ICE trickling.""" pc_id = payload.get("pc_id") sdp = payload.get("sdp") type_ = payload.get("type") call_context_vars = payload.get("call_context_vars", {}) # Set run context for logging and tracing. org_id must be set before # pc.initialize() so that aiortc's internal tasks inherit it. set_current_run_id(workflow_run_id) org_id = await db_client.get_workflow_organization_id(workflow_id) if org_id: set_current_org_id(org_id) # Check Dograh quota before initiating the call (apply per-workflow # model_overrides so we evaluate the keys this workflow will use). quota_result = await check_dograh_quota(user, workflow_id=workflow_id) if not quota_result.has_quota: # Send error response for quota issues await ws.send_json( { "type": "error", "payload": { "error_type": quota_result.error_code, "message": quota_result.error_message, }, } ) return if pc_id and pc_id in self._peer_connections: # Reuse existing connection logger.info(f"Reusing existing connection for pc_id: {pc_id}") pc = self._peer_connections[pc_id] await pc.renegotiate(sdp=sdp, type=type_, restart_pc=False) # Send updated answer answer = pc.get_answer() await ws.send_json( { "type": "answer", "payload": { "sdp": filter_outbound_sdp(answer["sdp"]), "type": "answer", "pc_id": pc_id, }, } ) else: # Create new connection using correct SmallWebRTC API # Generate ICE servers with time-limited TURN credentials for this user user_ice_servers = get_ice_servers(user_id=str(user.id)) pc = SmallWebRTCConnection( ice_servers=user_ice_servers, connection_timeout_secs=60 ) # Set the pc_id before initialization so it's available in get_answer() pc._pc_id = pc_id # Initialize connection with offer await pc.initialize(sdp=sdp, type=type_) # Store peer connection using client's pc_id self._peer_connections[pc_id] = pc # Register WebSocket sender for real-time feedback async def ws_sender(message: dict): if ws.application_state == WebSocketState.CONNECTED: await ws.send_json(message) register_ws_sender(workflow_run_id, ws_sender) # Setup closed handler @pc.event_handler("closed") async def handle_disconnected(webrtc_connection: SmallWebRTCConnection): logger.info(f"PeerConnection closed: {webrtc_connection.pc_id}") self._peer_connections.pop(webrtc_connection.pc_id, None) # Start pipeline in background asyncio.create_task( run_pipeline_smallwebrtc( pc, workflow_id, workflow_run_id, user.id, call_context_vars, user_provider_id=str(user.provider_id), ) ) # Get answer after initialization answer = pc.get_answer() # Send answer immediately (ICE candidates will be sent separately via trickling) await ws.send_json( { "type": "answer", "payload": { "sdp": filter_outbound_sdp(answer["sdp"]), "type": answer["type"], "pc_id": answer["pc_id"], }, } ) async def _handle_ice_candidate( self, ws: WebSocket, payload: dict, workflow_run_id: int ): """Handle incoming ICE candidate from client. Uses SmallWebRTC's native ICE trickling support via add_ice_candidate(). Candidates are parsed using aiortc's candidate_from_sdp() for proper formatting, consistent with SmallWebRTCRequestHandler.handle_patch_request(). In non-local environments, private IP candidates are filtered out to prevent TURN relay errors when coturn blocks private IP ranges (denied-peer-ip). """ pc_id = payload.get("pc_id") candidate_data = payload.get("candidate") if not pc_id: logger.warning("Received ICE candidate without pc_id") return pc = self._peer_connections.get(pc_id) if not pc: logger.warning(f"No peer connection found for pc_id: {pc_id}") return if candidate_data: candidate_str = candidate_data.get("candidate", "") # Filter out private IP candidates in non-local environments # This prevents TURN relay errors when coturn blocks private IP ranges if ENVIRONMENT != Environment.LOCAL.value and is_private_ip_candidate( candidate_str ): logger.debug( f"Skipping private IP candidate in {ENVIRONMENT}: {candidate_str[:50]}..." ) return try: # Parse the ICE candidate using aiortc's parser (same as SmallWebRTCRequestHandler) candidate = candidate_from_sdp(candidate_str) candidate.sdpMid = candidate_data.get("sdpMid") candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex") await pc.add_ice_candidate(candidate) logger.debug(f"Added ICE candidate for pc_id: {pc_id}") except Exception as e: logger.error(f"Failed to add ICE candidate: {e}") else: logger.debug(f"End of ICE candidates for pc_id: {pc_id}") async def _handle_renegotiation( self, ws: WebSocket, payload: dict, workflow_id: int, workflow_run_id: int ): """Handle renegotiation request.""" pc_id = payload.get("pc_id") sdp = payload.get("sdp") type_ = payload.get("type") restart_pc = payload.get("restart_pc", False) if not pc_id or pc_id not in self._peer_connections: await ws.send_json( {"type": "error", "payload": {"message": "Peer connection not found"}} ) return pc = self._peer_connections[pc_id] await pc.renegotiate(sdp=sdp, type=type_, restart_pc=restart_pc) # Send updated answer answer = pc.get_answer() await ws.send_json( { "type": "answer", "payload": { "sdp": filter_outbound_sdp(answer["sdp"]), "type": "answer", "pc_id": pc_id, # Use the client's pc_id }, } ) # Create singleton instance signaling_manager = SignalingManager() @router.websocket("/signaling/{workflow_id}/{workflow_run_id}") async def signaling_websocket( websocket: WebSocket, workflow_id: int, workflow_run_id: int, user: UserModel = Depends(get_user_ws), ): """WebSocket endpoint for WebRTC signaling with ICE trickling.""" workflow_run = await db_client.get_workflow_run(workflow_run_id, user.id) if not workflow_run: logger.warning(f"workflow run {workflow_run_id} not found for user {user.id}") raise HTTPException(status_code=400, detail="Bad workflow_run_id") await signaling_manager.handle_websocket( websocket, workflow_id, workflow_run_id, user ) @router.websocket("/public/signaling/{session_token}") async def public_signaling_websocket( websocket: WebSocket, session_token: str, ): """Public WebSocket endpoint for WebRTC signaling with embed tokens. This endpoint: 1. Validates the session token from embed initialization 2. Retrieves the associated workflow run 3. Handles WebRTC signaling without requiring authentication """ # Validate session token embed_session = await db_client.get_embed_session_by_token(session_token) if not embed_session: await websocket.close(code=1008, reason="Invalid session token") return # Check if session is expired if embed_session.expires_at and embed_session.expires_at < datetime.now(UTC): await websocket.close(code=1008, reason="Session expired") return # Get the embed token for user information embed_token = await db_client.get_embed_token_by_id(embed_session.embed_token_id) if not embed_token: await websocket.close(code=1008, reason="Invalid embed token") return # Create a minimal user object for compatibility with signaling manager # Use the embed token creator as the user user = await db_client.get_user_by_id(embed_token.created_by) if not user: await websocket.close(code=1008, reason="Invalid user") return # Handle the WebSocket connection using the existing signaling manager await signaling_manager.handle_websocket( websocket, embed_token.workflow_id, embed_session.workflow_run_id, user )