mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
feat: ari outbound dialing
This commit is contained in:
parent
1349654c75
commit
e0f43ccf27
11 changed files with 1165 additions and 18 deletions
42
api/alembic/versions/6d2f94baf4b7_add_ari_mode.py
Normal file
42
api/alembic/versions/6d2f94baf4b7_add_ari_mode.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""add ari mode
|
||||
|
||||
Revision ID: 6d2f94baf4b7
|
||||
Revises: 1a7d74d54e8f
|
||||
Create Date: 2026-02-15 13:52:29.285583
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic_postgresql_enum import TableReference
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '6d2f94baf4b7'
|
||||
down_revision: Union[str, None] = '1a7d74d54e8f'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.sync_enum_values(
|
||||
enum_schema='public',
|
||||
enum_name='workflow_run_mode',
|
||||
new_values=['ari', 'twilio', 'vonage', 'vobiz', 'cloudonix', 'webrtc', 'smallwebrtc', 'stasis', 'VOICE', 'CHAT'],
|
||||
affected_columns=[TableReference(table_schema='public', table_name='workflow_runs', column_name='mode')],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.sync_enum_values(
|
||||
enum_schema='public',
|
||||
enum_name='workflow_run_mode',
|
||||
new_values=['twilio', 'vonage', 'vobiz', 'cloudonix', 'stasis', 'webrtc', 'smallwebrtc', 'VOICE', 'CHAT'],
|
||||
affected_columns=[TableReference(table_schema='public', table_name='workflow_runs', column_name='mode')],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
|
@ -94,3 +94,27 @@ class OrganizationConfigurationClient(BaseDBClient):
|
|||
"""Get the value of a configuration, returning default if not found."""
|
||||
config = await self.get_configuration(organization_id, key)
|
||||
return config.value if config else default
|
||||
|
||||
async def get_configurations_by_provider(
|
||||
self, key: str, provider: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get all organization configurations for a given key filtered by provider.
|
||||
|
||||
Returns a list of dicts with organization_id and the config value.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(OrganizationConfigurationModel).where(
|
||||
OrganizationConfigurationModel.key == key,
|
||||
)
|
||||
)
|
||||
configs = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"organization_id": config.organization_id,
|
||||
"value": config.value,
|
||||
}
|
||||
for config in configs
|
||||
if config.value and config.value.get("provider") == provider
|
||||
]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ class CallType(Enum):
|
|||
|
||||
|
||||
class WorkflowRunMode(Enum):
|
||||
ARI = "ari"
|
||||
TWILIO = "twilio"
|
||||
VONAGE = "vonage"
|
||||
VOBIZ = "vobiz"
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from api.db import db_client
|
|||
from api.db.models import UserModel
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.telephony_config import (
|
||||
ARIConfigurationRequest,
|
||||
ARIConfigurationResponse,
|
||||
CloudonixConfigurationRequest,
|
||||
CloudonixConfigurationResponse,
|
||||
TelephonyConfigurationResponse,
|
||||
|
|
@ -29,6 +31,7 @@ PROVIDER_MASKED_FIELDS = {
|
|||
"vonage": ["private_key", "api_key", "api_secret"],
|
||||
"vobiz": ["auth_id", "auth_token"],
|
||||
"cloudonix": ["bearer_token"],
|
||||
"ari": ["app_password"],
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -125,6 +128,21 @@ async def get_telephony_configuration(user: UserModel = Depends(get_user)):
|
|||
),
|
||||
vobiz=None,
|
||||
)
|
||||
elif stored_provider == "ari":
|
||||
ari_endpoint = config.value.get("ari_endpoint", "")
|
||||
app_name = config.value.get("app_name", "")
|
||||
app_password = config.value.get("app_password", "")
|
||||
from_numbers = config.value.get("from_numbers", [])
|
||||
|
||||
return TelephonyConfigurationResponse(
|
||||
ari=ARIConfigurationResponse(
|
||||
provider="ari",
|
||||
ari_endpoint=ari_endpoint,
|
||||
app_name=app_name,
|
||||
app_password=mask_key(app_password) if app_password else "",
|
||||
from_numbers=from_numbers,
|
||||
),
|
||||
)
|
||||
else:
|
||||
return TelephonyConfigurationResponse()
|
||||
|
||||
|
|
@ -136,6 +154,7 @@ async def save_telephony_configuration(
|
|||
VonageConfigurationRequest,
|
||||
VobizConfigurationRequest,
|
||||
CloudonixConfigurationRequest,
|
||||
ARIConfigurationRequest,
|
||||
],
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
|
|
@ -180,6 +199,14 @@ async def save_telephony_configuration(
|
|||
"domain_id": request.domain_id,
|
||||
"from_numbers": request.from_numbers,
|
||||
}
|
||||
elif request.provider == "ari":
|
||||
config_value = {
|
||||
"provider": "ari",
|
||||
"ari_endpoint": request.ari_endpoint,
|
||||
"app_name": request.app_name,
|
||||
"app_password": request.app_password,
|
||||
"from_numbers": request.from_numbers,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Unsupported provider: {request.provider}"
|
||||
|
|
|
|||
|
|
@ -89,6 +89,33 @@ class CloudonixConfigurationResponse(BaseModel):
|
|||
from_numbers: List[str]
|
||||
|
||||
|
||||
class ARIConfigurationRequest(BaseModel):
|
||||
"""Request schema for Asterisk ARI configuration."""
|
||||
|
||||
provider: str = Field(default="ari")
|
||||
ari_endpoint: str = Field(
|
||||
..., description="ARI base URL (e.g., http://asterisk.example.com:8088)"
|
||||
)
|
||||
app_name: str = Field(
|
||||
..., description="Stasis application name registered in Asterisk"
|
||||
)
|
||||
app_password: str = Field(..., description="ARI user password")
|
||||
from_numbers: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of SIP extensions/numbers for outbound calls (optional)",
|
||||
)
|
||||
|
||||
|
||||
class ARIConfigurationResponse(BaseModel):
|
||||
"""Response schema for ARI configuration with masked sensitive fields."""
|
||||
|
||||
provider: str
|
||||
ari_endpoint: str
|
||||
app_name: str
|
||||
app_password: str # Masked
|
||||
from_numbers: List[str]
|
||||
|
||||
|
||||
class TelephonyConfigurationResponse(BaseModel):
|
||||
"""Top-level telephony configuration response."""
|
||||
|
||||
|
|
@ -96,3 +123,4 @@ class TelephonyConfigurationResponse(BaseModel):
|
|||
vonage: Optional[VonageConfigurationResponse] = None
|
||||
vobiz: Optional[VobizConfigurationResponse] = None
|
||||
cloudonix: Optional[CloudonixConfigurationResponse] = None
|
||||
ari: Optional[ARIConfigurationResponse] = None
|
||||
|
|
|
|||
379
api/services/telephony/ari_manager.py
Normal file
379
api/services/telephony/ari_manager.py
Normal file
|
|
@ -0,0 +1,379 @@
|
|||
"""ARI WebSocket Event Listener Manager.
|
||||
|
||||
Standalone process that:
|
||||
1. Queries the database for all organizations with ARI telephony configuration
|
||||
2. Creates WebSocket connections to each ARI instance
|
||||
3. Handles reconnection logic with exponential backoff
|
||||
4. Processes StasisStart/StasisEnd events
|
||||
5. Periodically refreshes configuration to detect new/removed organizations
|
||||
"""
|
||||
|
||||
from api.logging_config import setup_logging
|
||||
|
||||
setup_logging()
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import signal
|
||||
from typing import Any, Dict, Optional, Set
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import websockets
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
|
||||
|
||||
class ARIConnection:
|
||||
"""Manages a single ARI WebSocket connection for an organization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
organization_id: int,
|
||||
ari_endpoint: str,
|
||||
app_name: str,
|
||||
app_password: str,
|
||||
):
|
||||
self.organization_id = organization_id
|
||||
self.ari_endpoint = ari_endpoint.rstrip("/")
|
||||
self.app_name = app_name
|
||||
self.app_password = app_password
|
||||
|
||||
self._ws: Optional[websockets.ClientConnection] = None
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
self._reconnect_delay = 1 # Start with 1 second
|
||||
self._max_reconnect_delay = 300 # Max 300 seconds
|
||||
self._ping_interval = 30 # Send ping every 30 seconds
|
||||
|
||||
@property
|
||||
def ws_url(self) -> str:
|
||||
"""Build the ARI WebSocket URL."""
|
||||
parsed = urlparse(self.ari_endpoint)
|
||||
ws_scheme = "wss" if parsed.scheme == "https" else "ws"
|
||||
return (
|
||||
f"{ws_scheme}://{parsed.netloc}/ari/events"
|
||||
f"?api_key={self.app_name}:{self.app_password}"
|
||||
f"&app={self.app_name}"
|
||||
f"&subscribeAll=true"
|
||||
)
|
||||
|
||||
@property
|
||||
def connection_key(self) -> str:
|
||||
"""Unique key for this connection based on config."""
|
||||
return f"{self.organization_id}:{self.ari_endpoint}:{self.app_name}"
|
||||
|
||||
async def start(self):
|
||||
"""Start the WebSocket connection in a background task."""
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._connection_loop())
|
||||
logger.info(
|
||||
f"[ARI org={self.organization_id}] Started connection to {self.ari_endpoint}"
|
||||
)
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the WebSocket connection."""
|
||||
self._running = False
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info(
|
||||
f"[ARI org={self.organization_id}] Stopped connection to {self.ari_endpoint}"
|
||||
)
|
||||
|
||||
async def _connection_loop(self):
|
||||
"""Main connection loop with reconnection logic."""
|
||||
while self._running:
|
||||
try:
|
||||
await self._connect_and_listen()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
if not self._running:
|
||||
break
|
||||
logger.warning(
|
||||
f"[ARI org={self.organization_id}] Connection error: {e}. "
|
||||
f"Reconnecting in {self._reconnect_delay}s..."
|
||||
)
|
||||
await asyncio.sleep(self._reconnect_delay)
|
||||
# Exponential backoff
|
||||
self._reconnect_delay = min(
|
||||
self._reconnect_delay * 2, self._max_reconnect_delay
|
||||
)
|
||||
|
||||
async def _connect_and_listen(self):
|
||||
"""Establish WebSocket connection and listen for events."""
|
||||
ws_url = self.ws_url
|
||||
logger.info(
|
||||
f"[ARI org={self.organization_id}] Connecting to {self.ari_endpoint}..."
|
||||
)
|
||||
|
||||
async for ws in websockets.connect(
|
||||
ws_url,
|
||||
ping_interval=self._ping_interval,
|
||||
ping_timeout=10,
|
||||
close_timeout=5,
|
||||
):
|
||||
try:
|
||||
self._ws = ws
|
||||
|
||||
# Reset reconnect delay on successful connection
|
||||
self._reconnect_delay = 1
|
||||
|
||||
logger.info(
|
||||
f"[ARI org={self.organization_id}] WebSocket connected to {self.ari_endpoint}"
|
||||
)
|
||||
|
||||
async for message in ws:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
if isinstance(message, str):
|
||||
await self._handle_event(message)
|
||||
else:
|
||||
logger.debug(
|
||||
f"[ARI org={self.organization_id}] Received binary message, ignoring"
|
||||
)
|
||||
|
||||
except websockets.ConnectionClosed as e:
|
||||
if not self._running:
|
||||
return
|
||||
logger.warning(
|
||||
f"[ARI org={self.organization_id}] WebSocket closed: "
|
||||
f"code={e.code}, reason={e.reason}. Reconnecting..."
|
||||
)
|
||||
continue
|
||||
finally:
|
||||
self._ws = None
|
||||
|
||||
async def _handle_event(self, raw_data: str):
|
||||
"""Handle an ARI WebSocket event."""
|
||||
try:
|
||||
event = json.loads(raw_data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[ARI org={self.organization_id}] Invalid JSON: {raw_data[:200]}"
|
||||
)
|
||||
return
|
||||
|
||||
event_type = event.get("type", "unknown")
|
||||
channel = event.get("channel", {})
|
||||
channel_id = channel.get("id", "unknown")
|
||||
channel_state = channel.get("state", "unknown")
|
||||
|
||||
if event_type == "StasisStart":
|
||||
app_args = event.get("args", [])
|
||||
caller = channel.get("caller", {})
|
||||
logger.info(
|
||||
f"[ARI org={self.organization_id}] StasisStart: "
|
||||
f"channel={channel_id}, state={channel_state}, "
|
||||
f"caller={caller.get('number', 'unknown')}, "
|
||||
f"args={app_args}"
|
||||
)
|
||||
# TODO: This is where we'll integrate with the pipeline
|
||||
# For now, just log the event
|
||||
|
||||
elif event_type == "StasisEnd":
|
||||
logger.info(
|
||||
f"[ARI org={self.organization_id}] StasisEnd: "
|
||||
f"channel={channel_id}"
|
||||
)
|
||||
|
||||
elif event_type == "ChannelStateChange":
|
||||
logger.debug(
|
||||
f"[ARI org={self.organization_id}] ChannelStateChange: "
|
||||
f"channel={channel_id}, state={channel_state}"
|
||||
)
|
||||
|
||||
elif event_type == "ChannelDestroyed":
|
||||
cause = channel.get("cause", 0)
|
||||
cause_txt = channel.get("cause_txt", "unknown")
|
||||
logger.info(
|
||||
f"[ARI org={self.organization_id}] ChannelDestroyed: "
|
||||
f"channel={channel_id}, cause={cause} ({cause_txt})"
|
||||
)
|
||||
|
||||
elif event_type == "ChannelDtmfReceived":
|
||||
digit = event.get("digit", "")
|
||||
logger.debug(
|
||||
f"[ARI org={self.organization_id}] DTMF: "
|
||||
f"channel={channel_id}, digit={digit}"
|
||||
)
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
f"[ARI org={self.organization_id}] Event: {event_type} "
|
||||
f"channel={channel_id}"
|
||||
)
|
||||
|
||||
|
||||
class ARIManager:
|
||||
"""Manages ARI WebSocket connections for all organizations."""
|
||||
|
||||
def __init__(self):
|
||||
self._connections: Dict[str, ARIConnection] = {} # key -> connection
|
||||
self._running = False
|
||||
self._config_refresh_interval = 60 # Check for config changes every 60 seconds
|
||||
|
||||
async def start(self):
|
||||
"""Start the ARI manager."""
|
||||
self._running = True
|
||||
logger.info("ARI Manager starting...")
|
||||
|
||||
# Initial load of configurations
|
||||
await self._refresh_connections()
|
||||
|
||||
# Start periodic config refresh
|
||||
while self._running:
|
||||
await asyncio.sleep(self._config_refresh_interval)
|
||||
if self._running:
|
||||
await self._refresh_connections()
|
||||
|
||||
async def stop(self):
|
||||
"""Stop all connections and clean up."""
|
||||
self._running = False
|
||||
logger.info("ARI Manager stopping...")
|
||||
|
||||
# Stop all connections
|
||||
for conn in self._connections.values():
|
||||
await conn.stop()
|
||||
self._connections.clear()
|
||||
logger.info("ARI Manager stopped")
|
||||
|
||||
async def _refresh_connections(self):
|
||||
"""
|
||||
Refresh connections based on current database configurations.
|
||||
|
||||
- Starts new connections for new ARI configurations
|
||||
- Stops connections for removed configurations
|
||||
- Restarts connections if configuration changed
|
||||
"""
|
||||
try:
|
||||
active_configs = await self._load_ari_configs()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load ARI configurations: {e}")
|
||||
return
|
||||
|
||||
active_keys: Set[str] = set()
|
||||
|
||||
for config in active_configs:
|
||||
org_id = config["organization_id"]
|
||||
ari_endpoint = config["ari_endpoint"]
|
||||
app_name = config["app_name"]
|
||||
app_password = config["app_password"]
|
||||
|
||||
conn = ARIConnection(org_id, ari_endpoint, app_name, app_password)
|
||||
key = conn.connection_key
|
||||
|
||||
active_keys.add(key)
|
||||
|
||||
if key not in self._connections:
|
||||
# New configuration - start connection
|
||||
logger.info(
|
||||
f"[ARI Manager] New ARI config for org {org_id}: {ari_endpoint}"
|
||||
)
|
||||
self._connections[key] = conn
|
||||
await conn.start()
|
||||
else:
|
||||
# Existing configuration - check if password changed
|
||||
existing = self._connections[key]
|
||||
if existing.app_password != app_password:
|
||||
logger.info(
|
||||
f"[ARI Manager] Config changed for org {org_id}, reconnecting..."
|
||||
)
|
||||
await existing.stop()
|
||||
self._connections[key] = conn
|
||||
await conn.start()
|
||||
|
||||
# Stop connections for removed configurations
|
||||
removed_keys = set(self._connections.keys()) - active_keys
|
||||
for key in removed_keys:
|
||||
conn = self._connections.pop(key)
|
||||
logger.info(
|
||||
f"[ARI Manager] Removing connection for org {conn.organization_id}"
|
||||
)
|
||||
await conn.stop()
|
||||
|
||||
if active_configs:
|
||||
logger.info(
|
||||
f"[ARI Manager] Active connections: {len(self._connections)} "
|
||||
f"(orgs: {[c['organization_id'] for c in active_configs]})"
|
||||
)
|
||||
else:
|
||||
logger.debug("[ARI Manager] No ARI configurations found")
|
||||
|
||||
async def _load_ari_configs(self) -> list:
|
||||
"""Load all ARI telephony configurations from the database."""
|
||||
rows = await db_client.get_configurations_by_provider(
|
||||
OrganizationConfigurationKey.TELEPHONY_CONFIGURATION.value, "ari"
|
||||
)
|
||||
|
||||
configs = []
|
||||
for row in rows:
|
||||
org_id = row["organization_id"]
|
||||
value = row["value"]
|
||||
|
||||
ari_endpoint = value.get("ari_endpoint")
|
||||
app_name = value.get("app_name")
|
||||
app_password = value.get("app_password")
|
||||
|
||||
if not all([ari_endpoint, app_name, app_password]):
|
||||
logger.warning(
|
||||
f"[ARI Manager] Incomplete ARI config for org {org_id}, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
{
|
||||
"organization_id": org_id,
|
||||
"ari_endpoint": ari_endpoint,
|
||||
"app_name": app_name,
|
||||
"app_password": app_password,
|
||||
}
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
async def main():
|
||||
"""Entry point for the ARI manager process."""
|
||||
manager = ARIManager()
|
||||
|
||||
# Handle graceful shutdown
|
||||
loop = asyncio.get_running_loop()
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
def signal_handler():
|
||||
logger.info("Received shutdown signal")
|
||||
shutdown_event.set()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
# Start manager in background
|
||||
manager_task = asyncio.create_task(manager.start())
|
||||
|
||||
# Wait for shutdown signal
|
||||
await shutdown_event.wait()
|
||||
|
||||
# Clean up
|
||||
await manager.stop()
|
||||
manager_task.cancel()
|
||||
try:
|
||||
await manager_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("ARI Manager exited cleanly")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -11,6 +11,7 @@ from loguru import logger
|
|||
from api.db import db_client
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.services.telephony.base import TelephonyProvider
|
||||
from api.services.telephony.providers.ari_provider import ARIProvider
|
||||
from api.services.telephony.providers.cloudonix_provider import CloudonixProvider
|
||||
from api.services.telephony.providers.twilio_provider import TwilioProvider
|
||||
from api.services.telephony.providers.vobiz_provider import VobizProvider
|
||||
|
|
@ -75,6 +76,14 @@ async def load_telephony_config(organization_id: int) -> Dict[str, Any]:
|
|||
"domain_id": config.value.get("domain_id"),
|
||||
"from_numbers": config.value.get("from_numbers", []),
|
||||
}
|
||||
elif provider == "ari":
|
||||
return {
|
||||
"provider": "ari",
|
||||
"ari_endpoint": config.value.get("ari_endpoint"),
|
||||
"app_name": config.value.get("app_name"),
|
||||
"app_password": config.value.get("app_password"),
|
||||
"from_numbers": config.value.get("from_numbers", []),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown provider in config: {provider}")
|
||||
|
||||
|
|
@ -115,6 +124,9 @@ async def get_telephony_provider(organization_id: int) -> TelephonyProvider:
|
|||
elif provider_type == "cloudonix":
|
||||
return CloudonixProvider(config)
|
||||
|
||||
elif provider_type == "ari":
|
||||
return ARIProvider(config)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown telephony provider: {provider_type}")
|
||||
|
||||
|
|
@ -127,4 +139,4 @@ async def get_all_telephony_providers() -> List[Type[TelephonyProvider]]:
|
|||
Returns:
|
||||
List of provider classes that can be used for webhook detection
|
||||
"""
|
||||
return [CloudonixProvider, TwilioProvider, VobizProvider, VonageProvider]
|
||||
return [ARIProvider, CloudonixProvider, TwilioProvider, VobizProvider, VonageProvider]
|
||||
|
|
|
|||
416
api/services/telephony/providers/ari_provider.py
Normal file
416
api/services/telephony/providers/ari_provider.py
Normal file
|
|
@ -0,0 +1,416 @@
|
|||
"""
|
||||
Asterisk ARI (Asterisk REST Interface) implementation of the TelephonyProvider interface.
|
||||
|
||||
Uses ARI REST API to originate calls into a Stasis application.
|
||||
The ARI WebSocket event listener runs as a separate process (ari_manager.py).
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.telephony.base import (
|
||||
CallInitiationResult,
|
||||
NormalizedInboundData,
|
||||
TelephonyProvider,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
|
||||
|
||||
class ARIProvider(TelephonyProvider):
|
||||
"""
|
||||
Asterisk ARI implementation of TelephonyProvider.
|
||||
|
||||
Uses ARI REST API for call control and relies on a separate
|
||||
ari_manager process for WebSocket event listening.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = WorkflowRunMode.ARI.value
|
||||
WEBHOOK_ENDPOINT = None # ARI uses WebSocket events, not webhooks
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize ARIProvider with configuration.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing:
|
||||
- ari_endpoint: ARI base URL (e.g., http://asterisk:8088)
|
||||
- app_name: Stasis application name
|
||||
- app_password: ARI user password
|
||||
- from_numbers: List of SIP extensions/numbers (optional)
|
||||
"""
|
||||
self.ari_endpoint = config.get("ari_endpoint", "").rstrip("/")
|
||||
self.app_name = config.get("app_name", "")
|
||||
self.app_password = config.get("app_password", "")
|
||||
self.from_numbers = config.get("from_numbers", [])
|
||||
|
||||
if isinstance(self.from_numbers, str):
|
||||
self.from_numbers = [self.from_numbers]
|
||||
|
||||
self.base_url = f"{self.ari_endpoint}/ari"
|
||||
|
||||
def _get_auth(self) -> aiohttp.BasicAuth:
|
||||
"""Generate BasicAuth for ARI API requests."""
|
||||
return aiohttp.BasicAuth(self.app_name, self.app_password)
|
||||
|
||||
async def initiate_call(
|
||||
self,
|
||||
to_number: str,
|
||||
webhook_url: str,
|
||||
workflow_run_id: Optional[int] = None,
|
||||
from_number: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallInitiationResult:
|
||||
"""
|
||||
Initiate an outbound call via ARI.
|
||||
|
||||
Creates a channel in Asterisk using the ARI channels endpoint.
|
||||
The channel is placed into the Stasis application where
|
||||
the ari_manager will receive the StasisStart event.
|
||||
"""
|
||||
if not self.validate_config():
|
||||
raise ValueError("ARI provider not properly configured")
|
||||
|
||||
endpoint = f"{self.base_url}/channels"
|
||||
|
||||
# Build the SIP endpoint string
|
||||
# to_number can be a SIP URI or extension
|
||||
if to_number.startswith("SIP/") or to_number.startswith("PJSIP/"):
|
||||
sip_endpoint = to_number
|
||||
else:
|
||||
# Default to PJSIP technology
|
||||
sip_endpoint = f"PJSIP/{to_number}"
|
||||
|
||||
# Prepare channel creation data
|
||||
params = {
|
||||
"endpoint": sip_endpoint,
|
||||
"app": self.app_name,
|
||||
"appArgs": f"workflow_run_id={workflow_run_id}" if workflow_run_id else "",
|
||||
}
|
||||
|
||||
if from_number:
|
||||
params["callerId"] = from_number
|
||||
|
||||
# Add variables for tracking
|
||||
variables = {}
|
||||
if workflow_run_id:
|
||||
variables["WORKFLOW_RUN_ID"] = str(workflow_run_id)
|
||||
if kwargs.get("workflow_id"):
|
||||
variables["WORKFLOW_ID"] = str(kwargs["workflow_id"])
|
||||
if kwargs.get("user_id"):
|
||||
variables["USER_ID"] = str(kwargs["user_id"])
|
||||
|
||||
data = {}
|
||||
if variables:
|
||||
data["variables"] = variables
|
||||
|
||||
logger.info(
|
||||
f"[ARI] Initiating call to {sip_endpoint} "
|
||||
f"via app={self.app_name}, workflow_run_id={workflow_run_id}"
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
endpoint,
|
||||
params=params,
|
||||
json=data if data else None,
|
||||
auth=self._get_auth(),
|
||||
) as response:
|
||||
response_text = await response.text()
|
||||
|
||||
if response.status != 200:
|
||||
logger.error(
|
||||
f"[ARI] Channel creation failed: "
|
||||
f"HTTP {response.status} - {response_text}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Failed to create ARI channel: {response_text}",
|
||||
)
|
||||
|
||||
response_data = json.loads(response_text)
|
||||
channel_id = response_data.get("id", "")
|
||||
|
||||
logger.info(
|
||||
f"[ARI] Channel created: {channel_id} "
|
||||
f"state={response_data.get('state')}"
|
||||
)
|
||||
|
||||
return CallInitiationResult(
|
||||
call_id=channel_id,
|
||||
status=response_data.get("state", "created"),
|
||||
provider_metadata={
|
||||
"call_id": channel_id,
|
||||
"channel_name": response_data.get("name", ""),
|
||||
},
|
||||
raw_response=response_data,
|
||||
)
|
||||
|
||||
async def get_call_status(self, call_id: str) -> Dict[str, Any]:
|
||||
"""Get channel status from ARI."""
|
||||
if not self.validate_config():
|
||||
raise ValueError("ARI provider not properly configured")
|
||||
|
||||
endpoint = f"{self.base_url}/channels/{call_id}"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(endpoint, auth=self._get_auth()) as response:
|
||||
if response.status != 200:
|
||||
error_data = await response.text()
|
||||
raise Exception(f"Failed to get channel status: {error_data}")
|
||||
return await response.json()
|
||||
|
||||
async def get_available_phone_numbers(self) -> List[str]:
|
||||
"""Return configured extensions/numbers."""
|
||||
return self.from_numbers
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
"""Validate ARI configuration."""
|
||||
return bool(self.ari_endpoint and self.app_name and self.app_password)
|
||||
|
||||
async def verify_webhook_signature(
|
||||
self, url: str, params: Dict[str, Any], signature: str
|
||||
) -> bool:
|
||||
"""ARI does not use webhook signatures - events come via WebSocket."""
|
||||
return True
|
||||
|
||||
async def get_webhook_response(
|
||||
self, workflow_id: int, user_id: int, workflow_run_id: int
|
||||
) -> str:
|
||||
"""ARI does not use webhook responses - call control is via REST API."""
|
||||
logger.warning(
|
||||
"get_webhook_response called for ARI - this should not happen. "
|
||||
"ARI uses REST API for call control, not webhooks."
|
||||
)
|
||||
return ""
|
||||
|
||||
async def get_call_cost(self, call_id: str) -> Dict[str, Any]:
|
||||
"""ARI/Asterisk does not provide call cost information."""
|
||||
return {
|
||||
"cost_usd": 0.0,
|
||||
"duration": 0,
|
||||
"status": "unknown",
|
||||
"error": "ARI does not support cost retrieval",
|
||||
}
|
||||
|
||||
def parse_status_callback(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse ARI event data into generic status callback format.
|
||||
|
||||
ARI events come from the WebSocket listener, not HTTP callbacks.
|
||||
"""
|
||||
# Map ARI channel states to common status format
|
||||
state_map = {
|
||||
"Up": "answered",
|
||||
"Down": "completed",
|
||||
"Ringing": "ringing",
|
||||
"Ring": "ringing",
|
||||
"Busy": "busy",
|
||||
"Unavailable": "failed",
|
||||
}
|
||||
|
||||
channel_state = data.get("channel", {}).get("state", "")
|
||||
event_type = data.get("type", "")
|
||||
|
||||
# Determine status from event type
|
||||
if event_type == "StasisStart":
|
||||
status = "answered"
|
||||
elif event_type == "StasisEnd":
|
||||
status = "completed"
|
||||
elif event_type == "ChannelDestroyed":
|
||||
status = "completed"
|
||||
else:
|
||||
status = state_map.get(channel_state, channel_state.lower())
|
||||
|
||||
channel = data.get("channel", {})
|
||||
return {
|
||||
"call_id": channel.get("id", ""),
|
||||
"status": status,
|
||||
"from_number": channel.get("caller", {}).get("number"),
|
||||
"to_number": channel.get("dialplan", {}).get("exten"),
|
||||
"direction": None,
|
||||
"duration": None,
|
||||
"extra": data,
|
||||
}
|
||||
|
||||
async def handle_websocket(
|
||||
self,
|
||||
websocket: "WebSocket",
|
||||
workflow_id: int,
|
||||
user_id: int,
|
||||
workflow_run_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
ARI WebSocket handling is done by the ari_manager process.
|
||||
This method is a placeholder for the base class requirement.
|
||||
|
||||
TODO: Implement pipeline integration when ready.
|
||||
"""
|
||||
logger.warning(
|
||||
f"handle_websocket called for ARI provider - "
|
||||
f"pipeline integration not yet implemented for workflow_run {workflow_run_id}"
|
||||
)
|
||||
await websocket.close(
|
||||
code=4501, reason="ARI pipeline integration not yet implemented"
|
||||
)
|
||||
|
||||
# ======== INBOUND CALL METHODS ========
|
||||
|
||||
@classmethod
|
||||
def can_handle_webhook(
|
||||
cls, webhook_data: Dict[str, Any], headers: Dict[str, str]
|
||||
) -> bool:
|
||||
"""
|
||||
ARI does not use HTTP webhooks for inbound calls.
|
||||
Inbound calls are received via the ARI WebSocket event listener.
|
||||
"""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def parse_inbound_webhook(webhook_data: Dict[str, Any]) -> NormalizedInboundData:
|
||||
"""Parse ARI event data into normalized inbound format."""
|
||||
channel = webhook_data.get("channel", {})
|
||||
caller = channel.get("caller", {})
|
||||
connected = channel.get("connected", {})
|
||||
|
||||
return NormalizedInboundData(
|
||||
provider=ARIProvider.PROVIDER_NAME,
|
||||
call_id=channel.get("id", ""),
|
||||
from_number=caller.get("number", ""),
|
||||
to_number=channel.get("dialplan", {}).get("exten", ""),
|
||||
direction="inbound",
|
||||
call_status=channel.get("state", ""),
|
||||
account_id=None,
|
||||
raw_data=webhook_data,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_account_id(config_data: dict, webhook_account_id: str) -> bool:
|
||||
"""ARI doesn't use account IDs for validation."""
|
||||
return True
|
||||
|
||||
def normalize_phone_number(self, phone_number: str) -> str:
|
||||
"""Normalize phone number - ARI uses extensions as-is."""
|
||||
return phone_number or ""
|
||||
|
||||
async def verify_inbound_signature(
|
||||
self, url: str, webhook_data: Dict[str, Any], signature: str
|
||||
) -> bool:
|
||||
"""ARI authenticates via WebSocket connection credentials, not signatures."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def generate_inbound_response(
|
||||
websocket_url: str, workflow_run_id: int = None
|
||||
) -> tuple:
|
||||
"""ARI does not generate HTTP responses for inbound calls."""
|
||||
from fastapi import Response
|
||||
|
||||
return Response(content="", status_code=204)
|
||||
|
||||
@staticmethod
|
||||
def generate_error_response(error_type: str, message: str) -> tuple:
|
||||
"""Generate a generic JSON error response."""
|
||||
from fastapi import Response
|
||||
|
||||
return Response(
|
||||
content=json.dumps({"error": error_type, "message": message}),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_validation_error_response(error_type) -> tuple:
|
||||
"""Generate JSON error response for validation failures."""
|
||||
from fastapi import Response
|
||||
from api.errors.telephony_errors import TELEPHONY_ERROR_MESSAGES, TelephonyError
|
||||
|
||||
message = TELEPHONY_ERROR_MESSAGES.get(
|
||||
error_type, TELEPHONY_ERROR_MESSAGES[TelephonyError.GENERAL_AUTH_FAILED]
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=json.dumps({"error": str(error_type), "message": message}),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
# ======== CALL TRANSFER METHODS ========
|
||||
|
||||
def supports_transfers(self) -> bool:
|
||||
"""ARI does not currently support call transfers."""
|
||||
return False
|
||||
|
||||
async def transfer_call(
|
||||
self,
|
||||
destination: str,
|
||||
transfer_id: str,
|
||||
conference_name: str,
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""ARI call transfers are not yet implemented."""
|
||||
raise NotImplementedError("ARI provider does not support call transfers")
|
||||
|
||||
# ======== ARI-SPECIFIC METHODS ========
|
||||
|
||||
async def hangup_channel(self, channel_id: str, reason: str = "normal") -> bool:
|
||||
"""Hang up an ARI channel."""
|
||||
endpoint = f"{self.base_url}/channels/{channel_id}"
|
||||
params = {"reason_code": reason}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(
|
||||
endpoint, params=params, auth=self._get_auth()
|
||||
) as response:
|
||||
if response.status in (200, 204):
|
||||
logger.info(f"[ARI] Channel {channel_id} hung up")
|
||||
return True
|
||||
else:
|
||||
error = await response.text()
|
||||
logger.error(
|
||||
f"[ARI] Failed to hangup channel {channel_id}: {error}"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[ARI] Exception hanging up channel {channel_id}: {e}")
|
||||
return False
|
||||
|
||||
async def answer_channel(self, channel_id: str) -> bool:
|
||||
"""Answer an ARI channel."""
|
||||
endpoint = f"{self.base_url}/channels/{channel_id}/answer"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
endpoint, auth=self._get_auth()
|
||||
) as response:
|
||||
if response.status in (200, 204):
|
||||
logger.info(f"[ARI] Channel {channel_id} answered")
|
||||
return True
|
||||
else:
|
||||
error = await response.text()
|
||||
logger.error(
|
||||
f"[ARI] Failed to answer channel {channel_id}: {error}"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[ARI] Exception answering channel {channel_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_ws_url(self) -> str:
|
||||
"""Get the ARI WebSocket URL for event listening."""
|
||||
parsed = urlparse(self.ari_endpoint)
|
||||
ws_scheme = "wss" if parsed.scheme == "https" else "ws"
|
||||
return (
|
||||
f"{ws_scheme}://{parsed.netloc}/ari/events"
|
||||
f"?api_key={self.app_name}:{self.app_password}"
|
||||
f"&app={self.app_name}"
|
||||
f"&subscribeAll=true"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue