mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
362 lines
14 KiB
Python
362 lines
14 KiB
Python
|
|
"""Low-level RTP transport for Asterisk externalMedia sessions.
|
|||
|
|
|
|||
|
|
stasis_rtp_client.py
|
|||
|
|
~~~~~~~~~~~~~~~~~~~~
|
|||
|
|
|
|||
|
|
* Sends and receives **proper RTP/UDP** (PT 0 PCMU/μ-law).
|
|||
|
|
* Uses 20 ms frames (160 bytes payload) by default; automatically
|
|||
|
|
chunks or concatenates data so timestamps stay correct.
|
|||
|
|
* Verifies the RTP header on the receive path (SSRC and PT).
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import secrets
|
|||
|
|
import socket
|
|||
|
|
import struct
|
|||
|
|
from typing import TYPE_CHECKING, AsyncIterator, Optional
|
|||
|
|
|
|||
|
|
from loguru import logger
|
|||
|
|
from pipecat.utils.enums import EndTaskReason
|
|||
|
|
|
|||
|
|
if TYPE_CHECKING:
|
|||
|
|
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
|
|||
|
|
from api.services.telephony.stasis_rtp_transport import StasisRTPCallbacks
|
|||
|
|
|
|||
|
|
# ─────────────────────────────────────────────────────────────────── helpers
|
|||
|
|
|
|||
|
|
|
|||
|
|
_RTP_HDR = struct.Struct("!BBHII") # v/p/x/cc, m/pt, seq, ts, ssrc
|
|||
|
|
_PT_PCMU = 0 # static payload type for μ-law
|
|||
|
|
|
|||
|
|
|
|||
|
|
class _RTPEncoder:
|
|||
|
|
"""Builds PCMU RTP headers for the packets we SEND to Asterisk."""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self.ssrc = secrets.randbits(32)
|
|||
|
|
self.seq = secrets.randbits(16)
|
|||
|
|
self.ts = 0 # incremented by #payload bytes
|
|||
|
|
|
|||
|
|
def pack(self, payload: bytes, mark=False) -> bytes:
|
|||
|
|
b0 = 0x80 # V=2
|
|||
|
|
b1 = (0x80 if mark else 0x00) | _PT_PCMU
|
|||
|
|
hdr = _RTP_HDR.pack(b0, b1, self.seq, self.ts, self.ssrc)
|
|||
|
|
self.seq = (self.seq + 1) & 0xFFFF
|
|||
|
|
self.ts += len(payload) # 1 sample/byte @ 8 kHz
|
|||
|
|
return hdr + payload
|
|||
|
|
|
|||
|
|
|
|||
|
|
class _RTPDecoder:
|
|||
|
|
"""Very forgiving RTP decoder.
|
|||
|
|
|
|||
|
|
Latches on the first valid packet and then insists
|
|||
|
|
that SSRC & PT match afterwards. Returns *None* if the packet
|
|||
|
|
should be ignored.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self.peer_ssrc: int | None = None # learned from first packet
|
|||
|
|
|
|||
|
|
def unpack(self, packet: bytes) -> bytes | None:
|
|||
|
|
if len(packet) < _RTP_HDR.size:
|
|||
|
|
return None
|
|||
|
|
b0, b1, seq, ts, ssrc = _RTP_HDR.unpack_from(packet)
|
|||
|
|
if (b0 & 0xC0) != 0x80: # RTP v2?
|
|||
|
|
return None
|
|||
|
|
if (b1 & 0x7F) != _PT_PCMU: # payload-type 0?
|
|||
|
|
return None
|
|||
|
|
if self.peer_ssrc is None:
|
|||
|
|
self.peer_ssrc = ssrc # latch on first good packet
|
|||
|
|
elif ssrc != self.peer_ssrc:
|
|||
|
|
return None # stray stream – drop
|
|||
|
|
return packet[_RTP_HDR.size :]
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ──────────────────────────────────────────────────────────────── client
|
|||
|
|
|
|||
|
|
|
|||
|
|
class StasisRTPClient:
|
|||
|
|
"""Low-level wrapper around StasisRTPConnection.
|
|||
|
|
|
|||
|
|
Public API
|
|||
|
|
──────────
|
|||
|
|
• await setup(start_frame) kept for parity (does nothing)
|
|||
|
|
• await connect()
|
|||
|
|
• async for payload in receive(): # μ-law bytes (20 ms each)
|
|||
|
|
…
|
|||
|
|
• await send(data) # any length; will be chunked
|
|||
|
|
• await disconnect()
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
_FRAME_SIZE = 160 # 20 ms @ 8 kHz PCMU
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
connection: "StasisRTPConnection",
|
|||
|
|
callbacks: "StasisRTPCallbacks",
|
|||
|
|
):
|
|||
|
|
"""Initialize Stasis RTP client.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
connection: RTP connection parameters.
|
|||
|
|
callbacks: Callback handlers for transport events.
|
|||
|
|
"""
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
self._connection = connection
|
|||
|
|
self._callbacks = callbacks
|
|||
|
|
self._encoder = _RTPEncoder()
|
|||
|
|
self._decoder = _RTPDecoder()
|
|||
|
|
|
|||
|
|
self._recv_sock: Optional[socket.socket] = None
|
|||
|
|
self._send_sock: Optional[socket.socket] = None
|
|||
|
|
self._closing = False
|
|||
|
|
self._recv_sock_ready = asyncio.Event() # Signal when recv socket is ready
|
|||
|
|
self._leave_counter = 0 # Track input/output transport usage
|
|||
|
|
self._fallback_disconnect_timer: Optional[asyncio.Task] = (
|
|||
|
|
None # Safety timer for disconnect
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ── wire event handlers to the connection ────────────────
|
|||
|
|
@self._connection.event_handler("connected")
|
|||
|
|
async def _on_connected(_: Any):
|
|||
|
|
await self._setup_sockets()
|
|||
|
|
await self._callbacks.on_client_connected(
|
|||
|
|
self._connection.caller_channel_id
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@self._connection.event_handler("disconnected")
|
|||
|
|
async def _on_disconnected(_: Any, reason: str):
|
|||
|
|
# Cancel the safety timer if it exists. We start the safety timer when
|
|||
|
|
# sending disconnect or transfer from the engine, i.e when the disconnect()
|
|||
|
|
# method of the StasisRTPClient is called during wind down of the pipeline.
|
|||
|
|
# We start the timer so that if we don't get the remote hangup in a given
|
|||
|
|
# duration, we will call client disconnected handler.
|
|||
|
|
if (
|
|||
|
|
self._fallback_disconnect_timer
|
|||
|
|
and not self._fallback_disconnect_timer.done()
|
|||
|
|
):
|
|||
|
|
self._fallback_disconnect_timer.cancel()
|
|||
|
|
self._fallback_disconnect_timer = None
|
|||
|
|
|
|||
|
|
if not self._closing:
|
|||
|
|
# Mark the client as closing, so that when the pipeline is
|
|||
|
|
# cancelled or getting closed, we don't try start the fallback
|
|||
|
|
# disconnect timer and return safely from disconnect
|
|||
|
|
self._closing = True
|
|||
|
|
|
|||
|
|
await self._callbacks.on_client_disconnected(
|
|||
|
|
self._connection.caller_channel_id, reason
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ─── public helpers ──────────────────────────────────────────
|
|||
|
|
|
|||
|
|
async def setup(self, _):
|
|||
|
|
"""Setup method for compatibility."""
|
|||
|
|
self._leave_counter += 1
|
|||
|
|
|
|||
|
|
async def connect(self):
|
|||
|
|
"""Connect to the RTP socket."""
|
|||
|
|
if self._connection.is_connected():
|
|||
|
|
return
|
|||
|
|
await self._connection.connect()
|
|||
|
|
|
|||
|
|
async def disconnect(
|
|||
|
|
self,
|
|||
|
|
reason: str = EndTaskReason.UNKNOWN.value,
|
|||
|
|
call_transfer_context: dict = {}, # Keep parameter for backward compatibility
|
|||
|
|
):
|
|||
|
|
"""Disconnect from the RTP socket."""
|
|||
|
|
# Decrement leave counter when disconnect is called
|
|||
|
|
logger.debug(f"StasisRTPClient.disconnect leave_counter: {self._leave_counter}")
|
|||
|
|
self._leave_counter -= 1
|
|||
|
|
if self._leave_counter > 0:
|
|||
|
|
# Early return - InputTransport called first, OutputTransport will call later
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# Only proceed when counter reaches 0 (OutputTransport's call)
|
|||
|
|
# Close sockets
|
|||
|
|
logger.debug("Going to close sockets")
|
|||
|
|
await self._close_sockets()
|
|||
|
|
|
|||
|
|
if self._closing:
|
|||
|
|
# We might have received the disconnected callback from the StasisRTPConnection
|
|||
|
|
# due to user hangup. We will just return. We have already closed the sockets
|
|||
|
|
# in disconnected callback handler.
|
|||
|
|
return
|
|||
|
|
self._closing = True
|
|||
|
|
|
|||
|
|
# Create a safety timer that will call on_client_disconnected if we don't
|
|||
|
|
# get StasisEnd from the dialer within 5 seconds. StasisEnd is needed to
|
|||
|
|
# trigger on_client_disconnected handler in the event_handlers
|
|||
|
|
async def _fallback_disconnect_timeout():
|
|||
|
|
await asyncio.sleep(5.0)
|
|||
|
|
logger.warning(
|
|||
|
|
"Disconnect event not received within 5 seconds, calling on_client_disconnected as fallback"
|
|||
|
|
)
|
|||
|
|
await self._callbacks.on_client_disconnected(
|
|||
|
|
self._connection.caller_channel_id
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self._fallback_disconnect_timer = asyncio.create_task(
|
|||
|
|
_fallback_disconnect_timeout()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Only call disconnect if not a transfer (transfer already handled in PipecatEngine)
|
|||
|
|
# NOTE: Transfer now happens immediately in PipecatEngine.send_end_task_frame()
|
|||
|
|
if reason != EndTaskReason.USER_QUALIFIED.value:
|
|||
|
|
try:
|
|||
|
|
await self._connection.disconnect(reason)
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.error(f"Failed to disconnect RTP connection: {exc}")
|
|||
|
|
else:
|
|||
|
|
logger.debug(
|
|||
|
|
"Skipping disconnect call for USER_QUALIFIED - transfer already initiated by engine"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ─── socket management ──────────────────────────────────────
|
|||
|
|
|
|||
|
|
async def _setup_sockets(self):
|
|||
|
|
if self._recv_sock and self._send_sock:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
logger.debug(
|
|||
|
|
f"Setting up Sockets - local {self._connection.local_addr}, remote: {self._connection.remote_addr}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# receive socket – bind to local address provided by connection
|
|||
|
|
if not self._recv_sock:
|
|||
|
|
rs = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|||
|
|
rs.setblocking(False)
|
|||
|
|
rs.bind(self._connection.local_addr)
|
|||
|
|
self._recv_sock = rs
|
|||
|
|
self._recv_sock_ready.set() # Signal that recv socket is ready
|
|||
|
|
|
|||
|
|
# send socket – connect to remote (Asterisk) address
|
|||
|
|
if not self._send_sock:
|
|||
|
|
ss = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|||
|
|
ss.setblocking(False)
|
|||
|
|
ss.connect(self._connection.remote_addr)
|
|||
|
|
self._send_sock = ss
|
|||
|
|
|
|||
|
|
logger.debug(
|
|||
|
|
f"Socket setup complete - recv_fd: {self._recv_sock.fileno()}, send_fd: {self._send_sock.fileno()}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def _close_sockets(self):
|
|||
|
|
"""Safely close sockets with proper error handling."""
|
|||
|
|
for sock_name, sock in [("recv", self._recv_sock), ("send", self._send_sock)]:
|
|||
|
|
if sock:
|
|||
|
|
try:
|
|||
|
|
# Shutdown the socket first to break any pending operations
|
|||
|
|
sock.shutdown(socket.SHUT_RDWR)
|
|||
|
|
except OSError:
|
|||
|
|
# Socket might already be closed or in a bad state
|
|||
|
|
pass
|
|||
|
|
try:
|
|||
|
|
sock.close()
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.debug(f"Error closing {sock_name} socket: {exc}")
|
|||
|
|
|
|||
|
|
self._recv_sock = None
|
|||
|
|
self._send_sock = None
|
|||
|
|
self._recv_sock_ready.clear() # Reset the event for potential reconnection
|
|||
|
|
|
|||
|
|
# Notify the connection that sockets are closed so ARI Manager can clean up ports
|
|||
|
|
await self._connection.notify_sockets_closed()
|
|||
|
|
|
|||
|
|
logger.debug("Closed sockets in StasisRTPClient")
|
|||
|
|
|
|||
|
|
# ─── receive path ────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
async def receive(self) -> AsyncIterator[bytes]:
|
|||
|
|
"""Async generator yielding μ-law frames (exactly 160 bytes each).
|
|||
|
|
|
|||
|
|
Silently drops any packet whose RTP header does not match our SSRC/PT.
|
|||
|
|
"""
|
|||
|
|
loop = asyncio.get_running_loop()
|
|||
|
|
|
|||
|
|
# Wait for recv socket to be created
|
|||
|
|
try:
|
|||
|
|
await self._recv_sock_ready.wait()
|
|||
|
|
except asyncio.CancelledError:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
logger.debug("Going to receive from the socket now")
|
|||
|
|
|
|||
|
|
while not self._closing:
|
|||
|
|
try:
|
|||
|
|
# each loop gets 172 bytes UDP packet, which is 160 bytes of
|
|||
|
|
# audio data (Asterisk sends 20ms audio chunks with 8k sample rate)
|
|||
|
|
# and 12 bytes of RTP header
|
|||
|
|
data = await loop.sock_recv(self._recv_sock, 2048)
|
|||
|
|
except asyncio.CancelledError:
|
|||
|
|
logger.debug("RTP receive task cancelled")
|
|||
|
|
break
|
|||
|
|
except (OSError, socket.error) as exc:
|
|||
|
|
logger.warning(f"RTP receive failed (socket closed): {exc}")
|
|||
|
|
break
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.debug(f"Unexpected error in receive: {exc}")
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
payload = self._decoder.unpack(data)
|
|||
|
|
if payload is None:
|
|||
|
|
continue # header failed validation
|
|||
|
|
|
|||
|
|
# In practice Asterisk sends 20 ms frames – assert just in case.
|
|||
|
|
if len(payload) != self._FRAME_SIZE:
|
|||
|
|
logger.warning(f"Dropping non-20 ms packet len={len(payload)}")
|
|||
|
|
continue
|
|||
|
|
yield payload
|
|||
|
|
|
|||
|
|
# ─── send path ───────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
async def send(self, data: bytes):
|
|||
|
|
"""Send μ-law data of arbitrary length.
|
|||
|
|
|
|||
|
|
Splits/aggregates into 160-byte chunks before RTP-wrapping.
|
|||
|
|
"""
|
|||
|
|
if self._closing or not self._send_sock:
|
|||
|
|
return
|
|||
|
|
loop = asyncio.get_running_loop()
|
|||
|
|
|
|||
|
|
# chunk/concat to 160-byte frames
|
|||
|
|
chunks = self._chunk_ulaw(data, self._FRAME_SIZE)
|
|||
|
|
for i, chunk in enumerate(chunks):
|
|||
|
|
mark = i == 0 # set marker on the first packet of talk-spurt
|
|||
|
|
packet = self._encoder.pack(chunk, mark=mark)
|
|||
|
|
try:
|
|||
|
|
await loop.sock_sendall(self._send_sock, packet)
|
|||
|
|
except (OSError, socket.error) as exc:
|
|||
|
|
logger.warning(f"RTP send failed (socket closed): {exc}")
|
|||
|
|
break
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.error(f"RTP send failed: {exc}")
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
def _chunk_ulaw(self, buf: bytes, size: int) -> list[bytes]:
|
|||
|
|
"""Split / aggregate μ-law bytes to exact *size* multiples.
|
|||
|
|
|
|||
|
|
• If buf length is not a multiple of *size*, pad the last chunk with 0xFF
|
|||
|
|
(silence). That keeps timestamps monotonic.
|
|||
|
|
"""
|
|||
|
|
if not buf:
|
|||
|
|
return []
|
|||
|
|
if len(buf) % size:
|
|||
|
|
pad = size - (len(buf) % size)
|
|||
|
|
buf += b"\xff" * pad
|
|||
|
|
return [buf[i : i + size] for i in range(0, len(buf), size)]
|
|||
|
|
|
|||
|
|
# ─── properties ──────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def is_connected(self) -> bool:
|
|||
|
|
"""Check if client is connected."""
|
|||
|
|
return self._connection.is_connected() and not self._closing
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def is_closing(self) -> bool:
|
|||
|
|
"""Check if client is closing."""
|
|||
|
|
return self._closing
|