dograh/api/services/telephony/ari_manager.py

748 lines
29 KiB
Python

"""Standalone ARI Manager Service for distributed architecture.
This service maintains the single WebSocket connection to Asterisk ARI
and distributes events to multiple FastAPI workers via Redis pub/sub.
ARIManager creates an instance of ARIClientSupervisor and registers the callbacks
on_channel_start and on_channel_end. It is responsible to take in caller_channel
and setup ARIManagerConnection, i.e create bridge for externalMedia.
"""
import asyncio
import json
import os
import signal
import time
from typing import Dict, Optional
from api.constants import REDIS_URL
# --- Add logging setup before importing loguru ---
from api.logging_config import setup_logging
from api.services.telephony.stasis_event_protocol import (
BaseWorkerToARIManagerCommand,
DisconnectCommand,
RedisChannels,
RedisKeys,
SocketClosedCommand,
StasisEndEvent,
StasisStartEvent,
TransferCommand,
parse_command,
)
logging_queue_listener = setup_logging()
import redis.asyncio as aioredis
import redis.exceptions
from loguru import logger
from pipecat.utils.enums import EndTaskReason
from api.services.telephony.ari_client import Channel
from api.services.telephony.ari_client_manager import (
ARIClientManager,
setup_ari_client_supervisor,
)
from api.services.telephony.ari_manager_connection import ARIManagerConnection
class ARIManager:
"""Manages ARI connection and distributes events to workers via Redis."""
def __init__(self, redis_client: aioredis.Redis):
self.redis = redis_client
self.stasis_manager: Optional[ARIClientManager] = None
self._running = False
self._ari_client_supervisor = None
self._tasks: Dict[str, asyncio.Task] = {}
self._pubsubs: Dict[
str, aioredis.client.PubSub
] = {} # Track pubsub connections
self._active_channels: set[str] = (
set()
) # Track channels managed by this instance
self._port_range = range(4000, 5000, 2) # Even ports only
self._channel_connections: Dict[
str, ARIManagerConnection
] = {} # Track connections by channel ID
self._channel_disposed: Dict[str, bool] = {} # Track channel disposed state
self._socket_closed: Dict[str, bool] = {} # Track socket closed state
self._active_workers: list[str] = [] # Cached list of active workers
self._worker_discovery_task: Optional[asyncio.Task] = None
self._channel_to_worker: Dict[str, str] = {} # Map channel to worker
async def on_channel_start(self, caller_channel: Channel, call_context_vars: dict):
"""Handle new channel from ARIClientManager with atomically allocated port."""
try:
# Atomically allocate port for this channel (prevents race conditions)
port = await self._get_and_allocate_port_atomic(caller_channel.id)
# Create connection with allocated port
connection = ARIManagerConnection(
caller_channel=caller_channel,
host=os.getenv("ARI_STASIS_APP_ENDPOINT"),
port=port,
)
# Track the connection
self._channel_connections[caller_channel.id] = connection
# Initialize channel state flags
self._channel_disposed[caller_channel.id] = False
self._socket_closed[caller_channel.id] = False
# Handle the connection
await self._on_stasis_call(connection, call_context_vars)
except Exception as e:
logger.exception(f"Error handling new channel {caller_channel.id}: {e}")
# Release port if allocation failed
await self._release_port_for_channel(caller_channel.id)
async def on_channel_end(self, channel_id: str):
"""Handle channel end notification from ARIClientManager."""
logger.info(f"channelID: {channel_id} Received channel end notification")
# Find the connection for this channel
connection = None
caller_channel_id = None
# Check if it's a caller channel
if channel_id in self._channel_connections:
connection = self._channel_connections[channel_id]
caller_channel_id = channel_id
else:
# TODO: We are currently not handling StasisEnd on ExternalMedia
for conn_channel_id, conn in self._channel_connections.items():
if conn.em_channel_id and conn.em_channel_id == channel_id:
logger.debug(
f"channelID: {channel_id} ExternalMedia StasisEnd - Ignoring"
)
# connection = conn
# caller_channel_id = conn_channel_id
break
# Publish StasisEnd event to worker immediately
if connection and caller_channel_id:
worker_id = self._get_worker_for_channel(caller_channel_id)
event = StasisEndEvent(
channel_id=caller_channel_id,
reason=EndTaskReason.USER_HANGUP.value,
)
await self.redis.publish(
RedisChannels.worker_events(worker_id), event.to_json()
)
logger.info(f"channelID: {channel_id} Published StasisEnd event")
# Notify the connection about channel end
await connection.notify_channel_end()
# Mark channel as disposed
if caller_channel_id in self._channel_disposed:
self._channel_disposed[caller_channel_id] = True
# Check if both flags are set to cleanup
await self._check_and_cleanup_channel(caller_channel_id)
async def _on_stasis_call(
self, connection: ARIManagerConnection, call_context_vars: dict
):
"""Handle new Stasis call by setting up the connection and publishing to Redis."""
try:
# Setup the connection (create bridge and external media)
await connection.setup_call()
if not connection.is_connected():
logger.warning("Connection is not connected, skipping")
return
# Extract all necessary information after bridge is created
channel_id = connection.caller_channel_id
em_channel_id = connection.em_channel_id
bridge_id = connection.bridge_id
# Track this channel as active
self._active_channels.add(channel_id)
# Create event with all connection details
event = StasisStartEvent(
channel_id=channel_id,
caller_channel_id=channel_id,
em_channel_id=em_channel_id,
bridge_id=bridge_id,
local_addr=list(connection.local_addr),
remote_addr=list(connection.remote_addr)
if connection.remote_addr
else None,
call_context_vars=call_context_vars,
)
# Select worker using round-robin
worker_id = await self._select_worker()
if worker_id is None:
logger.error(f"channelID: {channel_id} No active workers available")
await connection.disconnect()
return
# Track channel to worker mapping
self._channel_to_worker[channel_id] = worker_id
channel = RedisChannels.worker_events(worker_id)
# Publish event to specific worker
await self.redis.publish(channel, event.to_json())
logger.info(
f"channelID: {channel_id} Published stasis_start event to worker {worker_id}"
)
# Start monitoring for commands from workers
self._tasks[channel_id] = asyncio.create_task(
self._monitor_channel_commands(channel_id, connection)
)
except Exception as e:
logger.exception(f"Error handling stasis call: {e}")
async def _get_and_allocate_port_atomic(self, channel_id: str) -> int:
"""Atomically find and allocate an available port using Redis Lua script.
This method prevents race conditions by using a Lua script that executes
atomically in Redis, ensuring that two concurrent calls cannot allocate
the same port.
"""
# Lua script for atomic port allocation
lua_script = """
local port_range_start = tonumber(ARGV[1])
local port_range_end = tonumber(ARGV[2])
local port_range_step = tonumber(ARGV[3])
local channel_id = KEYS[1]
local timestamp = ARGV[4]
-- Check if channel already has a port allocated
local existing_port = redis.call('HGET', 'channel_ports', channel_id)
if existing_port then
return tonumber(existing_port)
end
-- Find first available port
for port = port_range_start, port_range_end, port_range_step do
local port_str = tostring(port)
local exists = redis.call('HEXISTS', 'port_channels', port_str)
if exists == 0 then
-- Atomically allocate the port
redis.call('HSET', 'channel_ports', channel_id, port)
redis.call('HSET', 'port_channels', port_str, channel_id)
redis.call('HSET', 'channel_allocation_time', channel_id, timestamp)
return port
end
end
return -1 -- No ports available
"""
# Execute the Lua script with port range parameters
port_start = min(self._port_range)
port_end = max(self._port_range)
port_step = self._port_range.step
timestamp = int(time.time())
port = await self.redis.eval(
lua_script,
1, # Number of keys
channel_id, # KEYS[1]
port_start, # ARGV[1]
port_end, # ARGV[2]
port_step, # ARGV[3]
timestamp, # ARGV[4]
)
if port == -1:
# If all ports exhausted, clean up orphaned ports and retry
await self._cleanup_orphaned_ports()
# Retry after cleanup
port = await self.redis.eval(
lua_script, 1, channel_id, port_start, port_end, port_step, timestamp
)
if port == -1:
raise RuntimeError(
"No available ports in configured range after cleanup"
)
logger.debug(f"Atomically allocated port {port} for channel {channel_id}")
return port
async def _release_port_for_channel(self, channel_id: str):
"""Atomically release port when channel ends.
Uses a Lua script to ensure all cleanup operations happen atomically,
preventing partial cleanup or race conditions during release.
"""
lua_script = """
local channel_id = KEYS[1]
-- Get the port allocated to this channel
local port = redis.call('HGET', 'channel_ports', channel_id)
if port then
-- Atomically clean up all related entries
redis.call('HDEL', 'channel_ports', channel_id)
redis.call('HDEL', 'port_channels', port)
redis.call('HDEL', 'channel_allocation_time', channel_id)
return port
end
return nil
"""
port = await self.redis.eval(lua_script, 1, channel_id)
if port:
logger.debug(f"Atomically released port {port} for channel {channel_id}")
else:
logger.debug(f"No port was allocated for channel {channel_id}")
async def _discover_workers(self):
"""Periodically discover active workers from Redis."""
try:
while self._running:
try:
# Get all worker IDs from the set
worker_ids = await self.redis.smembers(RedisKeys.workers_set())
# Filter to only active workers
active_workers = []
for worker_id in worker_ids:
worker_id = (
worker_id.decode()
if isinstance(worker_id, bytes)
else worker_id
)
worker_key = RedisKeys.worker_active(worker_id)
worker_data = await self.redis.get(worker_key)
if worker_data:
try:
data = json.loads(worker_data)
# Only include workers that are ready (not draining)
if data.get("status") == "ready":
active_workers.append(worker_id)
except json.JSONDecodeError:
logger.warning(f"Invalid worker data for {worker_id}")
# Update the cached list atomically
self._active_workers = active_workers
logger.info(f"Discovered {len(active_workers)} active workers")
except Exception as e:
logger.error(f"Error discovering workers: {e}")
# Check every 5 seconds
await asyncio.sleep(5)
except asyncio.CancelledError:
logger.debug("Worker discovery task cancelled")
async def _select_worker(self) -> Optional[str]:
"""Select a worker using round-robin."""
if not self._active_workers:
return None
# Use Redis to maintain round-robin index across restarts
try:
index = await self.redis.incr(RedisKeys.round_robin_index())
worker_index = (index - 1) % len(self._active_workers)
return self._active_workers[worker_index]
except Exception as e:
logger.error(f"Error selecting worker: {e}")
# Fallback to first worker if Redis operation fails
return self._active_workers[0] if self._active_workers else None
def _get_worker_for_channel(self, channel_id: str) -> str:
"""Get the assigned worker for a channel (for sending commands)."""
# Return the worker ID that was assigned to this channel
return self._channel_to_worker.get(channel_id, "")
async def _monitor_channel_commands(
self, channel_id: str, connection: ARIManagerConnection
):
"""Listen for commands from workers for this channel."""
# TODO: Not sure if its a good idea to monitor command for every channel
# using pubsub. What happens if there are more number of calls than number
# of tcp connections redis can support? We can do something similar to
# Campaign Orchestrator, where we can subscribe to one channel and have
# commands for every channel there.
command_channel = RedisChannels.channel_commands(channel_id)
pubsub = None
try:
pubsub = self.redis.pubsub()
await pubsub.subscribe(command_channel)
# Store the pubsub connection for cleanup
self._pubsubs[channel_id] = pubsub
logger.debug(f"channelID: {channel_id} Monitoring commands for channel")
async for message in pubsub.listen():
if message["type"] == "message":
try:
command = parse_command(message["data"])
if command:
await self._handle_worker_command(
channel_id, command, connection
)
else:
logger.warning(
f"Failed to parse command for {channel_id}: {message['data']}"
)
except Exception as e:
logger.exception(
f"Error handling command for {channel_id}: {e}"
)
except asyncio.CancelledError:
logger.debug(f"channelID: {channel_id} Command monitor cancelled")
raise # Re-raise to maintain proper cancellation semantics
except (ConnectionError, redis.exceptions.ConnectionError) as e:
# We close the pubsub before cancelling the task. So, the code
# flow will arrive here
pass
except Exception as e:
logger.exception(f"Error in command monitor for {channel_id}: {e}")
async def _handle_worker_command(
self,
channel_id: str,
command: BaseWorkerToARIManagerCommand,
connection: ARIManagerConnection,
):
"""Execute commands from workers."""
if isinstance(command, DisconnectCommand):
logger.info(f"channelID: {channel_id} Worker requested disconnect")
await connection.disconnect()
elif isinstance(command, TransferCommand):
logger.info(f"channelID: {channel_id} Worker requested transfer")
await connection.transfer(command.context)
elif isinstance(command, SocketClosedCommand):
logger.info(f"channelID: {channel_id} Worker notified socket closed")
# Mark socket as closed
if channel_id in self._socket_closed:
self._socket_closed[channel_id] = True
# Release port immediately
await self._release_port_for_channel(channel_id)
# Check if both flags are set to cleanup
await self._check_and_cleanup_channel(channel_id)
else:
logger.warning(
f"channelID: {channel_id} Received unknown command: {command}"
)
async def _check_and_cleanup_channel(self, channel_id: str):
"""Check if both flags are set and cleanup channel if so."""
channel_disposed = self._channel_disposed.get(channel_id, False)
socket_closed = self._socket_closed.get(channel_id, False)
logger.debug(
f"channelID: {channel_id} Check cleanup - disposed: {channel_disposed}, socket_closed: {socket_closed}"
)
if channel_disposed and socket_closed:
# Remove from active channels and connections
self._active_channels.discard(channel_id)
self._channel_connections.pop(channel_id, None)
# Close pubsub connection first (before cancelling task)
if channel_id in self._pubsubs:
pubsub = self._pubsubs[channel_id]
try:
command_channel = RedisChannels.channel_commands(channel_id)
await pubsub.unsubscribe(command_channel)
await pubsub.aclose()
logger.debug(
f"channelID: {channel_id} Closed pubsub connection in cleanup"
)
except Exception as e:
logger.warning(f"Error closing pubsub for {channel_id}: {e}")
finally:
del self._pubsubs[channel_id]
# Cancel command monitor task
if channel_id in self._tasks:
task = self._tasks[channel_id]
if not task.done():
# Task is still running, cancel it
task.cancel()
try:
# Wait for task to complete
await task
logger.debug(
f"channelID: {channel_id} Task completed after cancel"
)
except asyncio.CancelledError:
logger.debug(
f"channelID: {channel_id} Task cancelled successfully"
)
except Exception as e:
logger.warning(
f"channelID: {channel_id} Task raised exception: {e}"
)
else:
# Task already completed
logger.debug(
f"channelID: {channel_id} Monitor task already completed"
)
try:
# Still await to get any exception that might have occurred
await task
except Exception as e:
logger.warning(
f"channelID: {channel_id} Completed task had exception: {e}"
)
del self._tasks[channel_id]
# Clean up the flag tracking
self._channel_disposed.pop(channel_id, None)
self._socket_closed.pop(channel_id, None)
logger.info(f"channelID: {channel_id} Completed cleanup of all resources")
async def _cleanup_orphaned_ports(self):
"""Clean up ports from previous ungraceful shutdowns."""
try:
# Get all channel-port mappings
channel_ports = await self.redis.hgetall("channel_ports")
if not channel_ports:
return
logger.info(
f"Found {len(channel_ports)} existing port allocations, checking for orphans..."
)
cleaned = 0
current_time = int(time.time())
max_age_seconds = 3600 # 1 hour
# On startup, we can safely assume any existing allocations are orphaned
# since this is a fresh instance with no active channels yet
if not self._active_channels:
# Clean up all existing allocations on startup
for channel_id, port in channel_ports.items():
allocation_time = await self.redis.hget(
"channel_allocation_time", channel_id
)
age_str = ""
if allocation_time:
age = current_time - int(allocation_time)
age_str = f" (aged {age}s)"
await self._release_port_for_channel(channel_id)
logger.info(
f"Cleaned up orphaned port {port} for channel {channel_id}{age_str}"
)
cleaned += 1
else:
# During runtime, only clean up channels not being tracked
for channel_id, port in channel_ports.items():
if channel_id not in self._active_channels:
# Check allocation age
allocation_time = await self.redis.hget(
"channel_allocation_time", channel_id
)
if allocation_time:
age = current_time - int(allocation_time)
if age > max_age_seconds:
# Too old, clean up regardless
await self._release_port_for_channel(channel_id)
logger.info(
f"Cleaned up stale port {port} for channel {channel_id} (aged {age}s)"
)
cleaned += 1
continue
# Not tracked by this instance, might be orphaned
# For safety, only clean up if reasonably old (5 minutes)
if (
allocation_time
and (current_time - int(allocation_time)) > 300
):
await self._release_port_for_channel(channel_id)
logger.info(
f"Cleaned up orphaned port {port} for untracked channel {channel_id}"
)
cleaned += 1
if cleaned > 0:
logger.info(f"Cleaned up {cleaned} orphaned port allocations")
except Exception as e:
logger.exception(f"Error during orphaned port cleanup: {e}")
async def _periodic_cleanup(self):
"""Periodically clean up orphaned ports."""
cleanup_interval = 1800 # 30 minutes
while self._running:
try:
await asyncio.sleep(cleanup_interval)
if self._running: # Check again after sleep
logger.info("Running periodic orphaned port cleanup...")
await self._cleanup_orphaned_ports()
except asyncio.CancelledError:
logger.debug("Periodic cleanup task cancelled")
break
except Exception as e:
logger.exception(f"Error in periodic cleanup: {e}")
async def run(self):
"""Main run loop for ARI Manager."""
self._running = True
# Setup ARI connection with supervisor
try:
self._ari_client_supervisor = await setup_ari_client_supervisor(
self.on_channel_start, self.on_channel_end
)
if not self._ari_client_supervisor:
logger.error("Failed to setup ARI connection")
return
# Start worker discovery task
self._worker_discovery_task = asyncio.create_task(self._discover_workers())
# Wait a moment for initial worker discovery
await asyncio.sleep(1)
logger.info(
f"ARI Manager started with {len(self._active_workers)} active workers"
)
# Clean up any orphaned ports from previous runs
await self._cleanup_orphaned_ports()
# Start periodic cleanup task
cleanup_task = asyncio.create_task(self._periodic_cleanup())
# Keep running until shutdown
while self._running:
await asyncio.sleep(1)
logger.debug("ARIManager._running is false. Will cleanup and shutdown")
# Cancel cleanup task
cleanup_task.cancel()
try:
await cleanup_task
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"ARI Manager error: {e}")
finally:
if self._ari_client_supervisor:
await self._ari_client_supervisor.close()
logger.info("ARI Manager stopped")
async def shutdown(self):
"""Graceful shutdown."""
logger.info("Shutting down ARI Manager...")
# Close supervisor first to prevent reconnection attempts
if self._ari_client_supervisor:
await self._ari_client_supervisor.close()
# Cancel worker discovery task
if self._worker_discovery_task:
self._worker_discovery_task.cancel()
try:
await self._worker_discovery_task
except asyncio.CancelledError:
pass
self._worker_discovery_task = None
# Now set running to False
self._running = False
# Clean up all active channel ports before shutting down
if self._active_channels:
logger.info(f"Cleaning up {len(self._active_channels)} active channels...")
for channel_id in list(
self._active_channels
): # Copy to avoid modification during iteration
await self._release_port_for_channel(channel_id)
logger.info(
f"Released port for active channel {channel_id} during shutdown"
)
self._active_channels.clear()
# Clear flag tracking
self._channel_disposed.clear()
self._socket_closed.clear()
# Cancel all monitoring tasks
for task in self._tasks.values():
task.cancel()
# Wait for tasks to complete
if self._tasks:
await asyncio.gather(*self._tasks.values(), return_exceptions=True)
async def main():
"""Main entry point for ARI Manager service."""
# Setup Redis connection
redis = await aioredis.from_url(REDIS_URL, decode_responses=True)
# Create and run manager
manager = ARIManager(redis)
# Create a shutdown event for clean coordination
shutdown_event = asyncio.Event()
# Setup signal handlers
loop = asyncio.get_event_loop()
def signal_handler(signum):
logger.info(f"Received shutdown signal {signum}")
# Set the shutdown event which will trigger shutdown
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda s=sig: signal_handler(s))
# Run manager with shutdown monitoring
manager_task = asyncio.create_task(manager.run())
shutdown_task = asyncio.create_task(shutdown_event.wait())
try:
# Wait for either normal completion or shutdown signal
done, pending = await asyncio.wait(
[manager_task, shutdown_task], return_when=asyncio.FIRST_COMPLETED
)
# If shutdown was triggered, perform graceful shutdown
if shutdown_task in done:
await manager.shutdown()
# Cancel the manager task if still running
if manager_task in pending:
manager_task.cancel()
try:
await manager_task
except asyncio.CancelledError:
pass
finally:
await redis.aclose()
# --- Ensure Axiom logging listener is stopped gracefully ---
if logging_queue_listener is not None:
logging_queue_listener.stop()
if __name__ == "__main__":
# Configure logging
logger.add("logs/ari_manager.log", rotation="10 MB")
asyncio.run(main())