mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
chore: upgrade pipecat (#36)
This commit is contained in:
parent
e690753275
commit
491e6edd36
3 changed files with 46 additions and 25 deletions
|
|
@ -1,8 +1,18 @@
|
|||
"""WebSocket-based WebRTC signaling endpoint with ICE trickling support."""
|
||||
"""WebSocket-based WebRTC signaling endpoint with ICE trickling support.
|
||||
|
||||
This implementation uses WebSocket-based signaling instead of HTTP PATCH for ICE candidates,
|
||||
which is suitable for multi-worker FastAPI deployments where local _pcs_map cannot be shared.
|
||||
|
||||
Uses the SmallWebRTC API contract:
|
||||
- SmallWebRTCConnection for peer connection management
|
||||
- candidate_from_sdp() for parsing ICE candidates
|
||||
- add_ice_candidate() for trickling support
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
|
||||
from aiortc.sdp import candidate_from_sdp
|
||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -51,14 +61,14 @@ class SignalingManager:
|
|||
# Cleanup
|
||||
self._connections.pop(connection_id, None)
|
||||
|
||||
# Clean up peer connections for this workflow run
|
||||
# Since we store connections by client pc_id directly,
|
||||
# we need to iterate through all connections and check which ones to clean up
|
||||
# We could track connections per workflow_run, but for now let's clean all
|
||||
# connections that were using this WebSocket
|
||||
# Clean up all peer connections for this workflow run
|
||||
# Note: In a WebSocket-based signaling approach (vs HTTP PATCH),
|
||||
# we maintain our own connection map instead of relying on
|
||||
# SmallWebRTCRequestHandler's _pcs_map. This is suitable for
|
||||
# multi-worker FastAPI deployments where state cannot be shared.
|
||||
for pc_id in list(self._peer_connections.keys()):
|
||||
pc = self._peer_connections.pop(pc_id, None)
|
||||
if pc and pc._signaling_ws == websocket:
|
||||
if pc:
|
||||
await pc.disconnect()
|
||||
logger.debug(f"Disconnected peer connection: {pc_id}")
|
||||
|
||||
|
|
@ -113,14 +123,13 @@ class SignalingManager:
|
|||
}
|
||||
)
|
||||
else:
|
||||
# Create new connection with WebSocket support
|
||||
# Pass client's pc_id as the name parameter
|
||||
pc = SmallWebRTCConnection(
|
||||
ice_servers, signaling_ws=ws, enable_trickling=True, name=pc_id
|
||||
)
|
||||
# Create new connection using correct SmallWebRTC API
|
||||
pc = SmallWebRTCConnection(ice_servers=ice_servers, connection_timeout_secs=20)
|
||||
# Set the pc_id before initialization so it's available in get_answer()
|
||||
pc._pc_id = pc_id
|
||||
|
||||
# Initialize with trickling support
|
||||
answer = await pc.initialize_with_trickling(sdp=sdp, type=type_)
|
||||
# Initialize connection with offer
|
||||
await pc.initialize(sdp=sdp, type=type_)
|
||||
|
||||
# Store peer connection using client's pc_id
|
||||
self._peer_connections[pc_id] = pc
|
||||
|
|
@ -138,14 +147,17 @@ class SignalingManager:
|
|||
)
|
||||
)
|
||||
|
||||
# Send answer immediately (without ICE candidates)
|
||||
# Get answer after initialization
|
||||
answer = pc.get_answer()
|
||||
|
||||
# Send answer immediately (ICE candidates will be sent separately via trickling)
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "answer",
|
||||
"payload": {
|
||||
"sdp": answer.sdp,
|
||||
"type": answer.type,
|
||||
"pc_id": pc.pc_id, # This will be the same as pc_id we passed
|
||||
"sdp": answer["sdp"],
|
||||
"type": answer["type"],
|
||||
"pc_id": answer["pc_id"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
@ -153,9 +165,14 @@ class SignalingManager:
|
|||
async def _handle_ice_candidate(
|
||||
self, ws: WebSocket, payload: dict, workflow_run_id: int
|
||||
):
|
||||
"""Handle incoming ICE candidate from client."""
|
||||
"""Handle incoming ICE candidate from client.
|
||||
|
||||
Uses SmallWebRTC's native ICE trickling support via add_ice_candidate().
|
||||
Candidates are parsed using aiortc's candidate_from_sdp() for proper formatting,
|
||||
consistent with SmallWebRTCRequestHandler.handle_patch_request().
|
||||
"""
|
||||
pc_id = payload.get("pc_id")
|
||||
candidate = payload.get("candidate")
|
||||
candidate_data = payload.get("candidate")
|
||||
|
||||
if not pc_id:
|
||||
logger.warning("Received ICE candidate without pc_id")
|
||||
|
|
@ -166,8 +183,13 @@ class SignalingManager:
|
|||
logger.warning(f"No peer connection found for pc_id: {pc_id}")
|
||||
return
|
||||
|
||||
if candidate:
|
||||
if candidate_data:
|
||||
try:
|
||||
# Parse the ICE candidate using aiortc's parser (same as SmallWebRTCRequestHandler)
|
||||
candidate = candidate_from_sdp(candidate_data["candidate"])
|
||||
candidate.sdpMid = candidate_data.get("sdpMid")
|
||||
candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex")
|
||||
|
||||
await pc.add_ice_candidate(candidate)
|
||||
logger.debug(f"Added ICE candidate for pc_id: {pc_id}")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -23,8 +23,7 @@ from pipecat.frames.frames import (
|
|||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import (
|
||||
BaseOutputTransport,
|
||||
TransportClientNotConnectedException,
|
||||
BaseOutputTransport
|
||||
)
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
|
||||
|
|
@ -204,7 +203,7 @@ class StasisRTPOutputTransport(BaseOutputTransport):
|
|||
async def write_audio_frame(self, frame: OutputAudioRawFrame):
|
||||
"""Write audio frame to RTP stream."""
|
||||
if self._client.is_closing:
|
||||
raise TransportClientNotConnectedException()
|
||||
return False
|
||||
|
||||
if not self._client.is_connected:
|
||||
# If not connected yet, just simulate playback delay.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue