mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
feat: Trickle ice candidates for faster WebRTC connection
This commit is contained in:
parent
4b4a7ba19a
commit
895af47482
8 changed files with 677 additions and 6 deletions
|
|
@ -1,4 +1,4 @@
|
|||
pipecat-ai[cartesia,deepgram,openai,elevenlabs,groq,google,azure,soundfile,silero,webrtc] @ git+https://github.com/dograh-hq/pipecat.git@c327208
|
||||
pipecat-ai[cartesia,deepgram,openai,elevenlabs,groq,google,azure,soundfile,silero,webrtc] @ git+https://github.com/dograh-hq/pipecat.git@d03d892
|
||||
langfuse==3.4.0
|
||||
fastapi==0.116.2
|
||||
asyncpg==0.30.0
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from api.routes.service_keys import router as service_keys_router
|
|||
from api.routes.superuser import router as superuser_router
|
||||
from api.routes.twilio import router as twilio_router
|
||||
from api.routes.user import router as user_router
|
||||
from api.routes.webrtc_signaling import router as webrtc_signaling_router
|
||||
from api.routes.workflow import router as workflow_router
|
||||
|
||||
router = APIRouter(
|
||||
|
|
@ -31,6 +32,7 @@ router.include_router(service_keys_router)
|
|||
router.include_router(looptalk_router)
|
||||
router.include_router(organization_usage_router)
|
||||
router.include_router(reports_router)
|
||||
router.include_router(webrtc_signaling_router)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
|
|
|
|||
224
api/routes/webrtc_signaling.py
Normal file
224
api/routes/webrtc_signaling.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
"""WebSocket-based WebRTC signaling endpoint with ICE trickling support."""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||
from loguru import logger
|
||||
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import get_user_ws
|
||||
from api.services.pipecat.run_pipeline import run_pipeline_smallwebrtc
|
||||
|
||||
router = APIRouter(prefix="/ws")
|
||||
|
||||
# ICE servers configuration
|
||||
ice_servers = ["stun:stun.l.google.com:19302"]
|
||||
|
||||
|
||||
class SignalingManager:
|
||||
"""Manages WebSocket connections and WebRTC peer connections."""
|
||||
|
||||
def __init__(self):
|
||||
self._connections: Dict[str, WebSocket] = {}
|
||||
self._peer_connections: Dict[str, SmallWebRTCConnection] = {}
|
||||
|
||||
async def handle_websocket(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user: UserModel,
|
||||
):
|
||||
"""Handle WebSocket connection for signaling."""
|
||||
await websocket.accept()
|
||||
connection_id = f"{workflow_id}:{workflow_run_id}:{user.id}"
|
||||
self._connections[connection_id] = websocket
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_json()
|
||||
await self._handle_message(
|
||||
websocket, message, workflow_id, workflow_run_id, user
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket disconnected for {connection_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for {connection_id}: {e}")
|
||||
finally:
|
||||
# Cleanup
|
||||
self._connections.pop(connection_id, None)
|
||||
|
||||
# Clean up peer connections for this workflow run
|
||||
# Since we store connections by client pc_id directly,
|
||||
# we need to iterate through all connections and check which ones to clean up
|
||||
# We could track connections per workflow_run, but for now let's clean all
|
||||
# connections that were using this WebSocket
|
||||
for pc_id in list(self._peer_connections.keys()):
|
||||
pc = self._peer_connections.pop(pc_id, None)
|
||||
if pc and pc._signaling_ws == websocket:
|
||||
await pc.disconnect()
|
||||
logger.debug(f"Disconnected peer connection: {pc_id}")
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
ws: WebSocket,
|
||||
message: dict,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user: UserModel,
|
||||
):
|
||||
"""Handle incoming WebSocket messages."""
|
||||
msg_type = message.get("type")
|
||||
payload = message.get("payload", {})
|
||||
|
||||
if msg_type == "offer":
|
||||
await self._handle_offer(ws, payload, workflow_id, workflow_run_id, user)
|
||||
elif msg_type == "ice-candidate":
|
||||
await self._handle_ice_candidate(ws, payload, workflow_run_id)
|
||||
elif msg_type == "renegotiate":
|
||||
await self._handle_renegotiation(ws, payload, workflow_id, workflow_run_id)
|
||||
|
||||
async def _handle_offer(
|
||||
self,
|
||||
ws: WebSocket,
|
||||
payload: dict,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user: UserModel,
|
||||
):
|
||||
"""Handle offer message and create answer with ICE trickling."""
|
||||
pc_id = payload.get("pc_id")
|
||||
sdp = payload.get("sdp")
|
||||
type_ = payload.get("type")
|
||||
call_context_vars = payload.get("call_context_vars", {})
|
||||
|
||||
# Set run context for logging
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
if pc_id and pc_id in self._peer_connections:
|
||||
# Reuse existing connection
|
||||
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
|
||||
pc = self._peer_connections[pc_id]
|
||||
await pc.renegotiate(sdp=sdp, type=type_, restart_pc=False)
|
||||
|
||||
# Send updated answer
|
||||
answer = pc.get_answer()
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "answer",
|
||||
"payload": {"sdp": answer["sdp"], "type": "answer", "pc_id": pc_id},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Create new connection with WebSocket support
|
||||
# Pass client's pc_id as the name parameter
|
||||
pc = SmallWebRTCConnection(
|
||||
ice_servers, signaling_ws=ws, enable_trickling=True, name=pc_id
|
||||
)
|
||||
|
||||
# Initialize with trickling support
|
||||
answer = await pc.initialize_with_trickling(sdp=sdp, type=type_)
|
||||
|
||||
# Store peer connection using client's pc_id
|
||||
self._peer_connections[pc_id] = pc
|
||||
|
||||
# Setup closed handler
|
||||
@pc.event_handler("closed")
|
||||
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
|
||||
logger.info(f"PeerConnection closed: {webrtc_connection.pc_id}")
|
||||
self._peer_connections.pop(webrtc_connection.pc_id, None)
|
||||
|
||||
# Start pipeline in background
|
||||
asyncio.create_task(
|
||||
run_pipeline_smallwebrtc(
|
||||
pc, workflow_id, workflow_run_id, user.id, call_context_vars
|
||||
)
|
||||
)
|
||||
|
||||
# Send answer immediately (without ICE candidates)
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "answer",
|
||||
"payload": {
|
||||
"sdp": answer.sdp,
|
||||
"type": answer.type,
|
||||
"pc_id": pc.pc_id, # This will be the same as pc_id we passed
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
async def _handle_ice_candidate(
|
||||
self, ws: WebSocket, payload: dict, workflow_run_id: int
|
||||
):
|
||||
"""Handle incoming ICE candidate from client."""
|
||||
pc_id = payload.get("pc_id")
|
||||
candidate = payload.get("candidate")
|
||||
|
||||
if not pc_id:
|
||||
logger.warning("Received ICE candidate without pc_id")
|
||||
return
|
||||
|
||||
pc = self._peer_connections.get(pc_id)
|
||||
if not pc:
|
||||
logger.warning(f"No peer connection found for pc_id: {pc_id}")
|
||||
return
|
||||
|
||||
if candidate:
|
||||
try:
|
||||
await pc.add_ice_candidate(candidate)
|
||||
logger.debug(f"Added ICE candidate for pc_id: {pc_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add ICE candidate: {e}")
|
||||
else:
|
||||
logger.debug(f"End of ICE candidates for pc_id: {pc_id}")
|
||||
|
||||
async def _handle_renegotiation(
|
||||
self, ws: WebSocket, payload: dict, workflow_id: int, workflow_run_id: int
|
||||
):
|
||||
"""Handle renegotiation request."""
|
||||
pc_id = payload.get("pc_id")
|
||||
sdp = payload.get("sdp")
|
||||
type_ = payload.get("type")
|
||||
restart_pc = payload.get("restart_pc", False)
|
||||
|
||||
if not pc_id or pc_id not in self._peer_connections:
|
||||
await ws.send_json(
|
||||
{"type": "error", "payload": {"message": "Peer connection not found"}}
|
||||
)
|
||||
return
|
||||
|
||||
pc = self._peer_connections[pc_id]
|
||||
await pc.renegotiate(sdp=sdp, type=type_, restart_pc=restart_pc)
|
||||
|
||||
# Send updated answer
|
||||
answer = pc.get_answer()
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "answer",
|
||||
"payload": {
|
||||
"sdp": answer["sdp"],
|
||||
"type": "answer",
|
||||
"pc_id": pc_id, # Use the client's pc_id
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Create singleton instance
|
||||
signaling_manager = SignalingManager()
|
||||
|
||||
|
||||
@router.websocket("/signaling/{workflow_id}/{workflow_run_id}")
|
||||
async def signaling_websocket(
|
||||
websocket: WebSocket,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user: UserModel = Depends(get_user_ws),
|
||||
):
|
||||
"""WebSocket endpoint for WebRTC signaling with ICE trickling."""
|
||||
await signaling_manager.handle_websocket(
|
||||
websocket, workflow_id, workflow_run_id, user
|
||||
)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Annotated, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import Header, HTTPException
|
||||
from fastapi import Header, HTTPException, Query, WebSocket
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -328,3 +328,26 @@ async def get_superuser(
|
|||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_user_ws(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(None),
|
||||
) -> UserModel:
|
||||
"""
|
||||
WebSocket authentication dependency.
|
||||
Uses token from query parameters for authentication.
|
||||
"""
|
||||
if not token:
|
||||
await websocket.close(code=1008, reason="Missing authentication token")
|
||||
raise HTTPException(status_code=401, detail="Missing authentication token")
|
||||
|
||||
# Use the same logic as get_user but with token from query
|
||||
authorization = f"Bearer {token}"
|
||||
|
||||
try:
|
||||
user = await get_user(authorization)
|
||||
return user
|
||||
except HTTPException as e:
|
||||
await websocket.close(code=1008, reason=e.detail)
|
||||
raise
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue