dograh/api/services/telephony/stasis_rtp_client.py
Abhishek Kumar 4f2a629340 Initial Commit 🚀 🚀
2025-09-09 14:37:32 +05:30

361 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Low-level RTP transport for Asterisk externalMedia sessions.
stasis_rtp_client.py
~~~~~~~~~~~~~~~~~~~~
* Sends and receives **proper RTP/UDP** (PT 0 PCMU/μ-law).
* Uses 20 ms frames (160 bytes payload) by default; automatically
chunks or concatenates data so timestamps stay correct.
* Verifies the RTP header on the receive path (SSRC and PT).
"""
import asyncio
import secrets
import socket
import struct
from typing import TYPE_CHECKING, AsyncIterator, Optional
from loguru import logger
from pipecat.utils.enums import EndTaskReason
if TYPE_CHECKING:
from api.services.telephony.stasis_rtp_connection import StasisRTPConnection
from api.services.telephony.stasis_rtp_transport import StasisRTPCallbacks
# ─────────────────────────────────────────────────────────────────── helpers
_RTP_HDR = struct.Struct("!BBHII") # v/p/x/cc, m/pt, seq, ts, ssrc
_PT_PCMU = 0 # static payload type for μ-law
class _RTPEncoder:
"""Builds PCMU RTP headers for the packets we SEND to Asterisk."""
def __init__(self):
self.ssrc = secrets.randbits(32)
self.seq = secrets.randbits(16)
self.ts = 0 # incremented by #payload bytes
def pack(self, payload: bytes, mark=False) -> bytes:
b0 = 0x80 # V=2
b1 = (0x80 if mark else 0x00) | _PT_PCMU
hdr = _RTP_HDR.pack(b0, b1, self.seq, self.ts, self.ssrc)
self.seq = (self.seq + 1) & 0xFFFF
self.ts += len(payload) # 1 sample/byte @ 8 kHz
return hdr + payload
class _RTPDecoder:
"""Very forgiving RTP decoder.
Latches on the first valid packet and then insists
that SSRC & PT match afterwards. Returns *None* if the packet
should be ignored.
"""
def __init__(self):
self.peer_ssrc: int | None = None # learned from first packet
def unpack(self, packet: bytes) -> bytes | None:
if len(packet) < _RTP_HDR.size:
return None
b0, b1, seq, ts, ssrc = _RTP_HDR.unpack_from(packet)
if (b0 & 0xC0) != 0x80: # RTP v2?
return None
if (b1 & 0x7F) != _PT_PCMU: # payload-type 0?
return None
if self.peer_ssrc is None:
self.peer_ssrc = ssrc # latch on first good packet
elif ssrc != self.peer_ssrc:
return None # stray stream drop
return packet[_RTP_HDR.size :]
# ──────────────────────────────────────────────────────────────── client
class StasisRTPClient:
"""Low-level wrapper around StasisRTPConnection.
Public API
──────────
• await setup(start_frame) kept for parity (does nothing)
• await connect()
• async for payload in receive(): # μ-law bytes (20 ms each)
• await send(data) # any length; will be chunked
• await disconnect()
"""
_FRAME_SIZE = 160 # 20 ms @ 8 kHz PCMU
def __init__(
self,
connection: "StasisRTPConnection",
callbacks: "StasisRTPCallbacks",
):
"""Initialize Stasis RTP client.
Args:
connection: RTP connection parameters.
callbacks: Callback handlers for transport events.
"""
from typing import Any
self._connection = connection
self._callbacks = callbacks
self._encoder = _RTPEncoder()
self._decoder = _RTPDecoder()
self._recv_sock: Optional[socket.socket] = None
self._send_sock: Optional[socket.socket] = None
self._closing = False
self._recv_sock_ready = asyncio.Event() # Signal when recv socket is ready
self._leave_counter = 0 # Track input/output transport usage
self._fallback_disconnect_timer: Optional[asyncio.Task] = (
None # Safety timer for disconnect
)
# ── wire event handlers to the connection ────────────────
@self._connection.event_handler("connected")
async def _on_connected(_: Any):
await self._setup_sockets()
await self._callbacks.on_client_connected(
self._connection.caller_channel_id
)
@self._connection.event_handler("disconnected")
async def _on_disconnected(_: Any, reason: str):
# Cancel the safety timer if it exists. We start the safety timer when
# sending disconnect or transfer from the engine, i.e when the disconnect()
# method of the StasisRTPClient is called during wind down of the pipeline.
# We start the timer so that if we don't get the remote hangup in a given
# duration, we will call client disconnected handler.
if (
self._fallback_disconnect_timer
and not self._fallback_disconnect_timer.done()
):
self._fallback_disconnect_timer.cancel()
self._fallback_disconnect_timer = None
if not self._closing:
# Mark the client as closing, so that when the pipeline is
# cancelled or getting closed, we don't try start the fallback
# disconnect timer and return safely from disconnect
self._closing = True
await self._callbacks.on_client_disconnected(
self._connection.caller_channel_id, reason
)
# ─── public helpers ──────────────────────────────────────────
async def setup(self, _):
"""Setup method for compatibility."""
self._leave_counter += 1
async def connect(self):
"""Connect to the RTP socket."""
if self._connection.is_connected():
return
await self._connection.connect()
async def disconnect(
self,
reason: str = EndTaskReason.UNKNOWN.value,
call_transfer_context: dict = {}, # Keep parameter for backward compatibility
):
"""Disconnect from the RTP socket."""
# Decrement leave counter when disconnect is called
logger.debug(f"StasisRTPClient.disconnect leave_counter: {self._leave_counter}")
self._leave_counter -= 1
if self._leave_counter > 0:
# Early return - InputTransport called first, OutputTransport will call later
return
# Only proceed when counter reaches 0 (OutputTransport's call)
# Close sockets
logger.debug("Going to close sockets")
await self._close_sockets()
if self._closing:
# We might have received the disconnected callback from the StasisRTPConnection
# due to user hangup. We will just return. We have already closed the sockets
# in disconnected callback handler.
return
self._closing = True
# Create a safety timer that will call on_client_disconnected if we don't
# get StasisEnd from the dialer within 5 seconds. StasisEnd is needed to
# trigger on_client_disconnected handler in the event_handlers
async def _fallback_disconnect_timeout():
await asyncio.sleep(5.0)
logger.warning(
"Disconnect event not received within 5 seconds, calling on_client_disconnected as fallback"
)
await self._callbacks.on_client_disconnected(
self._connection.caller_channel_id
)
self._fallback_disconnect_timer = asyncio.create_task(
_fallback_disconnect_timeout()
)
# Only call disconnect if not a transfer (transfer already handled in PipecatEngine)
# NOTE: Transfer now happens immediately in PipecatEngine.send_end_task_frame()
if reason != EndTaskReason.USER_QUALIFIED.value:
try:
await self._connection.disconnect(reason)
except Exception as exc:
logger.error(f"Failed to disconnect RTP connection: {exc}")
else:
logger.debug(
"Skipping disconnect call for USER_QUALIFIED - transfer already initiated by engine"
)
# ─── socket management ──────────────────────────────────────
async def _setup_sockets(self):
if self._recv_sock and self._send_sock:
return
logger.debug(
f"Setting up Sockets - local {self._connection.local_addr}, remote: {self._connection.remote_addr}"
)
# receive socket bind to local address provided by connection
if not self._recv_sock:
rs = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
rs.setblocking(False)
rs.bind(self._connection.local_addr)
self._recv_sock = rs
self._recv_sock_ready.set() # Signal that recv socket is ready
# send socket connect to remote (Asterisk) address
if not self._send_sock:
ss = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ss.setblocking(False)
ss.connect(self._connection.remote_addr)
self._send_sock = ss
logger.debug(
f"Socket setup complete - recv_fd: {self._recv_sock.fileno()}, send_fd: {self._send_sock.fileno()}"
)
async def _close_sockets(self):
"""Safely close sockets with proper error handling."""
for sock_name, sock in [("recv", self._recv_sock), ("send", self._send_sock)]:
if sock:
try:
# Shutdown the socket first to break any pending operations
sock.shutdown(socket.SHUT_RDWR)
except OSError:
# Socket might already be closed or in a bad state
pass
try:
sock.close()
except Exception as exc:
logger.debug(f"Error closing {sock_name} socket: {exc}")
self._recv_sock = None
self._send_sock = None
self._recv_sock_ready.clear() # Reset the event for potential reconnection
# Notify the connection that sockets are closed so ARI Manager can clean up ports
await self._connection.notify_sockets_closed()
logger.debug("Closed sockets in StasisRTPClient")
# ─── receive path ────────────────────────────────────────────
async def receive(self) -> AsyncIterator[bytes]:
"""Async generator yielding μ-law frames (exactly 160 bytes each).
Silently drops any packet whose RTP header does not match our SSRC/PT.
"""
loop = asyncio.get_running_loop()
# Wait for recv socket to be created
try:
await self._recv_sock_ready.wait()
except asyncio.CancelledError:
return
logger.debug("Going to receive from the socket now")
while not self._closing:
try:
# each loop gets 172 bytes UDP packet, which is 160 bytes of
# audio data (Asterisk sends 20ms audio chunks with 8k sample rate)
# and 12 bytes of RTP header
data = await loop.sock_recv(self._recv_sock, 2048)
except asyncio.CancelledError:
logger.debug("RTP receive task cancelled")
break
except (OSError, socket.error) as exc:
logger.warning(f"RTP receive failed (socket closed): {exc}")
break
except Exception as exc:
logger.debug(f"Unexpected error in receive: {exc}")
break
payload = self._decoder.unpack(data)
if payload is None:
continue # header failed validation
# In practice Asterisk sends 20 ms frames assert just in case.
if len(payload) != self._FRAME_SIZE:
logger.warning(f"Dropping non-20 ms packet len={len(payload)}")
continue
yield payload
# ─── send path ───────────────────────────────────────────────
async def send(self, data: bytes):
"""Send μ-law data of arbitrary length.
Splits/aggregates into 160-byte chunks before RTP-wrapping.
"""
if self._closing or not self._send_sock:
return
loop = asyncio.get_running_loop()
# chunk/concat to 160-byte frames
chunks = self._chunk_ulaw(data, self._FRAME_SIZE)
for i, chunk in enumerate(chunks):
mark = i == 0 # set marker on the first packet of talk-spurt
packet = self._encoder.pack(chunk, mark=mark)
try:
await loop.sock_sendall(self._send_sock, packet)
except (OSError, socket.error) as exc:
logger.warning(f"RTP send failed (socket closed): {exc}")
break
except Exception as exc:
logger.error(f"RTP send failed: {exc}")
break
def _chunk_ulaw(self, buf: bytes, size: int) -> list[bytes]:
"""Split / aggregate μ-law bytes to exact *size* multiples.
• If buf length is not a multiple of *size*, pad the last chunk with 0xFF
(silence). That keeps timestamps monotonic.
"""
if not buf:
return []
if len(buf) % size:
pad = size - (len(buf) % size)
buf += b"\xff" * pad
return [buf[i : i + size] for i in range(0, len(buf), size)]
# ─── properties ──────────────────────────────────────────────
@property
def is_connected(self) -> bool:
"""Check if client is connected."""
return self._connection.is_connected() and not self._closing
@property
def is_closing(self) -> bool:
"""Check if client is closing."""
return self._closing