Merge pull request #8 from dograh-hq/feat/tricke-ice

feat: Trickle ice candidates for faster WebRTC connection
This commit is contained in:
Sabiha Khan 2025-09-24 15:00:17 +05:30 committed by GitHub
commit 034c551931
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 706 additions and 5 deletions

View file

@ -12,6 +12,7 @@ from api.routes.service_keys import router as service_keys_router
from api.routes.superuser import router as superuser_router
from api.routes.twilio import router as twilio_router
from api.routes.user import router as user_router
from api.routes.webrtc_signaling import router as webrtc_signaling_router
from api.routes.workflow import router as workflow_router
router = APIRouter(
@ -31,6 +32,7 @@ router.include_router(service_keys_router)
router.include_router(looptalk_router)
router.include_router(organization_usage_router)
router.include_router(reports_router)
router.include_router(webrtc_signaling_router)
@router.get("/health")

View file

@ -0,0 +1,224 @@
"""WebSocket-based WebRTC signaling endpoint with ICE trickling support."""
import asyncio
from typing import Dict
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from loguru import logger
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
from pipecat.utils.context import set_current_run_id
from api.db.models import UserModel
from api.services.auth.depends import get_user_ws
from api.services.pipecat.run_pipeline import run_pipeline_smallwebrtc
router = APIRouter(prefix="/ws")
# ICE servers configuration
ice_servers = ["stun:stun.l.google.com:19302"]
class SignalingManager:
"""Manages WebSocket connections and WebRTC peer connections."""
def __init__(self):
self._connections: Dict[str, WebSocket] = {}
self._peer_connections: Dict[str, SmallWebRTCConnection] = {}
async def handle_websocket(
self,
websocket: WebSocket,
workflow_id: int,
workflow_run_id: int,
user: UserModel,
):
"""Handle WebSocket connection for signaling."""
await websocket.accept()
connection_id = f"{workflow_id}:{workflow_run_id}:{user.id}"
self._connections[connection_id] = websocket
try:
while True:
message = await websocket.receive_json()
await self._handle_message(
websocket, message, workflow_id, workflow_run_id, user
)
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for {connection_id}")
except Exception as e:
logger.error(f"WebSocket error for {connection_id}: {e}")
finally:
# Cleanup
self._connections.pop(connection_id, None)
# Clean up peer connections for this workflow run
# Since we store connections by client pc_id directly,
# we need to iterate through all connections and check which ones to clean up
# We could track connections per workflow_run, but for now let's clean all
# connections that were using this WebSocket
for pc_id in list(self._peer_connections.keys()):
pc = self._peer_connections.pop(pc_id, None)
if pc and pc._signaling_ws == websocket:
await pc.disconnect()
logger.debug(f"Disconnected peer connection: {pc_id}")
async def _handle_message(
self,
ws: WebSocket,
message: dict,
workflow_id: int,
workflow_run_id: int,
user: UserModel,
):
"""Handle incoming WebSocket messages."""
msg_type = message.get("type")
payload = message.get("payload", {})
if msg_type == "offer":
await self._handle_offer(ws, payload, workflow_id, workflow_run_id, user)
elif msg_type == "ice-candidate":
await self._handle_ice_candidate(ws, payload, workflow_run_id)
elif msg_type == "renegotiate":
await self._handle_renegotiation(ws, payload, workflow_id, workflow_run_id)
async def _handle_offer(
self,
ws: WebSocket,
payload: dict,
workflow_id: int,
workflow_run_id: int,
user: UserModel,
):
"""Handle offer message and create answer with ICE trickling."""
pc_id = payload.get("pc_id")
sdp = payload.get("sdp")
type_ = payload.get("type")
call_context_vars = payload.get("call_context_vars", {})
# Set run context for logging
set_current_run_id(workflow_run_id)
if pc_id and pc_id in self._peer_connections:
# Reuse existing connection
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
pc = self._peer_connections[pc_id]
await pc.renegotiate(sdp=sdp, type=type_, restart_pc=False)
# Send updated answer
answer = pc.get_answer()
await ws.send_json(
{
"type": "answer",
"payload": {"sdp": answer["sdp"], "type": "answer", "pc_id": pc_id},
}
)
else:
# Create new connection with WebSocket support
# Pass client's pc_id as the name parameter
pc = SmallWebRTCConnection(
ice_servers, signaling_ws=ws, enable_trickling=True, name=pc_id
)
# Initialize with trickling support
answer = await pc.initialize_with_trickling(sdp=sdp, type=type_)
# Store peer connection using client's pc_id
self._peer_connections[pc_id] = pc
# Setup closed handler
@pc.event_handler("closed")
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
logger.info(f"PeerConnection closed: {webrtc_connection.pc_id}")
self._peer_connections.pop(webrtc_connection.pc_id, None)
# Start pipeline in background
asyncio.create_task(
run_pipeline_smallwebrtc(
pc, workflow_id, workflow_run_id, user.id, call_context_vars
)
)
# Send answer immediately (without ICE candidates)
await ws.send_json(
{
"type": "answer",
"payload": {
"sdp": answer.sdp,
"type": answer.type,
"pc_id": pc.pc_id, # This will be the same as pc_id we passed
},
}
)
async def _handle_ice_candidate(
self, ws: WebSocket, payload: dict, workflow_run_id: int
):
"""Handle incoming ICE candidate from client."""
pc_id = payload.get("pc_id")
candidate = payload.get("candidate")
if not pc_id:
logger.warning("Received ICE candidate without pc_id")
return
pc = self._peer_connections.get(pc_id)
if not pc:
logger.warning(f"No peer connection found for pc_id: {pc_id}")
return
if candidate:
try:
await pc.add_ice_candidate(candidate)
logger.debug(f"Added ICE candidate for pc_id: {pc_id}")
except Exception as e:
logger.error(f"Failed to add ICE candidate: {e}")
else:
logger.debug(f"End of ICE candidates for pc_id: {pc_id}")
async def _handle_renegotiation(
self, ws: WebSocket, payload: dict, workflow_id: int, workflow_run_id: int
):
"""Handle renegotiation request."""
pc_id = payload.get("pc_id")
sdp = payload.get("sdp")
type_ = payload.get("type")
restart_pc = payload.get("restart_pc", False)
if not pc_id or pc_id not in self._peer_connections:
await ws.send_json(
{"type": "error", "payload": {"message": "Peer connection not found"}}
)
return
pc = self._peer_connections[pc_id]
await pc.renegotiate(sdp=sdp, type=type_, restart_pc=restart_pc)
# Send updated answer
answer = pc.get_answer()
await ws.send_json(
{
"type": "answer",
"payload": {
"sdp": answer["sdp"],
"type": "answer",
"pc_id": pc_id, # Use the client's pc_id
},
}
)
# Create singleton instance
signaling_manager = SignalingManager()
@router.websocket("/signaling/{workflow_id}/{workflow_run_id}")
async def signaling_websocket(
websocket: WebSocket,
workflow_id: int,
workflow_run_id: int,
user: UserModel = Depends(get_user_ws),
):
"""WebSocket endpoint for WebRTC signaling with ICE trickling."""
await signaling_manager.handle_websocket(
websocket, workflow_id, workflow_run_id, user
)

View file

@ -1,7 +1,7 @@
from typing import Annotated, Optional
import httpx
from fastapi import Header, HTTPException
from fastapi import Header, HTTPException, Query, WebSocket
from loguru import logger
from pydantic import ValidationError
@ -328,3 +328,26 @@ async def get_superuser(
)
return user
async def get_user_ws(
websocket: WebSocket,
token: str = Query(None),
) -> UserModel:
"""
WebSocket authentication dependency.
Uses token from query parameters for authentication.
"""
if not token:
await websocket.close(code=1008, reason="Missing authentication token")
raise HTTPException(status_code=401, detail="Missing authentication token")
# Use the same logic as get_user but with token from query
authorization = f"Bearer {token}"
try:
user = await get_user(authorization)
return user
except HTTPException as e:
await websocket.close(code=1008, reason=e.detail)
raise

View file

@ -11,7 +11,7 @@ import {
ConnectionStatus,
WorkflowConfigErrorDialog
} from "./components";
import { useWebRTC } from "./hooks";
import { useWebSocketRTC } from "./hooks";
const BrowserCall = ({ workflowId, workflowRunId, accessToken, initialContextVariables }: {
workflowId: number,
@ -40,7 +40,7 @@ const BrowserCall = ({ workflowId, workflowRunId, accessToken, initialContextVar
start,
stop,
isStarting
} = useWebRTC({ workflowId, workflowRunId, accessToken, initialContextVariables });
} = useWebSocketRTC({ workflowId, workflowRunId, accessToken, initialContextVariables });
// Poll for recording availability after call ends
useEffect(() => {

View file

@ -1,2 +1,3 @@
export * from './useDeviceInputs';
export * from './useWebRTC';
export * from './useWebSocketRTC';

View file

@ -3,7 +3,6 @@ import { useRef, useState } from "react";
import { offerApiV1PipecatRtcOfferPost, validateUserConfigurationsApiV1UserConfigurationsUserValidateGet, validateWorkflowApiV1WorkflowWorkflowIdValidatePost } from "@/client/sdk.gen";
import { WorkflowValidationError } from "@/components/flow/types";
import logger from '@/lib/logger';
import { getRandomId } from "@/lib/utils";
import { sdpFilterCodec } from "../utils";
import { useDeviceInputs } from "./useDeviceInputs";
@ -42,7 +41,19 @@ export const useWebRTC = ({ workflowId, workflowRunId, accessToken, initialConte
const audioRef = useRef<HTMLAudioElement>(null);
const pcRef = useRef<RTCPeerConnection | null>(null);
const timeStartRef = useRef<number | null>(null);
const pc_id = 'PC-' + getRandomId().toString();
// Generate a cryptographically secure unique ID
const generateSecureId = () => {
// Use Web Crypto API to generate random bytes
const array = new Uint8Array(16);
crypto.getRandomValues(array);
// Convert to hex string
return 'PC-' + Array.from(array)
.map(b => b.toString(16).padStart(2, '0'))
.join('');
};
const pc_id = generateSecureId();
const createPeerConnection = () => {
const config: RTCConfiguration = {

View file

@ -0,0 +1,440 @@
import { useCallback, useEffect, useRef, useState } from "react";
import { client } from "@/client/client.gen";
import { validateUserConfigurationsApiV1UserConfigurationsUserValidateGet, validateWorkflowApiV1WorkflowWorkflowIdValidatePost } from "@/client/sdk.gen";
import { WorkflowValidationError } from "@/components/flow/types";
import logger from '@/lib/logger';
import { sdpFilterCodec } from "../utils";
import { useDeviceInputs } from "./useDeviceInputs";
interface UseWebSocketRTCProps {
workflowId: number;
workflowRunId: number;
accessToken: string | null;
initialContextVariables?: Record<string, string> | null;
}
export const useWebSocketRTC = ({ workflowId, workflowRunId, accessToken, initialContextVariables }: UseWebSocketRTCProps) => {
const [connectionStatus, setConnectionStatus] = useState<'idle' | 'connecting' | 'connected' | 'failed'>('idle');
const [connectionActive, setConnectionActive] = useState(false);
const [isCompleted, setIsCompleted] = useState(false);
const [apiKeyModalOpen, setApiKeyModalOpen] = useState(false);
const [apiKeyError, setApiKeyError] = useState<string | null>(null);
const [workflowConfigModalOpen, setWorkflowConfigModalOpen] = useState(false);
const [workflowConfigError, setWorkflowConfigError] = useState<string | null>(null);
const [isStarting, setIsStarting] = useState(false);
const initialContext = initialContextVariables || {};
const {
audioInputs,
selectedAudioInput,
setSelectedAudioInput,
permissionError,
setPermissionError
} = useDeviceInputs();
const useStun = true;
const useAudio = true;
const audioCodec = 'default';
const audioRef = useRef<HTMLAudioElement>(null);
const pcRef = useRef<RTCPeerConnection | null>(null);
const wsRef = useRef<WebSocket | null>(null);
const timeStartRef = useRef<number | null>(null);
// Generate a cryptographically secure unique ID
const generateSecureId = () => {
// Use Web Crypto API to generate random bytes
const array = new Uint8Array(16);
crypto.getRandomValues(array);
// Convert to hex string
return 'PC-' + Array.from(array)
.map(b => b.toString(16).padStart(2, '0'))
.join('');
};
const pc_id = useRef(generateSecureId());
// Get WebSocket URL from client configuration
const getWebSocketUrl = useCallback(() => {
// Get base URL from client configuration
const baseUrl = client.getConfig().baseUrl || 'http://127.0.0.1:8000';
// Convert HTTP to WS protocol
const wsUrl = baseUrl.replace(/^http/, 'ws');
return `${wsUrl}/api/v1/ws/signaling/${workflowId}/${workflowRunId}?token=${accessToken}`;
}, [workflowId, workflowRunId, accessToken]);
const createPeerConnection = () => {
const config: RTCConfiguration = {
iceServers: useStun ? [{ urls: ['stun:stun.l.google.com:19302'] }] : []
};
const pc = new RTCPeerConnection(config);
// Set up ICE candidate trickling
pc.addEventListener('icecandidate', (event) => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
const message = {
type: 'ice-candidate',
payload: {
candidate: event.candidate ? {
candidate: event.candidate.candidate,
sdpMid: event.candidate.sdpMid,
sdpMLineIndex: event.candidate.sdpMLineIndex
} : null,
pc_id: pc_id.current
}
};
wsRef.current.send(JSON.stringify(message));
if (event.candidate) {
logger.debug(`Sending ICE candidate: ${event.candidate.candidate}`);
} else {
logger.debug('Sending end-of-candidates signal');
}
}
});
pc.addEventListener('iceconnectionstatechange', () => {
logger.info(`ICE connection state changed: ${pc.iceConnectionState}`);
if (pc.iceConnectionState === 'connected' || pc.iceConnectionState === 'completed') {
setConnectionStatus('connected');
} else if (pc.iceConnectionState === 'failed') {
setConnectionStatus('failed');
} else if (pc.iceConnectionState === 'disconnected') {
// Server-initiated disconnect - clean up gracefully
logger.info('Server initiated disconnect - cleaning up connection');
// Close WebSocket if still open
if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
}
// Mark as completed to trigger recording check
setConnectionActive(false);
setIsCompleted(true);
setConnectionStatus('idle');
// Clean up peer connection
if (pc.getTransceivers) {
pc.getTransceivers().forEach((transceiver) => {
if (transceiver.stop) {
transceiver.stop();
}
});
}
pc.getSenders().forEach((sender) => {
if (sender.track) {
sender.track.stop();
}
});
}
});
pc.addEventListener('track', (evt) => {
if (evt.track.kind === 'audio' && audioRef.current) {
audioRef.current.srcObject = evt.streams[0];
}
});
pcRef.current = pc;
return pc;
};
const connectWebSocket = useCallback(() => {
return new Promise<void>((resolve, reject) => {
const wsUrl = getWebSocketUrl();
logger.info(`Connecting to WebSocket: ${wsUrl}`);
const ws = new WebSocket(wsUrl);
ws.onopen = () => {
logger.info('WebSocket connected');
wsRef.current = ws;
resolve();
};
ws.onerror = (error) => {
logger.error('WebSocket error:', error);
reject(error);
};
ws.onclose = () => {
logger.info('WebSocket closed');
wsRef.current = null;
// Don't set failed status if already completed (graceful disconnect)
if (connectionActive && !isCompleted) {
setConnectionStatus('failed');
}
};
ws.onmessage = async (event) => {
try {
const message = JSON.parse(event.data);
switch (message.type) {
case 'answer':
// Set remote description immediately (may have no candidates)
const answer = message.payload;
logger.debug('Received answer from server');
if (pcRef.current) {
await pcRef.current.setRemoteDescription({
type: 'answer',
sdp: answer.sdp
});
setConnectionActive(true);
logger.info('Remote description set');
}
break;
case 'ice-candidate':
// Add ICE candidate from server
const candidate = message.payload.candidate;
if (candidate && pcRef.current) {
try {
await pcRef.current.addIceCandidate({
candidate: candidate.candidate,
sdpMid: candidate.sdpMid,
sdpMLineIndex: candidate.sdpMLineIndex
});
logger.debug(`Added remote ICE candidate: ${candidate.candidate}`);
} catch (e) {
logger.error('Failed to add ICE candidate:', e);
}
} else if (!candidate) {
logger.debug('Received end-of-candidates signal from server');
}
break;
case 'error':
logger.error('Server error:', message.payload);
break;
default:
logger.warn('Unknown message type:', message.type);
}
} catch (e) {
logger.error('Failed to handle WebSocket message:', e);
}
};
});
}, [getWebSocketUrl, connectionActive, isCompleted]);
const negotiate = async () => {
const pc = pcRef.current;
const ws = wsRef.current;
if (!pc || !ws || ws.readyState !== WebSocket.OPEN) {
logger.error('Cannot negotiate: PC or WebSocket not ready');
return;
}
try {
// Create offer
const offer = await pc.createOffer();
await pc.setLocalDescription(offer);
const localDescription = pc.localDescription;
if (!localDescription) return;
let sdp = localDescription.sdp;
if (audioCodec !== 'default') {
sdp = sdpFilterCodec('audio', audioCodec, sdp);
}
// Send offer immediately via WebSocket (without waiting for ICE gathering)
const message = {
type: 'offer',
payload: {
sdp: sdp,
type: 'offer',
pc_id: pc_id.current,
workflow_id: workflowId,
workflow_run_id: workflowRunId,
call_context_vars: initialContext
}
};
ws.send(JSON.stringify(message));
logger.info('Sent offer via WebSocket (ICE trickling enabled)');
} catch (e) {
logger.error(`Negotiation failed: ${e}`);
setConnectionStatus('failed');
}
};
const start = async () => {
if (isStarting || !accessToken) return;
setIsStarting(true);
setConnectionStatus('connecting');
try {
// Validate API keys
const response = await validateUserConfigurationsApiV1UserConfigurationsUserValidateGet({
headers: {
'Authorization': `Bearer ${accessToken}`,
},
query: {
validity_ttl_seconds: 86400
},
});
if (response.error) {
setApiKeyModalOpen(true);
let msg = 'API Key Error';
const detail = (response.error as unknown as { detail?: { errors: { model: string; message: string }[] } }).detail;
if (Array.isArray(detail)) {
msg = detail
.map((e: { model: string; message: string }) => `${e.model}: ${e.message}`)
.join('\n');
}
setApiKeyError(msg);
setConnectionStatus('failed');
return;
}
// Validate workflow
const workflowResponse = await validateWorkflowApiV1WorkflowWorkflowIdValidatePost({
path: {
workflow_id: workflowId,
},
headers: {
'Authorization': `Bearer ${accessToken}`,
},
});
if (workflowResponse.error) {
setWorkflowConfigModalOpen(true);
let msg = 'Workflow validation failed';
const errorDetail = workflowResponse.error as { detail?: { errors: WorkflowValidationError[] } };
if (errorDetail?.detail?.errors) {
msg = errorDetail.detail.errors
.map(err => `${err.kind}: ${err.message}`)
.join('\n');
}
setWorkflowConfigError(msg);
setConnectionStatus('failed');
return;
}
// Connect WebSocket first
await connectWebSocket();
// Create peer connection
timeStartRef.current = null;
const pc = createPeerConnection();
// Set up media constraints
const constraints: MediaStreamConstraints = {
audio: false,
};
if (useAudio) {
const audioConstraints: MediaTrackConstraints = {};
if (selectedAudioInput) {
audioConstraints.deviceId = { exact: selectedAudioInput };
}
constraints.audio = Object.keys(audioConstraints).length ? audioConstraints : true;
}
// Get user media and negotiate
if (constraints.audio) {
try {
const stream = await navigator.mediaDevices.getUserMedia(constraints);
stream.getTracks().forEach((track) => {
pc.addTrack(track, stream);
});
await negotiate();
} catch (err) {
logger.error(`Could not acquire media: ${err}`);
setPermissionError('Could not acquire media');
setConnectionStatus('failed');
}
} else {
await negotiate();
}
} catch (error) {
logger.error('Failed to start connection:', error);
setConnectionStatus('failed');
} finally {
setIsStarting(false);
}
};
const stop = () => {
setConnectionActive(false);
setIsCompleted(true);
setConnectionStatus('idle');
// Close WebSocket
if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
}
// Close peer connection
const pc = pcRef.current;
if (!pc) return;
if (pc.getTransceivers) {
pc.getTransceivers().forEach((transceiver) => {
if (transceiver.stop) {
transceiver.stop();
}
});
}
pc.getSenders().forEach((sender) => {
if (sender.track) {
sender.track.stop();
}
});
setTimeout(() => {
if (pcRef.current) {
pcRef.current.close();
pcRef.current = null;
}
}, 500);
};
// Cleanup on unmount
useEffect(() => {
return () => {
if (wsRef.current) {
wsRef.current.close();
}
if (pcRef.current) {
pcRef.current.close();
}
};
}, []);
return {
audioRef,
audioInputs,
selectedAudioInput,
setSelectedAudioInput,
connectionActive,
permissionError,
isCompleted,
apiKeyModalOpen,
setApiKeyModalOpen,
apiKeyError,
workflowConfigError,
workflowConfigModalOpen,
setWorkflowConfigModalOpen,
connectionStatus,
start,
stop,
isStarting,
initialContext
};
};