mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
748 lines
29 KiB
Python
748 lines
29 KiB
Python
"""Standalone ARI Manager Service for distributed architecture.
|
|
|
|
This service maintains the single WebSocket connection to Asterisk ARI
|
|
and distributes events to multiple FastAPI workers via Redis pub/sub.
|
|
|
|
ARIManager creates an instance of ARIClientSupervisor and registers the callbacks
|
|
on_channel_start and on_channel_end. It is responsible to take in caller_channel
|
|
and setup ARIManagerConnection, i.e create bridge for externalMedia.
|
|
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import signal
|
|
import time
|
|
from typing import Dict, Optional
|
|
|
|
from api.constants import REDIS_URL
|
|
|
|
# --- Add logging setup before importing loguru ---
|
|
from api.logging_config import setup_logging
|
|
from api.services.telephony.stasis_event_protocol import (
|
|
BaseWorkerToARIManagerCommand,
|
|
DisconnectCommand,
|
|
RedisChannels,
|
|
RedisKeys,
|
|
SocketClosedCommand,
|
|
StasisEndEvent,
|
|
StasisStartEvent,
|
|
TransferCommand,
|
|
parse_command,
|
|
)
|
|
|
|
logging_queue_listener = setup_logging()
|
|
|
|
import redis.asyncio as aioredis
|
|
import redis.exceptions
|
|
from loguru import logger
|
|
from pipecat.utils.enums import EndTaskReason
|
|
|
|
from api.services.telephony.ari_client import Channel
|
|
from api.services.telephony.ari_client_manager import (
|
|
ARIClientManager,
|
|
setup_ari_client_supervisor,
|
|
)
|
|
from api.services.telephony.ari_manager_connection import ARIManagerConnection
|
|
|
|
|
|
class ARIManager:
|
|
"""Manages ARI connection and distributes events to workers via Redis."""
|
|
|
|
def __init__(self, redis_client: aioredis.Redis):
|
|
self.redis = redis_client
|
|
self.stasis_manager: Optional[ARIClientManager] = None
|
|
self._running = False
|
|
self._ari_client_supervisor = None
|
|
self._tasks: Dict[str, asyncio.Task] = {}
|
|
self._pubsubs: Dict[
|
|
str, aioredis.client.PubSub
|
|
] = {} # Track pubsub connections
|
|
self._active_channels: set[str] = (
|
|
set()
|
|
) # Track channels managed by this instance
|
|
self._port_range = range(4000, 5000, 2) # Even ports only
|
|
self._channel_connections: Dict[
|
|
str, ARIManagerConnection
|
|
] = {} # Track connections by channel ID
|
|
self._channel_disposed: Dict[str, bool] = {} # Track channel disposed state
|
|
self._socket_closed: Dict[str, bool] = {} # Track socket closed state
|
|
self._active_workers: list[str] = [] # Cached list of active workers
|
|
self._worker_discovery_task: Optional[asyncio.Task] = None
|
|
self._channel_to_worker: Dict[str, str] = {} # Map channel to worker
|
|
|
|
async def on_channel_start(self, caller_channel: Channel, call_context_vars: dict):
|
|
"""Handle new channel from ARIClientManager with atomically allocated port."""
|
|
try:
|
|
# Atomically allocate port for this channel (prevents race conditions)
|
|
port = await self._get_and_allocate_port_atomic(caller_channel.id)
|
|
|
|
# Create connection with allocated port
|
|
connection = ARIManagerConnection(
|
|
caller_channel=caller_channel,
|
|
host=os.getenv("ARI_STASIS_APP_ENDPOINT"),
|
|
port=port,
|
|
)
|
|
|
|
# Track the connection
|
|
self._channel_connections[caller_channel.id] = connection
|
|
# Initialize channel state flags
|
|
self._channel_disposed[caller_channel.id] = False
|
|
self._socket_closed[caller_channel.id] = False
|
|
|
|
# Handle the connection
|
|
await self._on_stasis_call(connection, call_context_vars)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Error handling new channel {caller_channel.id}: {e}")
|
|
# Release port if allocation failed
|
|
await self._release_port_for_channel(caller_channel.id)
|
|
|
|
async def on_channel_end(self, channel_id: str):
|
|
"""Handle channel end notification from ARIClientManager."""
|
|
logger.info(f"channelID: {channel_id} Received channel end notification")
|
|
|
|
# Find the connection for this channel
|
|
connection = None
|
|
caller_channel_id = None
|
|
|
|
# Check if it's a caller channel
|
|
if channel_id in self._channel_connections:
|
|
connection = self._channel_connections[channel_id]
|
|
caller_channel_id = channel_id
|
|
else:
|
|
# TODO: We are currently not handling StasisEnd on ExternalMedia
|
|
for conn_channel_id, conn in self._channel_connections.items():
|
|
if conn.em_channel_id and conn.em_channel_id == channel_id:
|
|
logger.debug(
|
|
f"channelID: {channel_id} ExternalMedia StasisEnd - Ignoring"
|
|
)
|
|
# connection = conn
|
|
# caller_channel_id = conn_channel_id
|
|
break
|
|
|
|
# Publish StasisEnd event to worker immediately
|
|
if connection and caller_channel_id:
|
|
worker_id = self._get_worker_for_channel(caller_channel_id)
|
|
event = StasisEndEvent(
|
|
channel_id=caller_channel_id,
|
|
reason=EndTaskReason.USER_HANGUP.value,
|
|
)
|
|
await self.redis.publish(
|
|
RedisChannels.worker_events(worker_id), event.to_json()
|
|
)
|
|
logger.info(f"channelID: {channel_id} Published StasisEnd event")
|
|
|
|
# Notify the connection about channel end
|
|
await connection.notify_channel_end()
|
|
|
|
# Mark channel as disposed
|
|
if caller_channel_id in self._channel_disposed:
|
|
self._channel_disposed[caller_channel_id] = True
|
|
# Check if both flags are set to cleanup
|
|
await self._check_and_cleanup_channel(caller_channel_id)
|
|
|
|
async def _on_stasis_call(
|
|
self, connection: ARIManagerConnection, call_context_vars: dict
|
|
):
|
|
"""Handle new Stasis call by setting up the connection and publishing to Redis."""
|
|
try:
|
|
# Setup the connection (create bridge and external media)
|
|
await connection.setup_call()
|
|
|
|
if not connection.is_connected():
|
|
logger.warning("Connection is not connected, skipping")
|
|
return
|
|
|
|
# Extract all necessary information after bridge is created
|
|
channel_id = connection.caller_channel_id
|
|
em_channel_id = connection.em_channel_id
|
|
bridge_id = connection.bridge_id
|
|
|
|
# Track this channel as active
|
|
self._active_channels.add(channel_id)
|
|
|
|
# Create event with all connection details
|
|
event = StasisStartEvent(
|
|
channel_id=channel_id,
|
|
caller_channel_id=channel_id,
|
|
em_channel_id=em_channel_id,
|
|
bridge_id=bridge_id,
|
|
local_addr=list(connection.local_addr),
|
|
remote_addr=list(connection.remote_addr)
|
|
if connection.remote_addr
|
|
else None,
|
|
call_context_vars=call_context_vars,
|
|
)
|
|
|
|
# Select worker using round-robin
|
|
worker_id = await self._select_worker()
|
|
if worker_id is None:
|
|
logger.error(f"channelID: {channel_id} No active workers available")
|
|
await connection.disconnect()
|
|
return
|
|
|
|
# Track channel to worker mapping
|
|
self._channel_to_worker[channel_id] = worker_id
|
|
channel = RedisChannels.worker_events(worker_id)
|
|
|
|
# Publish event to specific worker
|
|
await self.redis.publish(channel, event.to_json())
|
|
logger.info(
|
|
f"channelID: {channel_id} Published stasis_start event to worker {worker_id}"
|
|
)
|
|
|
|
# Start monitoring for commands from workers
|
|
self._tasks[channel_id] = asyncio.create_task(
|
|
self._monitor_channel_commands(channel_id, connection)
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Error handling stasis call: {e}")
|
|
|
|
async def _get_and_allocate_port_atomic(self, channel_id: str) -> int:
|
|
"""Atomically find and allocate an available port using Redis Lua script.
|
|
|
|
This method prevents race conditions by using a Lua script that executes
|
|
atomically in Redis, ensuring that two concurrent calls cannot allocate
|
|
the same port.
|
|
"""
|
|
# Lua script for atomic port allocation
|
|
lua_script = """
|
|
local port_range_start = tonumber(ARGV[1])
|
|
local port_range_end = tonumber(ARGV[2])
|
|
local port_range_step = tonumber(ARGV[3])
|
|
local channel_id = KEYS[1]
|
|
local timestamp = ARGV[4]
|
|
|
|
-- Check if channel already has a port allocated
|
|
local existing_port = redis.call('HGET', 'channel_ports', channel_id)
|
|
if existing_port then
|
|
return tonumber(existing_port)
|
|
end
|
|
|
|
-- Find first available port
|
|
for port = port_range_start, port_range_end, port_range_step do
|
|
local port_str = tostring(port)
|
|
local exists = redis.call('HEXISTS', 'port_channels', port_str)
|
|
if exists == 0 then
|
|
-- Atomically allocate the port
|
|
redis.call('HSET', 'channel_ports', channel_id, port)
|
|
redis.call('HSET', 'port_channels', port_str, channel_id)
|
|
redis.call('HSET', 'channel_allocation_time', channel_id, timestamp)
|
|
return port
|
|
end
|
|
end
|
|
|
|
return -1 -- No ports available
|
|
"""
|
|
|
|
# Execute the Lua script with port range parameters
|
|
port_start = min(self._port_range)
|
|
port_end = max(self._port_range)
|
|
port_step = self._port_range.step
|
|
timestamp = int(time.time())
|
|
|
|
port = await self.redis.eval(
|
|
lua_script,
|
|
1, # Number of keys
|
|
channel_id, # KEYS[1]
|
|
port_start, # ARGV[1]
|
|
port_end, # ARGV[2]
|
|
port_step, # ARGV[3]
|
|
timestamp, # ARGV[4]
|
|
)
|
|
|
|
if port == -1:
|
|
# If all ports exhausted, clean up orphaned ports and retry
|
|
await self._cleanup_orphaned_ports()
|
|
|
|
# Retry after cleanup
|
|
port = await self.redis.eval(
|
|
lua_script, 1, channel_id, port_start, port_end, port_step, timestamp
|
|
)
|
|
|
|
if port == -1:
|
|
raise RuntimeError(
|
|
"No available ports in configured range after cleanup"
|
|
)
|
|
|
|
logger.debug(f"Atomically allocated port {port} for channel {channel_id}")
|
|
return port
|
|
|
|
async def _release_port_for_channel(self, channel_id: str):
|
|
"""Atomically release port when channel ends.
|
|
|
|
Uses a Lua script to ensure all cleanup operations happen atomically,
|
|
preventing partial cleanup or race conditions during release.
|
|
"""
|
|
lua_script = """
|
|
local channel_id = KEYS[1]
|
|
|
|
-- Get the port allocated to this channel
|
|
local port = redis.call('HGET', 'channel_ports', channel_id)
|
|
|
|
if port then
|
|
-- Atomically clean up all related entries
|
|
redis.call('HDEL', 'channel_ports', channel_id)
|
|
redis.call('HDEL', 'port_channels', port)
|
|
redis.call('HDEL', 'channel_allocation_time', channel_id)
|
|
return port
|
|
end
|
|
|
|
return nil
|
|
"""
|
|
|
|
port = await self.redis.eval(lua_script, 1, channel_id)
|
|
|
|
if port:
|
|
logger.debug(f"Atomically released port {port} for channel {channel_id}")
|
|
else:
|
|
logger.debug(f"No port was allocated for channel {channel_id}")
|
|
|
|
async def _discover_workers(self):
|
|
"""Periodically discover active workers from Redis."""
|
|
try:
|
|
while self._running:
|
|
try:
|
|
# Get all worker IDs from the set
|
|
worker_ids = await self.redis.smembers(RedisKeys.workers_set())
|
|
|
|
# Filter to only active workers
|
|
active_workers = []
|
|
for worker_id in worker_ids:
|
|
worker_id = (
|
|
worker_id.decode()
|
|
if isinstance(worker_id, bytes)
|
|
else worker_id
|
|
)
|
|
worker_key = RedisKeys.worker_active(worker_id)
|
|
worker_data = await self.redis.get(worker_key)
|
|
|
|
if worker_data:
|
|
try:
|
|
data = json.loads(worker_data)
|
|
# Only include workers that are ready (not draining)
|
|
if data.get("status") == "ready":
|
|
active_workers.append(worker_id)
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Invalid worker data for {worker_id}")
|
|
|
|
# Update the cached list atomically
|
|
self._active_workers = active_workers
|
|
logger.info(f"Discovered {len(active_workers)} active workers")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error discovering workers: {e}")
|
|
|
|
# Check every 5 seconds
|
|
await asyncio.sleep(5)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.debug("Worker discovery task cancelled")
|
|
|
|
async def _select_worker(self) -> Optional[str]:
|
|
"""Select a worker using round-robin."""
|
|
if not self._active_workers:
|
|
return None
|
|
|
|
# Use Redis to maintain round-robin index across restarts
|
|
try:
|
|
index = await self.redis.incr(RedisKeys.round_robin_index())
|
|
worker_index = (index - 1) % len(self._active_workers)
|
|
return self._active_workers[worker_index]
|
|
except Exception as e:
|
|
logger.error(f"Error selecting worker: {e}")
|
|
# Fallback to first worker if Redis operation fails
|
|
return self._active_workers[0] if self._active_workers else None
|
|
|
|
def _get_worker_for_channel(self, channel_id: str) -> str:
|
|
"""Get the assigned worker for a channel (for sending commands)."""
|
|
# Return the worker ID that was assigned to this channel
|
|
return self._channel_to_worker.get(channel_id, "")
|
|
|
|
async def _monitor_channel_commands(
|
|
self, channel_id: str, connection: ARIManagerConnection
|
|
):
|
|
"""Listen for commands from workers for this channel."""
|
|
# TODO: Not sure if its a good idea to monitor command for every channel
|
|
# using pubsub. What happens if there are more number of calls than number
|
|
# of tcp connections redis can support? We can do something similar to
|
|
# Campaign Orchestrator, where we can subscribe to one channel and have
|
|
# commands for every channel there.
|
|
command_channel = RedisChannels.channel_commands(channel_id)
|
|
pubsub = None
|
|
|
|
try:
|
|
pubsub = self.redis.pubsub()
|
|
await pubsub.subscribe(command_channel)
|
|
|
|
# Store the pubsub connection for cleanup
|
|
self._pubsubs[channel_id] = pubsub
|
|
|
|
logger.debug(f"channelID: {channel_id} Monitoring commands for channel")
|
|
|
|
async for message in pubsub.listen():
|
|
if message["type"] == "message":
|
|
try:
|
|
command = parse_command(message["data"])
|
|
if command:
|
|
await self._handle_worker_command(
|
|
channel_id, command, connection
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Failed to parse command for {channel_id}: {message['data']}"
|
|
)
|
|
except Exception as e:
|
|
logger.exception(
|
|
f"Error handling command for {channel_id}: {e}"
|
|
)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.debug(f"channelID: {channel_id} Command monitor cancelled")
|
|
raise # Re-raise to maintain proper cancellation semantics
|
|
except (ConnectionError, redis.exceptions.ConnectionError) as e:
|
|
# We close the pubsub before cancelling the task. So, the code
|
|
# flow will arrive here
|
|
pass
|
|
except Exception as e:
|
|
logger.exception(f"Error in command monitor for {channel_id}: {e}")
|
|
|
|
async def _handle_worker_command(
|
|
self,
|
|
channel_id: str,
|
|
command: BaseWorkerToARIManagerCommand,
|
|
connection: ARIManagerConnection,
|
|
):
|
|
"""Execute commands from workers."""
|
|
if isinstance(command, DisconnectCommand):
|
|
logger.info(f"channelID: {channel_id} Worker requested disconnect")
|
|
await connection.disconnect()
|
|
|
|
elif isinstance(command, TransferCommand):
|
|
logger.info(f"channelID: {channel_id} Worker requested transfer")
|
|
await connection.transfer(command.context)
|
|
|
|
elif isinstance(command, SocketClosedCommand):
|
|
logger.info(f"channelID: {channel_id} Worker notified socket closed")
|
|
|
|
# Mark socket as closed
|
|
if channel_id in self._socket_closed:
|
|
self._socket_closed[channel_id] = True
|
|
|
|
# Release port immediately
|
|
await self._release_port_for_channel(channel_id)
|
|
|
|
# Check if both flags are set to cleanup
|
|
await self._check_and_cleanup_channel(channel_id)
|
|
else:
|
|
logger.warning(
|
|
f"channelID: {channel_id} Received unknown command: {command}"
|
|
)
|
|
|
|
async def _check_and_cleanup_channel(self, channel_id: str):
|
|
"""Check if both flags are set and cleanup channel if so."""
|
|
channel_disposed = self._channel_disposed.get(channel_id, False)
|
|
socket_closed = self._socket_closed.get(channel_id, False)
|
|
|
|
logger.debug(
|
|
f"channelID: {channel_id} Check cleanup - disposed: {channel_disposed}, socket_closed: {socket_closed}"
|
|
)
|
|
|
|
if channel_disposed and socket_closed:
|
|
# Remove from active channels and connections
|
|
self._active_channels.discard(channel_id)
|
|
self._channel_connections.pop(channel_id, None)
|
|
|
|
# Close pubsub connection first (before cancelling task)
|
|
if channel_id in self._pubsubs:
|
|
pubsub = self._pubsubs[channel_id]
|
|
try:
|
|
command_channel = RedisChannels.channel_commands(channel_id)
|
|
await pubsub.unsubscribe(command_channel)
|
|
await pubsub.aclose()
|
|
logger.debug(
|
|
f"channelID: {channel_id} Closed pubsub connection in cleanup"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Error closing pubsub for {channel_id}: {e}")
|
|
finally:
|
|
del self._pubsubs[channel_id]
|
|
|
|
# Cancel command monitor task
|
|
if channel_id in self._tasks:
|
|
task = self._tasks[channel_id]
|
|
if not task.done():
|
|
# Task is still running, cancel it
|
|
task.cancel()
|
|
try:
|
|
# Wait for task to complete
|
|
await task
|
|
logger.debug(
|
|
f"channelID: {channel_id} Task completed after cancel"
|
|
)
|
|
except asyncio.CancelledError:
|
|
logger.debug(
|
|
f"channelID: {channel_id} Task cancelled successfully"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"channelID: {channel_id} Task raised exception: {e}"
|
|
)
|
|
else:
|
|
# Task already completed
|
|
logger.debug(
|
|
f"channelID: {channel_id} Monitor task already completed"
|
|
)
|
|
try:
|
|
# Still await to get any exception that might have occurred
|
|
await task
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"channelID: {channel_id} Completed task had exception: {e}"
|
|
)
|
|
|
|
del self._tasks[channel_id]
|
|
|
|
# Clean up the flag tracking
|
|
self._channel_disposed.pop(channel_id, None)
|
|
self._socket_closed.pop(channel_id, None)
|
|
|
|
logger.info(f"channelID: {channel_id} Completed cleanup of all resources")
|
|
|
|
async def _cleanup_orphaned_ports(self):
|
|
"""Clean up ports from previous ungraceful shutdowns."""
|
|
try:
|
|
# Get all channel-port mappings
|
|
channel_ports = await self.redis.hgetall("channel_ports")
|
|
if not channel_ports:
|
|
return
|
|
|
|
logger.info(
|
|
f"Found {len(channel_ports)} existing port allocations, checking for orphans..."
|
|
)
|
|
|
|
cleaned = 0
|
|
current_time = int(time.time())
|
|
max_age_seconds = 3600 # 1 hour
|
|
|
|
# On startup, we can safely assume any existing allocations are orphaned
|
|
# since this is a fresh instance with no active channels yet
|
|
if not self._active_channels:
|
|
# Clean up all existing allocations on startup
|
|
for channel_id, port in channel_ports.items():
|
|
allocation_time = await self.redis.hget(
|
|
"channel_allocation_time", channel_id
|
|
)
|
|
age_str = ""
|
|
if allocation_time:
|
|
age = current_time - int(allocation_time)
|
|
age_str = f" (aged {age}s)"
|
|
|
|
await self._release_port_for_channel(channel_id)
|
|
logger.info(
|
|
f"Cleaned up orphaned port {port} for channel {channel_id}{age_str}"
|
|
)
|
|
cleaned += 1
|
|
else:
|
|
# During runtime, only clean up channels not being tracked
|
|
for channel_id, port in channel_ports.items():
|
|
if channel_id not in self._active_channels:
|
|
# Check allocation age
|
|
allocation_time = await self.redis.hget(
|
|
"channel_allocation_time", channel_id
|
|
)
|
|
if allocation_time:
|
|
age = current_time - int(allocation_time)
|
|
if age > max_age_seconds:
|
|
# Too old, clean up regardless
|
|
await self._release_port_for_channel(channel_id)
|
|
logger.info(
|
|
f"Cleaned up stale port {port} for channel {channel_id} (aged {age}s)"
|
|
)
|
|
cleaned += 1
|
|
continue
|
|
|
|
# Not tracked by this instance, might be orphaned
|
|
# For safety, only clean up if reasonably old (5 minutes)
|
|
if (
|
|
allocation_time
|
|
and (current_time - int(allocation_time)) > 300
|
|
):
|
|
await self._release_port_for_channel(channel_id)
|
|
logger.info(
|
|
f"Cleaned up orphaned port {port} for untracked channel {channel_id}"
|
|
)
|
|
cleaned += 1
|
|
|
|
if cleaned > 0:
|
|
logger.info(f"Cleaned up {cleaned} orphaned port allocations")
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Error during orphaned port cleanup: {e}")
|
|
|
|
async def _periodic_cleanup(self):
|
|
"""Periodically clean up orphaned ports."""
|
|
cleanup_interval = 1800 # 30 minutes
|
|
|
|
while self._running:
|
|
try:
|
|
await asyncio.sleep(cleanup_interval)
|
|
if self._running: # Check again after sleep
|
|
logger.info("Running periodic orphaned port cleanup...")
|
|
await self._cleanup_orphaned_ports()
|
|
except asyncio.CancelledError:
|
|
logger.debug("Periodic cleanup task cancelled")
|
|
break
|
|
except Exception as e:
|
|
logger.exception(f"Error in periodic cleanup: {e}")
|
|
|
|
async def run(self):
|
|
"""Main run loop for ARI Manager."""
|
|
self._running = True
|
|
|
|
# Setup ARI connection with supervisor
|
|
try:
|
|
self._ari_client_supervisor = await setup_ari_client_supervisor(
|
|
self.on_channel_start, self.on_channel_end
|
|
)
|
|
if not self._ari_client_supervisor:
|
|
logger.error("Failed to setup ARI connection")
|
|
return
|
|
|
|
# Start worker discovery task
|
|
self._worker_discovery_task = asyncio.create_task(self._discover_workers())
|
|
|
|
# Wait a moment for initial worker discovery
|
|
await asyncio.sleep(1)
|
|
|
|
logger.info(
|
|
f"ARI Manager started with {len(self._active_workers)} active workers"
|
|
)
|
|
|
|
# Clean up any orphaned ports from previous runs
|
|
await self._cleanup_orphaned_ports()
|
|
|
|
# Start periodic cleanup task
|
|
cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
|
|
|
# Keep running until shutdown
|
|
while self._running:
|
|
await asyncio.sleep(1)
|
|
|
|
logger.debug("ARIManager._running is false. Will cleanup and shutdown")
|
|
|
|
# Cancel cleanup task
|
|
cleanup_task.cancel()
|
|
try:
|
|
await cleanup_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
except Exception as e:
|
|
logger.exception(f"ARI Manager error: {e}")
|
|
finally:
|
|
if self._ari_client_supervisor:
|
|
await self._ari_client_supervisor.close()
|
|
logger.info("ARI Manager stopped")
|
|
|
|
async def shutdown(self):
|
|
"""Graceful shutdown."""
|
|
logger.info("Shutting down ARI Manager...")
|
|
|
|
# Close supervisor first to prevent reconnection attempts
|
|
if self._ari_client_supervisor:
|
|
await self._ari_client_supervisor.close()
|
|
|
|
# Cancel worker discovery task
|
|
if self._worker_discovery_task:
|
|
self._worker_discovery_task.cancel()
|
|
try:
|
|
await self._worker_discovery_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._worker_discovery_task = None
|
|
|
|
# Now set running to False
|
|
self._running = False
|
|
|
|
# Clean up all active channel ports before shutting down
|
|
if self._active_channels:
|
|
logger.info(f"Cleaning up {len(self._active_channels)} active channels...")
|
|
for channel_id in list(
|
|
self._active_channels
|
|
): # Copy to avoid modification during iteration
|
|
await self._release_port_for_channel(channel_id)
|
|
logger.info(
|
|
f"Released port for active channel {channel_id} during shutdown"
|
|
)
|
|
self._active_channels.clear()
|
|
|
|
# Clear flag tracking
|
|
self._channel_disposed.clear()
|
|
self._socket_closed.clear()
|
|
|
|
# Cancel all monitoring tasks
|
|
for task in self._tasks.values():
|
|
task.cancel()
|
|
|
|
# Wait for tasks to complete
|
|
if self._tasks:
|
|
await asyncio.gather(*self._tasks.values(), return_exceptions=True)
|
|
|
|
|
|
async def main():
|
|
"""Main entry point for ARI Manager service."""
|
|
# Setup Redis connection
|
|
redis = await aioredis.from_url(REDIS_URL, decode_responses=True)
|
|
|
|
# Create and run manager
|
|
manager = ARIManager(redis)
|
|
|
|
# Create a shutdown event for clean coordination
|
|
shutdown_event = asyncio.Event()
|
|
|
|
# Setup signal handlers
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def signal_handler(signum):
|
|
logger.info(f"Received shutdown signal {signum}")
|
|
# Set the shutdown event which will trigger shutdown
|
|
shutdown_event.set()
|
|
|
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
loop.add_signal_handler(sig, lambda s=sig: signal_handler(s))
|
|
|
|
# Run manager with shutdown monitoring
|
|
manager_task = asyncio.create_task(manager.run())
|
|
shutdown_task = asyncio.create_task(shutdown_event.wait())
|
|
|
|
try:
|
|
# Wait for either normal completion or shutdown signal
|
|
done, pending = await asyncio.wait(
|
|
[manager_task, shutdown_task], return_when=asyncio.FIRST_COMPLETED
|
|
)
|
|
|
|
# If shutdown was triggered, perform graceful shutdown
|
|
if shutdown_task in done:
|
|
await manager.shutdown()
|
|
# Cancel the manager task if still running
|
|
if manager_task in pending:
|
|
manager_task.cancel()
|
|
try:
|
|
await manager_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
await redis.aclose()
|
|
# --- Ensure Axiom logging listener is stopped gracefully ---
|
|
if logging_queue_listener is not None:
|
|
logging_queue_listener.stop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Configure logging
|
|
logger.add("logs/ari_manager.log", rotation="10 MB")
|
|
asyncio.run(main())
|