dograh/api/services/telephony/ari_manager.py

749 lines
29 KiB
Python
Raw Permalink Normal View History

2025-09-09 14:37:32 +05:30
"""Standalone ARI Manager Service for distributed architecture.
This service maintains the single WebSocket connection to Asterisk ARI
and distributes events to multiple FastAPI workers via Redis pub/sub.
ARIManager creates an instance of ARIClientSupervisor and registers the callbacks
on_channel_start and on_channel_end. It is responsible to take in caller_channel
and setup ARIManagerConnection, i.e create bridge for externalMedia.
"""
import asyncio
import json
import os
import signal
import time
from typing import Dict, Optional
from api.constants import REDIS_URL
# --- Add logging setup before importing loguru ---
from api.logging_config import setup_logging
from api.services.telephony.stasis_event_protocol import (
BaseWorkerToARIManagerCommand,
DisconnectCommand,
RedisChannels,
RedisKeys,
SocketClosedCommand,
TransferCommand,
parse_command,
)
logging_queue_listener = setup_logging()
import redis.asyncio as aioredis
import redis.exceptions
from loguru import logger
from pipecat.utils.enums import EndTaskReason
from api.services.telephony.ari_client import Channel
from api.services.telephony.ari_client_manager import (
ARIClientManager,
setup_ari_client_supervisor,
)
from api.services.telephony.ari_manager_connection import ARIManagerConnection
class ARIManager:
"""Manages ARI connection and distributes events to workers via Redis."""
def __init__(self, redis_client: aioredis.Redis):
self.redis = redis_client
self.stasis_manager: Optional[ARIClientManager] = None
self._running = False
self._ari_client_supervisor = None
self._tasks: Dict[str, asyncio.Task] = {}
self._pubsubs: Dict[
str, aioredis.client.PubSub
] = {} # Track pubsub connections
self._active_channels: set[str] = (
set()
) # Track channels managed by this instance
self._port_range = range(4000, 5000, 2) # Even ports only
self._channel_connections: Dict[
str, ARIManagerConnection
] = {} # Track connections by channel ID
self._channel_disposed: Dict[str, bool] = {} # Track channel disposed state
self._socket_closed: Dict[str, bool] = {} # Track socket closed state
self._active_workers: list[str] = [] # Cached list of active workers
self._worker_discovery_task: Optional[asyncio.Task] = None
self._channel_to_worker: Dict[str, str] = {} # Map channel to worker
async def on_channel_start(self, caller_channel: Channel, call_context_vars: dict):
"""Handle new channel from ARIClientManager with atomically allocated port."""
try:
# Atomically allocate port for this channel (prevents race conditions)
port = await self._get_and_allocate_port_atomic(caller_channel.id)
# Create connection with allocated port
connection = ARIManagerConnection(
caller_channel=caller_channel,
host=os.getenv("ARI_STASIS_APP_ENDPOINT"),
port=port,
)
# Track the connection
self._channel_connections[caller_channel.id] = connection
# Initialize channel state flags
self._channel_disposed[caller_channel.id] = False
self._socket_closed[caller_channel.id] = False
# Handle the connection
await self._on_stasis_call(connection, call_context_vars)
except Exception as e:
logger.exception(f"Error handling new channel {caller_channel.id}: {e}")
# Release port if allocation failed
await self._release_port_for_channel(caller_channel.id)
async def on_channel_end(self, channel_id: str):
"""Handle channel end notification from ARIClientManager."""
logger.info(f"channelID: {channel_id} Received channel end notification")
# Find the connection for this channel
connection = None
caller_channel_id = None
# Check if it's a caller channel
if channel_id in self._channel_connections:
connection = self._channel_connections[channel_id]
caller_channel_id = channel_id
else:
# TODO: We are currently not handling StasisEnd on ExternalMedia
for conn_channel_id, conn in self._channel_connections.items():
if conn.em_channel_id and conn.em_channel_id == channel_id:
logger.debug(
f"channelID: {channel_id} ExternalMedia StasisEnd - Ignoring"
)
# connection = conn
# caller_channel_id = conn_channel_id
break
# Publish StasisEnd event to worker immediately
if connection and caller_channel_id:
worker_id = self._get_worker_for_channel(caller_channel_id)
event = {
"type": "stasis_end",
"channel_id": caller_channel_id,
"reason": EndTaskReason.USER_HANGUP.value,
}
await self.redis.publish(
RedisChannels.worker_events(worker_id), json.dumps(event)
)
logger.info(f"channelID: {channel_id} Published StasisEnd event")
# Notify the connection about channel end
await connection.notify_channel_end()
# Mark channel as disposed
if caller_channel_id in self._channel_disposed:
self._channel_disposed[caller_channel_id] = True
# Check if both flags are set to cleanup
await self._check_and_cleanup_channel(caller_channel_id)
async def _on_stasis_call(
self, connection: ARIManagerConnection, call_context_vars: dict
):
"""Handle new Stasis call by setting up the connection and publishing to Redis."""
try:
# Setup the connection (create bridge and external media)
await connection.setup_call()
if not connection.is_connected():
logger.warning("Connection is not connected, skipping")
return
# Extract all necessary information after bridge is created
channel_id = connection.caller_channel_id
em_channel_id = connection.em_channel_id
bridge_id = connection.bridge_id
# Track this channel as active
self._active_channels.add(channel_id)
# Create event with all connection details
event = {
"type": "stasis_start",
"channel_id": channel_id,
"caller_channel_id": channel_id,
"em_channel_id": em_channel_id,
"bridge_id": bridge_id,
"local_addr": list(connection.local_addr),
"remote_addr": list(connection.remote_addr),
"call_context_vars": call_context_vars,
}
# Select worker using round-robin
worker_id = await self._select_worker()
if worker_id is None:
logger.error(f"channelID: {channel_id} No active workers available")
await connection.disconnect()
return
# Track channel to worker mapping
self._channel_to_worker[channel_id] = worker_id
channel = RedisChannels.worker_events(worker_id)
# Publish event to specific worker
await self.redis.publish(channel, json.dumps(event))
logger.info(
f"channelID: {channel_id} Published stasis_start event to worker {worker_id}"
)
# Start monitoring for commands from workers
self._tasks[channel_id] = asyncio.create_task(
self._monitor_channel_commands(channel_id, connection)
)
except Exception as e:
logger.exception(f"Error handling stasis call: {e}")
async def _get_and_allocate_port_atomic(self, channel_id: str) -> int:
"""Atomically find and allocate an available port using Redis Lua script.
This method prevents race conditions by using a Lua script that executes
atomically in Redis, ensuring that two concurrent calls cannot allocate
the same port.
"""
# Lua script for atomic port allocation
lua_script = """
local port_range_start = tonumber(ARGV[1])
local port_range_end = tonumber(ARGV[2])
local port_range_step = tonumber(ARGV[3])
local channel_id = KEYS[1]
local timestamp = ARGV[4]
-- Check if channel already has a port allocated
local existing_port = redis.call('HGET', 'channel_ports', channel_id)
if existing_port then
return tonumber(existing_port)
end
-- Find first available port
for port = port_range_start, port_range_end, port_range_step do
local port_str = tostring(port)
local exists = redis.call('HEXISTS', 'port_channels', port_str)
if exists == 0 then
-- Atomically allocate the port
redis.call('HSET', 'channel_ports', channel_id, port)
redis.call('HSET', 'port_channels', port_str, channel_id)
redis.call('HSET', 'channel_allocation_time', channel_id, timestamp)
return port
end
end
return -1 -- No ports available
"""
# Execute the Lua script with port range parameters
port_start = min(self._port_range)
port_end = max(self._port_range)
port_step = self._port_range.step
timestamp = int(time.time())
port = await self.redis.eval(
lua_script,
1, # Number of keys
channel_id, # KEYS[1]
port_start, # ARGV[1]
port_end, # ARGV[2]
port_step, # ARGV[3]
timestamp, # ARGV[4]
)
if port == -1:
# If all ports exhausted, clean up orphaned ports and retry
await self._cleanup_orphaned_ports()
# Retry after cleanup
port = await self.redis.eval(
lua_script, 1, channel_id, port_start, port_end, port_step, timestamp
)
if port == -1:
raise RuntimeError(
"No available ports in configured range after cleanup"
)
logger.debug(f"Atomically allocated port {port} for channel {channel_id}")
return port
async def _release_port_for_channel(self, channel_id: str):
"""Atomically release port when channel ends.
Uses a Lua script to ensure all cleanup operations happen atomically,
preventing partial cleanup or race conditions during release.
"""
lua_script = """
local channel_id = KEYS[1]
-- Get the port allocated to this channel
local port = redis.call('HGET', 'channel_ports', channel_id)
if port then
-- Atomically clean up all related entries
redis.call('HDEL', 'channel_ports', channel_id)
redis.call('HDEL', 'port_channels', port)
redis.call('HDEL', 'channel_allocation_time', channel_id)
return port
end
return nil
"""
port = await self.redis.eval(lua_script, 1, channel_id)
if port:
logger.debug(f"Atomically released port {port} for channel {channel_id}")
else:
logger.debug(f"No port was allocated for channel {channel_id}")
async def _discover_workers(self):
"""Periodically discover active workers from Redis."""
try:
while self._running:
try:
# Get all worker IDs from the set
worker_ids = await self.redis.smembers(RedisKeys.workers_set())
# Filter to only active workers
active_workers = []
for worker_id in worker_ids:
worker_id = (
worker_id.decode()
if isinstance(worker_id, bytes)
else worker_id
)
worker_key = RedisKeys.worker_active(worker_id)
worker_data = await self.redis.get(worker_key)
if worker_data:
try:
data = json.loads(worker_data)
# Only include workers that are ready (not draining)
if data.get("status") == "ready":
active_workers.append(worker_id)
except json.JSONDecodeError:
logger.warning(f"Invalid worker data for {worker_id}")
# Update the cached list atomically
self._active_workers = active_workers
logger.info(f"Discovered {len(active_workers)} active workers")
except Exception as e:
logger.error(f"Error discovering workers: {e}")
# Check every 5 seconds
await asyncio.sleep(5)
except asyncio.CancelledError:
logger.debug("Worker discovery task cancelled")
async def _select_worker(self) -> Optional[str]:
"""Select a worker using round-robin."""
if not self._active_workers:
return None
# Use Redis to maintain round-robin index across restarts
try:
index = await self.redis.incr(RedisKeys.round_robin_index())
worker_index = (index - 1) % len(self._active_workers)
return self._active_workers[worker_index]
except Exception as e:
logger.error(f"Error selecting worker: {e}")
# Fallback to first worker if Redis operation fails
return self._active_workers[0] if self._active_workers else None
def _get_worker_for_channel(self, channel_id: str) -> str:
"""Get the assigned worker for a channel (for sending commands)."""
# Return the worker ID that was assigned to this channel
return self._channel_to_worker.get(channel_id, "")
async def _monitor_channel_commands(
self, channel_id: str, connection: ARIManagerConnection
):
"""Listen for commands from workers for this channel."""
# TODO: Not sure if its a good idea to monitor command for every channel
# using pubsub. What happens if there are more number of calls than number
# of tcp connections redis can support? We can do something similar to
# Campaign Orchestrator, where we can subscribe to one channel and have
# commands for every channel there.
command_channel = RedisChannels.channel_commands(channel_id)
pubsub = None
try:
pubsub = self.redis.pubsub()
await pubsub.subscribe(command_channel)
# Store the pubsub connection for cleanup
self._pubsubs[channel_id] = pubsub
logger.debug(f"channelID: {channel_id} Monitoring commands for channel")
async for message in pubsub.listen():
if message["type"] == "message":
try:
command = parse_command(message["data"])
if command:
await self._handle_worker_command(
channel_id, command, connection
)
else:
logger.warning(
f"Failed to parse command for {channel_id}: {message['data']}"
)
except Exception as e:
logger.exception(
f"Error handling command for {channel_id}: {e}"
)
except asyncio.CancelledError:
logger.debug(f"channelID: {channel_id} Command monitor cancelled")
raise # Re-raise to maintain proper cancellation semantics
except (ConnectionError, redis.exceptions.ConnectionError) as e:
# We close the pubsub before cancelling the task. So, the code
# flow will arrive here
pass
except Exception as e:
logger.exception(f"Error in command monitor for {channel_id}: {e}")
async def _handle_worker_command(
self,
channel_id: str,
command: BaseWorkerToARIManagerCommand,
connection: ARIManagerConnection,
):
"""Execute commands from workers."""
if isinstance(command, DisconnectCommand):
logger.info(
f"channelID: {channel_id} Worker requested disconnect: {command.reason}"
)
await connection.disconnect(command.reason)
elif isinstance(command, TransferCommand):
logger.info(f"channelID: {channel_id} Worker requested transfer")
await connection.transfer(command.context)
elif isinstance(command, SocketClosedCommand):
logger.info(f"channelID: {channel_id} Worker notified socket closed")
# Mark socket as closed
if channel_id in self._socket_closed:
self._socket_closed[channel_id] = True
# Release port immediately
await self._release_port_for_channel(channel_id)
# Check if both flags are set to cleanup
await self._check_and_cleanup_channel(channel_id)
else:
logger.warning(
f"channelID: {channel_id} Received unknown command: {command}"
)
async def _check_and_cleanup_channel(self, channel_id: str):
"""Check if both flags are set and cleanup channel if so."""
channel_disposed = self._channel_disposed.get(channel_id, False)
socket_closed = self._socket_closed.get(channel_id, False)
logger.debug(
f"channelID: {channel_id} Check cleanup - disposed: {channel_disposed}, socket_closed: {socket_closed}"
)
if channel_disposed and socket_closed:
# Remove from active channels and connections
self._active_channels.discard(channel_id)
self._channel_connections.pop(channel_id, None)
# Close pubsub connection first (before cancelling task)
if channel_id in self._pubsubs:
pubsub = self._pubsubs[channel_id]
try:
command_channel = RedisChannels.channel_commands(channel_id)
await pubsub.unsubscribe(command_channel)
await pubsub.aclose()
logger.debug(
f"channelID: {channel_id} Closed pubsub connection in cleanup"
)
except Exception as e:
logger.warning(f"Error closing pubsub for {channel_id}: {e}")
finally:
del self._pubsubs[channel_id]
# Cancel command monitor task
if channel_id in self._tasks:
task = self._tasks[channel_id]
if not task.done():
# Task is still running, cancel it
task.cancel()
try:
# Wait for task to complete
await task
logger.debug(
f"channelID: {channel_id} Task completed after cancel"
)
except asyncio.CancelledError:
logger.debug(
f"channelID: {channel_id} Task cancelled successfully"
)
except Exception as e:
logger.warning(
f"channelID: {channel_id} Task raised exception: {e}"
)
else:
# Task already completed
logger.debug(
f"channelID: {channel_id} Monitor task already completed"
)
try:
# Still await to get any exception that might have occurred
await task
except Exception as e:
logger.warning(
f"channelID: {channel_id} Completed task had exception: {e}"
)
del self._tasks[channel_id]
# Clean up the flag tracking
self._channel_disposed.pop(channel_id, None)
self._socket_closed.pop(channel_id, None)
logger.info(f"channelID: {channel_id} Completed cleanup of all resources")
async def _cleanup_orphaned_ports(self):
"""Clean up ports from previous ungraceful shutdowns."""
try:
# Get all channel-port mappings
channel_ports = await self.redis.hgetall("channel_ports")
if not channel_ports:
return
logger.info(
f"Found {len(channel_ports)} existing port allocations, checking for orphans..."
)
cleaned = 0
current_time = int(time.time())
max_age_seconds = 3600 # 1 hour
# On startup, we can safely assume any existing allocations are orphaned
# since this is a fresh instance with no active channels yet
if not self._active_channels:
# Clean up all existing allocations on startup
for channel_id, port in channel_ports.items():
allocation_time = await self.redis.hget(
"channel_allocation_time", channel_id
)
age_str = ""
if allocation_time:
age = current_time - int(allocation_time)
age_str = f" (aged {age}s)"
await self._release_port_for_channel(channel_id)
logger.info(
f"Cleaned up orphaned port {port} for channel {channel_id}{age_str}"
)
cleaned += 1
else:
# During runtime, only clean up channels not being tracked
for channel_id, port in channel_ports.items():
if channel_id not in self._active_channels:
# Check allocation age
allocation_time = await self.redis.hget(
"channel_allocation_time", channel_id
)
if allocation_time:
age = current_time - int(allocation_time)
if age > max_age_seconds:
# Too old, clean up regardless
await self._release_port_for_channel(channel_id)
logger.info(
f"Cleaned up stale port {port} for channel {channel_id} (aged {age}s)"
)
cleaned += 1
continue
# Not tracked by this instance, might be orphaned
# For safety, only clean up if reasonably old (5 minutes)
if (
allocation_time
and (current_time - int(allocation_time)) > 300
):
await self._release_port_for_channel(channel_id)
logger.info(
f"Cleaned up orphaned port {port} for untracked channel {channel_id}"
)
cleaned += 1
if cleaned > 0:
logger.info(f"Cleaned up {cleaned} orphaned port allocations")
except Exception as e:
logger.exception(f"Error during orphaned port cleanup: {e}")
async def _periodic_cleanup(self):
"""Periodically clean up orphaned ports."""
cleanup_interval = 1800 # 30 minutes
while self._running:
try:
await asyncio.sleep(cleanup_interval)
if self._running: # Check again after sleep
logger.info("Running periodic orphaned port cleanup...")
await self._cleanup_orphaned_ports()
except asyncio.CancelledError:
logger.debug("Periodic cleanup task cancelled")
break
except Exception as e:
logger.exception(f"Error in periodic cleanup: {e}")
async def run(self):
"""Main run loop for ARI Manager."""
self._running = True
# Setup ARI connection with supervisor
try:
self._ari_client_supervisor = await setup_ari_client_supervisor(
self.on_channel_start, self.on_channel_end
)
if not self._ari_client_supervisor:
logger.error("Failed to setup ARI connection")
return
# Start worker discovery task
self._worker_discovery_task = asyncio.create_task(self._discover_workers())
# Wait a moment for initial worker discovery
await asyncio.sleep(1)
logger.info(
f"ARI Manager started with {len(self._active_workers)} active workers"
)
# Clean up any orphaned ports from previous runs
await self._cleanup_orphaned_ports()
# Start periodic cleanup task
cleanup_task = asyncio.create_task(self._periodic_cleanup())
# Keep running until shutdown
while self._running:
await asyncio.sleep(1)
logger.debug("ARIManager._running is false. Will cleanup and shutdown")
# Cancel cleanup task
cleanup_task.cancel()
try:
await cleanup_task
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"ARI Manager error: {e}")
finally:
if self._ari_client_supervisor:
await self._ari_client_supervisor.close()
logger.info("ARI Manager stopped")
async def shutdown(self):
"""Graceful shutdown."""
logger.info("Shutting down ARI Manager...")
# Close supervisor first to prevent reconnection attempts
if self._ari_client_supervisor:
await self._ari_client_supervisor.close()
# Cancel worker discovery task
if self._worker_discovery_task:
self._worker_discovery_task.cancel()
try:
await self._worker_discovery_task
except asyncio.CancelledError:
pass
self._worker_discovery_task = None
# Now set running to False
self._running = False
# Clean up all active channel ports before shutting down
if self._active_channels:
logger.info(f"Cleaning up {len(self._active_channels)} active channels...")
for channel_id in list(
self._active_channels
): # Copy to avoid modification during iteration
await self._release_port_for_channel(channel_id)
logger.info(
f"Released port for active channel {channel_id} during shutdown"
)
self._active_channels.clear()
# Clear flag tracking
self._channel_disposed.clear()
self._socket_closed.clear()
# Cancel all monitoring tasks
for task in self._tasks.values():
task.cancel()
# Wait for tasks to complete
if self._tasks:
await asyncio.gather(*self._tasks.values(), return_exceptions=True)
async def main():
"""Main entry point for ARI Manager service."""
# Setup Redis connection
redis = await aioredis.from_url(REDIS_URL, decode_responses=True)
# Create and run manager
manager = ARIManager(redis)
# Create a shutdown event for clean coordination
shutdown_event = asyncio.Event()
# Setup signal handlers
loop = asyncio.get_event_loop()
def signal_handler(signum):
logger.info(f"Received shutdown signal {signum}")
# Set the shutdown event which will trigger shutdown
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda s=sig: signal_handler(s))
# Run manager with shutdown monitoring
manager_task = asyncio.create_task(manager.run())
shutdown_task = asyncio.create_task(shutdown_event.wait())
try:
# Wait for either normal completion or shutdown signal
done, pending = await asyncio.wait(
[manager_task, shutdown_task], return_when=asyncio.FIRST_COMPLETED
)
# If shutdown was triggered, perform graceful shutdown
if shutdown_task in done:
await manager.shutdown()
# Cancel the manager task if still running
if manager_task in pending:
manager_task.cancel()
try:
await manager_task
except asyncio.CancelledError:
pass
finally:
await redis.aclose()
# --- Ensure Axiom logging listener is stopped gracefully ---
if logging_queue_listener is not None:
logging_queue_listener.stop()
if __name__ == "__main__":
# Configure logging
logger.add("logs/ari_manager.log", rotation="10 MB")
asyncio.run(main())