mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
246 lines
8.8 KiB
Python
246 lines
8.8 KiB
Python
"""WebSocket-based WebRTC signaling endpoint with ICE trickling support.
|
|
|
|
This implementation uses WebSocket-based signaling instead of HTTP PATCH for ICE candidates,
|
|
which is suitable for multi-worker FastAPI deployments where local _pcs_map cannot be shared.
|
|
|
|
Uses the SmallWebRTC API contract:
|
|
- SmallWebRTCConnection for peer connection management
|
|
- candidate_from_sdp() for parsing ICE candidates
|
|
- add_ice_candidate() for trickling support
|
|
"""
|
|
|
|
import asyncio
|
|
from typing import Dict
|
|
|
|
from aiortc.sdp import candidate_from_sdp
|
|
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
|
from loguru import logger
|
|
|
|
from api.db.models import UserModel
|
|
from api.services.auth.depends import get_user_ws
|
|
from api.services.pipecat.run_pipeline import run_pipeline_smallwebrtc
|
|
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
|
|
from pipecat.utils.context import set_current_run_id
|
|
|
|
router = APIRouter(prefix="/ws")
|
|
|
|
# ICE servers configuration
|
|
ice_servers = ["stun:stun.l.google.com:19302"]
|
|
|
|
|
|
class SignalingManager:
|
|
"""Manages WebSocket connections and WebRTC peer connections."""
|
|
|
|
def __init__(self):
|
|
self._connections: Dict[str, WebSocket] = {}
|
|
self._peer_connections: Dict[str, SmallWebRTCConnection] = {}
|
|
|
|
async def handle_websocket(
|
|
self,
|
|
websocket: WebSocket,
|
|
workflow_id: int,
|
|
workflow_run_id: int,
|
|
user: UserModel,
|
|
):
|
|
"""Handle WebSocket connection for signaling."""
|
|
await websocket.accept()
|
|
connection_id = f"{workflow_id}:{workflow_run_id}:{user.id}"
|
|
self._connections[connection_id] = websocket
|
|
|
|
try:
|
|
while True:
|
|
message = await websocket.receive_json()
|
|
await self._handle_message(
|
|
websocket, message, workflow_id, workflow_run_id, user
|
|
)
|
|
except WebSocketDisconnect:
|
|
logger.info(f"WebSocket disconnected for {connection_id}")
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error for {connection_id}: {e}")
|
|
finally:
|
|
# Cleanup
|
|
self._connections.pop(connection_id, None)
|
|
|
|
# Clean up all peer connections for this workflow run
|
|
# Note: In a WebSocket-based signaling approach (vs HTTP PATCH),
|
|
# we maintain our own connection map instead of relying on
|
|
# SmallWebRTCRequestHandler's _pcs_map. This is suitable for
|
|
# multi-worker FastAPI deployments where state cannot be shared.
|
|
for pc_id in list(self._peer_connections.keys()):
|
|
pc = self._peer_connections.pop(pc_id, None)
|
|
if pc:
|
|
await pc.disconnect()
|
|
logger.debug(f"Disconnected peer connection: {pc_id}")
|
|
|
|
async def _handle_message(
|
|
self,
|
|
ws: WebSocket,
|
|
message: dict,
|
|
workflow_id: int,
|
|
workflow_run_id: int,
|
|
user: UserModel,
|
|
):
|
|
"""Handle incoming WebSocket messages."""
|
|
msg_type = message.get("type")
|
|
payload = message.get("payload", {})
|
|
|
|
if msg_type == "offer":
|
|
await self._handle_offer(ws, payload, workflow_id, workflow_run_id, user)
|
|
elif msg_type == "ice-candidate":
|
|
await self._handle_ice_candidate(ws, payload, workflow_run_id)
|
|
elif msg_type == "renegotiate":
|
|
await self._handle_renegotiation(ws, payload, workflow_id, workflow_run_id)
|
|
|
|
async def _handle_offer(
|
|
self,
|
|
ws: WebSocket,
|
|
payload: dict,
|
|
workflow_id: int,
|
|
workflow_run_id: int,
|
|
user: UserModel,
|
|
):
|
|
"""Handle offer message and create answer with ICE trickling."""
|
|
pc_id = payload.get("pc_id")
|
|
sdp = payload.get("sdp")
|
|
type_ = payload.get("type")
|
|
call_context_vars = payload.get("call_context_vars", {})
|
|
|
|
# Set run context for logging
|
|
set_current_run_id(workflow_run_id)
|
|
|
|
if pc_id and pc_id in self._peer_connections:
|
|
# Reuse existing connection
|
|
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
|
|
pc = self._peer_connections[pc_id]
|
|
await pc.renegotiate(sdp=sdp, type=type_, restart_pc=False)
|
|
|
|
# Send updated answer
|
|
answer = pc.get_answer()
|
|
await ws.send_json(
|
|
{
|
|
"type": "answer",
|
|
"payload": {"sdp": answer["sdp"], "type": "answer", "pc_id": pc_id},
|
|
}
|
|
)
|
|
else:
|
|
# Create new connection using correct SmallWebRTC API
|
|
pc = SmallWebRTCConnection(ice_servers=ice_servers, connection_timeout_secs=20)
|
|
# Set the pc_id before initialization so it's available in get_answer()
|
|
pc._pc_id = pc_id
|
|
|
|
# Initialize connection with offer
|
|
await pc.initialize(sdp=sdp, type=type_)
|
|
|
|
# Store peer connection using client's pc_id
|
|
self._peer_connections[pc_id] = pc
|
|
|
|
# Setup closed handler
|
|
@pc.event_handler("closed")
|
|
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
|
|
logger.info(f"PeerConnection closed: {webrtc_connection.pc_id}")
|
|
self._peer_connections.pop(webrtc_connection.pc_id, None)
|
|
|
|
# Start pipeline in background
|
|
asyncio.create_task(
|
|
run_pipeline_smallwebrtc(
|
|
pc, workflow_id, workflow_run_id, user.id, call_context_vars
|
|
)
|
|
)
|
|
|
|
# Get answer after initialization
|
|
answer = pc.get_answer()
|
|
|
|
# Send answer immediately (ICE candidates will be sent separately via trickling)
|
|
await ws.send_json(
|
|
{
|
|
"type": "answer",
|
|
"payload": {
|
|
"sdp": answer["sdp"],
|
|
"type": answer["type"],
|
|
"pc_id": answer["pc_id"],
|
|
},
|
|
}
|
|
)
|
|
|
|
async def _handle_ice_candidate(
|
|
self, ws: WebSocket, payload: dict, workflow_run_id: int
|
|
):
|
|
"""Handle incoming ICE candidate from client.
|
|
|
|
Uses SmallWebRTC's native ICE trickling support via add_ice_candidate().
|
|
Candidates are parsed using aiortc's candidate_from_sdp() for proper formatting,
|
|
consistent with SmallWebRTCRequestHandler.handle_patch_request().
|
|
"""
|
|
pc_id = payload.get("pc_id")
|
|
candidate_data = payload.get("candidate")
|
|
|
|
if not pc_id:
|
|
logger.warning("Received ICE candidate without pc_id")
|
|
return
|
|
|
|
pc = self._peer_connections.get(pc_id)
|
|
if not pc:
|
|
logger.warning(f"No peer connection found for pc_id: {pc_id}")
|
|
return
|
|
|
|
if candidate_data:
|
|
try:
|
|
# Parse the ICE candidate using aiortc's parser (same as SmallWebRTCRequestHandler)
|
|
candidate = candidate_from_sdp(candidate_data["candidate"])
|
|
candidate.sdpMid = candidate_data.get("sdpMid")
|
|
candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex")
|
|
|
|
await pc.add_ice_candidate(candidate)
|
|
logger.debug(f"Added ICE candidate for pc_id: {pc_id}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to add ICE candidate: {e}")
|
|
else:
|
|
logger.debug(f"End of ICE candidates for pc_id: {pc_id}")
|
|
|
|
async def _handle_renegotiation(
|
|
self, ws: WebSocket, payload: dict, workflow_id: int, workflow_run_id: int
|
|
):
|
|
"""Handle renegotiation request."""
|
|
pc_id = payload.get("pc_id")
|
|
sdp = payload.get("sdp")
|
|
type_ = payload.get("type")
|
|
restart_pc = payload.get("restart_pc", False)
|
|
|
|
if not pc_id or pc_id not in self._peer_connections:
|
|
await ws.send_json(
|
|
{"type": "error", "payload": {"message": "Peer connection not found"}}
|
|
)
|
|
return
|
|
|
|
pc = self._peer_connections[pc_id]
|
|
await pc.renegotiate(sdp=sdp, type=type_, restart_pc=restart_pc)
|
|
|
|
# Send updated answer
|
|
answer = pc.get_answer()
|
|
await ws.send_json(
|
|
{
|
|
"type": "answer",
|
|
"payload": {
|
|
"sdp": answer["sdp"],
|
|
"type": "answer",
|
|
"pc_id": pc_id, # Use the client's pc_id
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
# Create singleton instance
|
|
signaling_manager = SignalingManager()
|
|
|
|
|
|
@router.websocket("/signaling/{workflow_id}/{workflow_run_id}")
|
|
async def signaling_websocket(
|
|
websocket: WebSocket,
|
|
workflow_id: int,
|
|
workflow_run_id: int,
|
|
user: UserModel = Depends(get_user_ws),
|
|
):
|
|
"""WebSocket endpoint for WebRTC signaling with ICE trickling."""
|
|
await signaling_manager.handle_websocket(
|
|
websocket, workflow_id, workflow_run_id, user
|
|
)
|