diff --git a/api/routes/webrtc_signaling.py b/api/routes/webrtc_signaling.py index 522d862..45bff78 100644 --- a/api/routes/webrtc_signaling.py +++ b/api/routes/webrtc_signaling.py @@ -1,8 +1,18 @@ -"""WebSocket-based WebRTC signaling endpoint with ICE trickling support.""" +"""WebSocket-based WebRTC signaling endpoint with ICE trickling support. + +This implementation uses WebSocket-based signaling instead of HTTP PATCH for ICE candidates, +which is suitable for multi-worker FastAPI deployments where local _pcs_map cannot be shared. + +Uses the SmallWebRTC API contract: +- SmallWebRTCConnection for peer connection management +- candidate_from_sdp() for parsing ICE candidates +- add_ice_candidate() for trickling support +""" import asyncio from typing import Dict +from aiortc.sdp import candidate_from_sdp from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect from loguru import logger @@ -51,14 +61,14 @@ class SignalingManager: # Cleanup self._connections.pop(connection_id, None) - # Clean up peer connections for this workflow run - # Since we store connections by client pc_id directly, - # we need to iterate through all connections and check which ones to clean up - # We could track connections per workflow_run, but for now let's clean all - # connections that were using this WebSocket + # Clean up all peer connections for this workflow run + # Note: In a WebSocket-based signaling approach (vs HTTP PATCH), + # we maintain our own connection map instead of relying on + # SmallWebRTCRequestHandler's _pcs_map. This is suitable for + # multi-worker FastAPI deployments where state cannot be shared. for pc_id in list(self._peer_connections.keys()): pc = self._peer_connections.pop(pc_id, None) - if pc and pc._signaling_ws == websocket: + if pc: await pc.disconnect() logger.debug(f"Disconnected peer connection: {pc_id}") @@ -113,14 +123,13 @@ class SignalingManager: } ) else: - # Create new connection with WebSocket support - # Pass client's pc_id as the name parameter - pc = SmallWebRTCConnection( - ice_servers, signaling_ws=ws, enable_trickling=True, name=pc_id - ) + # Create new connection using correct SmallWebRTC API + pc = SmallWebRTCConnection(ice_servers=ice_servers, connection_timeout_secs=20) + # Set the pc_id before initialization so it's available in get_answer() + pc._pc_id = pc_id - # Initialize with trickling support - answer = await pc.initialize_with_trickling(sdp=sdp, type=type_) + # Initialize connection with offer + await pc.initialize(sdp=sdp, type=type_) # Store peer connection using client's pc_id self._peer_connections[pc_id] = pc @@ -138,14 +147,17 @@ class SignalingManager: ) ) - # Send answer immediately (without ICE candidates) + # Get answer after initialization + answer = pc.get_answer() + + # Send answer immediately (ICE candidates will be sent separately via trickling) await ws.send_json( { "type": "answer", "payload": { - "sdp": answer.sdp, - "type": answer.type, - "pc_id": pc.pc_id, # This will be the same as pc_id we passed + "sdp": answer["sdp"], + "type": answer["type"], + "pc_id": answer["pc_id"], }, } ) @@ -153,9 +165,14 @@ class SignalingManager: async def _handle_ice_candidate( self, ws: WebSocket, payload: dict, workflow_run_id: int ): - """Handle incoming ICE candidate from client.""" + """Handle incoming ICE candidate from client. + + Uses SmallWebRTC's native ICE trickling support via add_ice_candidate(). + Candidates are parsed using aiortc's candidate_from_sdp() for proper formatting, + consistent with SmallWebRTCRequestHandler.handle_patch_request(). + """ pc_id = payload.get("pc_id") - candidate = payload.get("candidate") + candidate_data = payload.get("candidate") if not pc_id: logger.warning("Received ICE candidate without pc_id") @@ -166,8 +183,13 @@ class SignalingManager: logger.warning(f"No peer connection found for pc_id: {pc_id}") return - if candidate: + if candidate_data: try: + # Parse the ICE candidate using aiortc's parser (same as SmallWebRTCRequestHandler) + candidate = candidate_from_sdp(candidate_data["candidate"]) + candidate.sdpMid = candidate_data.get("sdpMid") + candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex") + await pc.add_ice_candidate(candidate) logger.debug(f"Added ICE candidate for pc_id: {pc_id}") except Exception as e: diff --git a/api/services/telephony/stasis_rtp_transport.py b/api/services/telephony/stasis_rtp_transport.py index 110d508..cd3716d 100644 --- a/api/services/telephony/stasis_rtp_transport.py +++ b/api/services/telephony/stasis_rtp_transport.py @@ -23,8 +23,7 @@ from pipecat.frames.frames import ( from pipecat.serializers.base_serializer import FrameSerializer from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import ( - BaseOutputTransport, - TransportClientNotConnectedException, + BaseOutputTransport ) from pipecat.transports.base_transport import BaseTransport, TransportParams @@ -204,7 +203,7 @@ class StasisRTPOutputTransport(BaseOutputTransport): async def write_audio_frame(self, frame: OutputAudioRawFrame): """Write audio frame to RTP stream.""" if self._client.is_closing: - raise TransportClientNotConnectedException() + return False if not self._client.is_connected: # If not connected yet, just simulate playback delay. diff --git a/pipecat b/pipecat index 278248a..3b190ce 160000 --- a/pipecat +++ b/pipecat @@ -1 +1 @@ -Subproject commit 278248a40cf7a8cb11d32534016ffec099408f8c +Subproject commit 3b190cebd80e9324cccf6be3cd31f1abd17897ca